reader_ops_test.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for reader_ops."""
  16. import os.path
  17. import numpy as np
  18. import tensorflow as tf
  19. from tensorflow.python.framework import test_util
  20. from tensorflow.python.platform import googletest
  21. from tensorflow.python.platform import tf_logging as logging
  22. from syntaxnet import dictionary_pb2
  23. from syntaxnet import graph_builder
  24. from syntaxnet import sparse_pb2
  25. from syntaxnet.ops import gen_parser_ops
  26. FLAGS = tf.app.flags.FLAGS
  27. if not hasattr(FLAGS, 'test_srcdir'):
  28. FLAGS.test_srcdir = ''
  29. if not hasattr(FLAGS, 'test_tmpdir'):
  30. FLAGS.test_tmpdir = tf.test.get_temp_dir()
  31. class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
  32. def setUp(self):
  33. # Creates a task context with the correct testing paths.
  34. initial_task_context = os.path.join(FLAGS.test_srcdir,
  35. 'syntaxnet/'
  36. 'testdata/context.pbtxt')
  37. self._task_context = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
  38. with open(initial_task_context, 'r') as fin:
  39. with open(self._task_context, 'w') as fout:
  40. fout.write(fin.read().replace('SRCDIR', FLAGS.test_srcdir)
  41. .replace('OUTPATH', FLAGS.test_tmpdir))
  42. # Creates necessary term maps.
  43. with self.test_session() as sess:
  44. gen_parser_ops.lexicon_builder(task_context=self._task_context,
  45. corpus_name='training-corpus').run()
  46. self._num_features, self._num_feature_ids, _, self._num_actions = (
  47. sess.run(gen_parser_ops.feature_size(task_context=self._task_context,
  48. arg_prefix='brain_parser')))
  49. def GetMaxId(self, sparse_features):
  50. max_id = 0
  51. for x in sparse_features:
  52. for y in x:
  53. f = sparse_pb2.SparseFeatures()
  54. f.ParseFromString(y)
  55. for i in f.id:
  56. max_id = max(i, max_id)
  57. return max_id
  58. def testParsingReaderOp(self):
  59. # Runs the reader over the test input for two epochs.
  60. num_steps_a = 0
  61. num_actions = 0
  62. num_word_ids = 0
  63. num_tag_ids = 0
  64. num_label_ids = 0
  65. batch_size = 10
  66. with self.test_session() as sess:
  67. (words, tags, labels), epochs, gold_actions = (
  68. gen_parser_ops.gold_parse_reader(self._task_context,
  69. 3,
  70. batch_size,
  71. corpus_name='training-corpus'))
  72. while True:
  73. tf_gold_actions, tf_epochs, tf_words, tf_tags, tf_labels = (
  74. sess.run([gold_actions, epochs, words, tags, labels]))
  75. num_steps_a += 1
  76. num_actions = max(num_actions, max(tf_gold_actions) + 1)
  77. num_word_ids = max(num_word_ids, self.GetMaxId(tf_words) + 1)
  78. num_tag_ids = max(num_tag_ids, self.GetMaxId(tf_tags) + 1)
  79. num_label_ids = max(num_label_ids, self.GetMaxId(tf_labels) + 1)
  80. self.assertIn(tf_epochs, [0, 1, 2])
  81. if tf_epochs > 1:
  82. break
  83. # Runs the reader again, this time with a lot of added graph nodes.
  84. num_steps_b = 0
  85. with self.test_session() as sess:
  86. num_features = [6, 6, 4]
  87. num_feature_ids = [num_word_ids, num_tag_ids, num_label_ids]
  88. embedding_sizes = [8, 8, 8]
  89. hidden_layer_sizes = [32, 32]
  90. # Here we aim to test the iteration of the reader op in a complex network,
  91. # not the GraphBuilder.
  92. parser = graph_builder.GreedyParser(
  93. num_actions, num_features, num_feature_ids, embedding_sizes,
  94. hidden_layer_sizes)
  95. parser.AddTraining(self._task_context,
  96. batch_size,
  97. corpus_name='training-corpus')
  98. sess.run(parser.inits.values())
  99. while True:
  100. tf_epochs, tf_cost, _ = sess.run(
  101. [parser.training['epochs'], parser.training['cost'],
  102. parser.training['train_op']])
  103. num_steps_b += 1
  104. self.assertGreaterEqual(tf_cost, 0)
  105. self.assertIn(tf_epochs, [0, 1, 2])
  106. if tf_epochs > 1:
  107. break
  108. # Assert that the two runs made the exact same number of steps.
  109. logging.info('Number of steps in the two runs: %d, %d',
  110. num_steps_a, num_steps_b)
  111. self.assertEqual(num_steps_a, num_steps_b)
  112. def testParsingReaderOpWhileLoop(self):
  113. feature_size = 3
  114. batch_size = 5
  115. def ParserEndpoints():
  116. return gen_parser_ops.gold_parse_reader(self._task_context,
  117. feature_size,
  118. batch_size,
  119. corpus_name='training-corpus')
  120. with self.test_session() as sess:
  121. # The 'condition' and 'body' functions expect as many arguments as there
  122. # are loop variables. 'condition' depends on the 'epoch' loop variable
  123. # only, so we disregard the remaining unused function arguments. 'body'
  124. # returns a list of updated loop variables.
  125. def Condition(epoch, *unused_args):
  126. return tf.less(epoch, 2)
  127. def Body(epoch, num_actions, *feature_args):
  128. # By adding one of the outputs of the reader op ('epoch') as a control
  129. # dependency to the reader op we force the repeated evaluation of the
  130. # reader op.
  131. with epoch.graph.control_dependencies([epoch]):
  132. features, epoch, gold_actions = ParserEndpoints()
  133. num_actions = tf.maximum(num_actions,
  134. tf.reduce_max(gold_actions, [0], False) + 1)
  135. feature_ids = []
  136. for i in range(len(feature_args)):
  137. feature_ids.append(features[i])
  138. return [epoch, num_actions] + feature_ids
  139. epoch = ParserEndpoints()[-2]
  140. num_actions = tf.constant(0)
  141. loop_vars = [epoch, num_actions]
  142. res = sess.run(
  143. tf.while_loop(Condition, Body, loop_vars,
  144. shape_invariants=[tf.TensorShape(None)] * 2,
  145. parallel_iterations=1))
  146. logging.info('Result: %s', res)
  147. self.assertEqual(res[0], 2)
  148. def _token_embedding(self, token, embedding):
  149. e = dictionary_pb2.TokenEmbedding()
  150. e.token = token
  151. e.vector.values.extend(embedding)
  152. return e.SerializeToString()
  153. def testWordEmbeddingInitializer(self):
  154. # Provide embeddings for the first three words in the word map.
  155. records_path = os.path.join(FLAGS.test_tmpdir, 'records1')
  156. writer = tf.python_io.TFRecordWriter(records_path)
  157. writer.write(self._token_embedding('.', [1, 2]))
  158. writer.write(self._token_embedding(',', [3, 4]))
  159. writer.write(self._token_embedding('the', [5, 6]))
  160. del writer
  161. with self.test_session():
  162. embeddings = gen_parser_ops.word_embedding_initializer(
  163. vectors=records_path,
  164. task_context=self._task_context).eval()
  165. self.assertAllClose(
  166. np.array([[1. / (1 + 4) ** .5, 2. / (1 + 4) ** .5],
  167. [3. / (9 + 16) ** .5, 4. / (9 + 16) ** .5],
  168. [5. / (25 + 36) ** .5, 6. / (25 + 36) ** .5]]),
  169. embeddings[:3,])
  170. def testWordEmbeddingInitializerRepeatability(self):
  171. records_path = os.path.join(FLAGS.test_tmpdir, 'records2')
  172. writer = tf.python_io.TFRecordWriter(records_path)
  173. writer.write(self._token_embedding('.', [1, 2, 3])) # 3 dims
  174. del writer
  175. # As long as there is one non-zero seed, the result should be repeatable.
  176. for seed1, seed2 in [(0, 1), (1, 0), (123, 456)]:
  177. with tf.Graph().as_default(), self.test_session():
  178. embeddings1 = gen_parser_ops.word_embedding_initializer(
  179. vectors=records_path,
  180. task_context=self._task_context,
  181. seed=seed1,
  182. seed2=seed2)
  183. embeddings2 = gen_parser_ops.word_embedding_initializer(
  184. vectors=records_path,
  185. task_context=self._task_context,
  186. seed=seed1,
  187. seed2=seed2)
  188. # The number of terms is based on the word map, which may change if the
  189. # test corpus is updated. Just assert that there are some terms.
  190. self.assertGreater(tf.shape(embeddings1)[0].eval(), 0)
  191. self.assertGreater(tf.shape(embeddings2)[0].eval(), 0)
  192. self.assertEqual(tf.shape(embeddings1)[1].eval(), 3)
  193. self.assertEqual(tf.shape(embeddings2)[1].eval(), 3)
  194. self.assertAllEqual(embeddings1.eval(), embeddings2.eval())
  195. if __name__ == '__main__':
  196. googletest.main()