beam_reader_ops_test.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for beam_reader_ops."""
  16. import os.path
  17. import time
  18. import tensorflow as tf
  19. from tensorflow.python.framework import test_util
  20. from tensorflow.python.platform import googletest
  21. from tensorflow.python.platform import tf_logging as logging
  22. from syntaxnet import structured_graph_builder
  23. from syntaxnet.ops import gen_parser_ops
  24. FLAGS = tf.app.flags.FLAGS
  25. if not hasattr(FLAGS, 'test_srcdir'):
  26. FLAGS.test_srcdir = ''
  27. if not hasattr(FLAGS, 'test_tmpdir'):
  28. FLAGS.test_tmpdir = tf.test.get_temp_dir()
  29. class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
  30. def setUp(self):
  31. # Creates a task context with the correct testing paths.
  32. initial_task_context = os.path.join(FLAGS.test_srcdir,
  33. 'syntaxnet/'
  34. 'testdata/context.pbtxt')
  35. self._task_context = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
  36. with open(initial_task_context, 'r') as fin:
  37. with open(self._task_context, 'w') as fout:
  38. fout.write(fin.read().replace('SRCDIR', FLAGS.test_srcdir)
  39. .replace('OUTPATH', FLAGS.test_tmpdir))
  40. # Creates necessary term maps.
  41. with self.test_session() as sess:
  42. gen_parser_ops.lexicon_builder(task_context=self._task_context,
  43. corpus_name='training-corpus').run()
  44. self._num_features, self._num_feature_ids, _, self._num_actions = (
  45. sess.run(gen_parser_ops.feature_size(task_context=self._task_context,
  46. arg_prefix='brain_parser')))
  47. def MakeGraph(self,
  48. max_steps=10,
  49. beam_size=2,
  50. batch_size=1,
  51. **kwargs):
  52. """Constructs a structured learning graph."""
  53. assert max_steps > 0, 'Empty network not supported.'
  54. logging.info('MakeGraph + %s', kwargs)
  55. with self.test_session(graph=tf.Graph()) as sess:
  56. feature_sizes, domain_sizes, embedding_dims, num_actions = sess.run(
  57. gen_parser_ops.feature_size(task_context=self._task_context))
  58. embedding_dims = [8, 8, 8]
  59. hidden_layer_sizes = []
  60. learning_rate = 0.01
  61. builder = structured_graph_builder.StructuredGraphBuilder(
  62. num_actions,
  63. feature_sizes,
  64. domain_sizes,
  65. embedding_dims,
  66. hidden_layer_sizes,
  67. seed=1,
  68. max_steps=max_steps,
  69. beam_size=beam_size,
  70. gate_gradients=True,
  71. use_locking=True,
  72. use_averaging=False,
  73. check_parameters=False,
  74. **kwargs)
  75. builder.AddTraining(self._task_context,
  76. batch_size,
  77. learning_rate=learning_rate,
  78. decay_steps=1000,
  79. momentum=0.9,
  80. corpus_name='training-corpus')
  81. builder.AddEvaluation(self._task_context,
  82. batch_size,
  83. evaluation_max_steps=25,
  84. corpus_name=None)
  85. builder.training['inits'] = tf.group(*builder.inits.values(), name='inits')
  86. return builder
  87. def Train(self, **kwargs):
  88. with self.test_session(graph=tf.Graph()) as sess:
  89. max_steps = 3
  90. batch_size = 3
  91. beam_size = 3
  92. builder = (
  93. self.MakeGraph(
  94. max_steps=max_steps, beam_size=beam_size,
  95. batch_size=batch_size, **kwargs))
  96. logging.info('params: %s', builder.params.keys())
  97. logging.info('variables: %s', builder.variables.keys())
  98. t = builder.training
  99. sess.run(t['inits'])
  100. costs = []
  101. gold_slots = []
  102. alive_steps_vector = []
  103. every_n = 5
  104. walltime = time.time()
  105. for step in range(10):
  106. if step > 0 and step % every_n == 0:
  107. new_walltime = time.time()
  108. logging.info(
  109. 'Step: %d <cost>: %f <gold_slot>: %f <alive_steps>: %f <iter '
  110. 'time>: %f ms',
  111. step, sum(costs[-every_n:]) / float(every_n),
  112. sum(gold_slots[-every_n:]) / float(every_n),
  113. sum(alive_steps_vector[-every_n:]) / float(every_n),
  114. 1000 * (new_walltime - walltime) / float(every_n))
  115. walltime = new_walltime
  116. cost, gold_slot, alive_steps, _ = sess.run(
  117. [t['cost'], t['gold_slot'], t['alive_steps'], t['train_op']])
  118. costs.append(cost)
  119. gold_slots.append(gold_slot.mean())
  120. alive_steps_vector.append(alive_steps.mean())
  121. if builder._only_train:
  122. trainable_param_names = [
  123. k for k in builder.params if k in builder._only_train]
  124. else:
  125. trainable_param_names = builder.params.keys()
  126. if builder._use_averaging:
  127. for v in trainable_param_names:
  128. avg = builder.variables['%s_avg_var' % v].eval()
  129. tf.assign(builder.params[v], avg).eval()
  130. # Reset for pseudo eval.
  131. costs = []
  132. gold_slots = []
  133. alive_stepss = []
  134. for step in range(10):
  135. cost, gold_slot, alive_steps = sess.run(
  136. [t['cost'], t['gold_slot'], t['alive_steps']])
  137. costs.append(cost)
  138. gold_slots.append(gold_slot.mean())
  139. alive_stepss.append(alive_steps.mean())
  140. logging.info(
  141. 'Pseudo eval: <cost>: %f <gold_slot>: %f <alive_steps>: %f',
  142. sum(costs[-every_n:]) / float(every_n),
  143. sum(gold_slots[-every_n:]) / float(every_n),
  144. sum(alive_stepss[-every_n:]) / float(every_n))
  145. def PathScores(self, iterations, beam_size, max_steps, batch_size):
  146. with self.test_session(graph=tf.Graph()) as sess:
  147. t = self.MakeGraph(beam_size=beam_size, max_steps=max_steps,
  148. batch_size=batch_size).training
  149. sess.run(t['inits'])
  150. all_path_scores = []
  151. beam_path_scores = []
  152. for i in range(iterations):
  153. logging.info('run %d', i)
  154. tensors = (
  155. sess.run(
  156. [t['alive_steps'], t['concat_scores'],
  157. t['all_path_scores'], t['beam_path_scores'],
  158. t['indices'], t['path_ids']]))
  159. logging.info('alive for %s, all_path_scores and beam_path_scores, '
  160. 'indices and path_ids:'
  161. '\n%s\n%s\n%s\n%s',
  162. tensors[0], tensors[2], tensors[3], tensors[4], tensors[5])
  163. logging.info('diff:\n%s', tensors[2] - tensors[3])
  164. all_path_scores.append(tensors[2])
  165. beam_path_scores.append(tensors[3])
  166. return all_path_scores, beam_path_scores
  167. def testParseUntilNotAlive(self):
  168. """Ensures that the 'alive' condition works in the Cond ops."""
  169. with self.test_session(graph=tf.Graph()) as sess:
  170. t = self.MakeGraph(batch_size=3, beam_size=2, max_steps=5).training
  171. sess.run(t['inits'])
  172. for i in range(5):
  173. logging.info('run %d', i)
  174. tf_alive = t['alive'].eval()
  175. self.assertFalse(any(tf_alive))
  176. def testParseMomentum(self):
  177. """Ensures that Momentum training can be done using the gradients."""
  178. self.Train()
  179. self.Train(model_cost='perceptron_loss')
  180. self.Train(model_cost='perceptron_loss',
  181. only_train='softmax_weight,softmax_bias', softmax_init=0)
  182. self.Train(only_train='softmax_weight,softmax_bias', softmax_init=0)
  183. def testPathScoresAgree(self):
  184. """Ensures that path scores computed in the beam are same in the net."""
  185. all_path_scores, beam_path_scores = self.PathScores(
  186. iterations=1, beam_size=130, max_steps=5, batch_size=1)
  187. self.assertArrayNear(all_path_scores[0], beam_path_scores[0], 1e-6)
  188. def testBatchPathScoresAgree(self):
  189. """Ensures that path scores computed in the beam are same in the net."""
  190. all_path_scores, beam_path_scores = self.PathScores(
  191. iterations=1, beam_size=130, max_steps=5, batch_size=22)
  192. self.assertArrayNear(all_path_scores[0], beam_path_scores[0], 1e-6)
  193. def testBatchOneStepPathScoresAgree(self):
  194. """Ensures that path scores computed in the beam are same in the net."""
  195. all_path_scores, beam_path_scores = self.PathScores(
  196. iterations=1, beam_size=130, max_steps=1, batch_size=22)
  197. self.assertArrayNear(all_path_scores[0], beam_path_scores[0], 1e-6)
  198. if __name__ == '__main__':
  199. googletest.main()