beam_reader_ops_test.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for beam_reader_ops."""
  16. import os.path
  17. import time
  18. import tensorflow as tf
  19. from tensorflow.python.framework import test_util
  20. from tensorflow.python.platform import googletest
  21. from tensorflow.python.platform import tf_logging as logging
  22. from syntaxnet import structured_graph_builder
  23. from syntaxnet.ops import gen_parser_ops
  24. FLAGS = tf.app.flags.FLAGS
  25. if not hasattr(FLAGS, 'test_srcdir'):
  26. FLAGS.test_srcdir = ''
  27. if not hasattr(FLAGS, 'test_tmpdir'):
  28. FLAGS.test_tmpdir = tf.test.get_temp_dir()
  29. class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
  30. def setUp(self):
  31. # Creates a task context with the correct testing paths.
  32. initial_task_context = os.path.join(
  33. FLAGS.test_srcdir,
  34. 'syntaxnet/'
  35. 'testdata/context.pbtxt')
  36. self._task_context = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
  37. with open(initial_task_context, 'r') as fin:
  38. with open(self._task_context, 'w') as fout:
  39. fout.write(fin.read().replace('SRCDIR', FLAGS.test_srcdir)
  40. .replace('OUTPATH', FLAGS.test_tmpdir))
  41. # Creates necessary term maps.
  42. with self.test_session() as sess:
  43. gen_parser_ops.lexicon_builder(task_context=self._task_context,
  44. corpus_name='training-corpus').run()
  45. self._num_features, self._num_feature_ids, _, self._num_actions = (
  46. sess.run(gen_parser_ops.feature_size(task_context=self._task_context,
  47. arg_prefix='brain_parser')))
  48. def MakeGraph(self,
  49. max_steps=10,
  50. beam_size=2,
  51. batch_size=1,
  52. **kwargs):
  53. """Constructs a structured learning graph."""
  54. assert max_steps > 0, 'Empty network not supported.'
  55. logging.info('MakeGraph + %s', kwargs)
  56. with self.test_session(graph=tf.Graph()) as sess:
  57. feature_sizes, domain_sizes, embedding_dims, num_actions = sess.run(
  58. gen_parser_ops.feature_size(task_context=self._task_context))
  59. embedding_dims = [8, 8, 8]
  60. hidden_layer_sizes = []
  61. learning_rate = 0.01
  62. builder = structured_graph_builder.StructuredGraphBuilder(
  63. num_actions,
  64. feature_sizes,
  65. domain_sizes,
  66. embedding_dims,
  67. hidden_layer_sizes,
  68. seed=1,
  69. max_steps=max_steps,
  70. beam_size=beam_size,
  71. gate_gradients=True,
  72. use_locking=True,
  73. use_averaging=False,
  74. check_parameters=False,
  75. **kwargs)
  76. builder.AddTraining(self._task_context,
  77. batch_size,
  78. learning_rate=learning_rate,
  79. decay_steps=1000,
  80. momentum=0.9,
  81. corpus_name='training-corpus')
  82. builder.AddEvaluation(self._task_context,
  83. batch_size,
  84. evaluation_max_steps=25,
  85. corpus_name=None)
  86. builder.training['inits'] = tf.group(*builder.inits.values(), name='inits')
  87. return builder
  88. def Train(self, **kwargs):
  89. with self.test_session(graph=tf.Graph()) as sess:
  90. max_steps = 3
  91. batch_size = 3
  92. beam_size = 3
  93. builder = (
  94. self.MakeGraph(
  95. max_steps=max_steps, beam_size=beam_size,
  96. batch_size=batch_size, **kwargs))
  97. logging.info('params: %s', builder.params.keys())
  98. logging.info('variables: %s', builder.variables.keys())
  99. t = builder.training
  100. sess.run(t['inits'])
  101. costs = []
  102. gold_slots = []
  103. alive_steps_vector = []
  104. every_n = 5
  105. walltime = time.time()
  106. for step in range(10):
  107. if step > 0 and step % every_n == 0:
  108. new_walltime = time.time()
  109. logging.info(
  110. 'Step: %d <cost>: %f <gold_slot>: %f <alive_steps>: %f <iter '
  111. 'time>: %f ms',
  112. step, sum(costs[-every_n:]) / float(every_n),
  113. sum(gold_slots[-every_n:]) / float(every_n),
  114. sum(alive_steps_vector[-every_n:]) / float(every_n),
  115. 1000 * (new_walltime - walltime) / float(every_n))
  116. walltime = new_walltime
  117. cost, gold_slot, alive_steps, _ = sess.run(
  118. [t['cost'], t['gold_slot'], t['alive_steps'], t['train_op']])
  119. costs.append(cost)
  120. gold_slots.append(gold_slot.mean())
  121. alive_steps_vector.append(alive_steps.mean())
  122. if builder._only_train:
  123. trainable_param_names = [
  124. k for k in builder.params if k in builder._only_train]
  125. else:
  126. trainable_param_names = builder.params.keys()
  127. if builder._use_averaging:
  128. for v in trainable_param_names:
  129. avg = builder.variables['%s_avg_var' % v].eval()
  130. tf.assign(builder.params[v], avg).eval()
  131. # Reset for pseudo eval.
  132. costs = []
  133. gold_slots = []
  134. alive_stepss = []
  135. for step in range(10):
  136. cost, gold_slot, alive_steps = sess.run(
  137. [t['cost'], t['gold_slot'], t['alive_steps']])
  138. costs.append(cost)
  139. gold_slots.append(gold_slot.mean())
  140. alive_stepss.append(alive_steps.mean())
  141. logging.info(
  142. 'Pseudo eval: <cost>: %f <gold_slot>: %f <alive_steps>: %f',
  143. sum(costs[-every_n:]) / float(every_n),
  144. sum(gold_slots[-every_n:]) / float(every_n),
  145. sum(alive_stepss[-every_n:]) / float(every_n))
  146. def PathScores(self, iterations, beam_size, max_steps, batch_size):
  147. with self.test_session(graph=tf.Graph()) as sess:
  148. t = self.MakeGraph(beam_size=beam_size, max_steps=max_steps,
  149. batch_size=batch_size).training
  150. sess.run(t['inits'])
  151. all_path_scores = []
  152. beam_path_scores = []
  153. for i in range(iterations):
  154. logging.info('run %d', i)
  155. tensors = (
  156. sess.run(
  157. [t['alive_steps'], t['concat_scores'],
  158. t['all_path_scores'], t['beam_path_scores'],
  159. t['indices'], t['path_ids']]))
  160. logging.info('alive for %s, all_path_scores and beam_path_scores, '
  161. 'indices and path_ids:'
  162. '\n%s\n%s\n%s\n%s',
  163. tensors[0], tensors[2], tensors[3], tensors[4], tensors[5])
  164. logging.info('diff:\n%s', tensors[2] - tensors[3])
  165. all_path_scores.append(tensors[2])
  166. beam_path_scores.append(tensors[3])
  167. return all_path_scores, beam_path_scores
  168. def testParseUntilNotAlive(self):
  169. """Ensures that the 'alive' condition works in the Cond ops."""
  170. with self.test_session(graph=tf.Graph()) as sess:
  171. t = self.MakeGraph(batch_size=3, beam_size=2, max_steps=5).training
  172. sess.run(t['inits'])
  173. for i in range(5):
  174. logging.info('run %d', i)
  175. tf_alive = t['alive'].eval()
  176. self.assertFalse(any(tf_alive))
  177. def testParseMomentum(self):
  178. """Ensures that Momentum training can be done using the gradients."""
  179. self.Train()
  180. self.Train(model_cost='perceptron_loss')
  181. self.Train(model_cost='perceptron_loss',
  182. only_train='softmax_weight,softmax_bias', softmax_init=0)
  183. self.Train(only_train='softmax_weight,softmax_bias', softmax_init=0)
  184. def testPathScoresAgree(self):
  185. """Ensures that path scores computed in the beam are same in the net."""
  186. all_path_scores, beam_path_scores = self.PathScores(
  187. iterations=1, beam_size=130, max_steps=5, batch_size=1)
  188. self.assertArrayNear(all_path_scores[0], beam_path_scores[0], 1e-6)
  189. def testBatchPathScoresAgree(self):
  190. """Ensures that path scores computed in the beam are same in the net."""
  191. all_path_scores, beam_path_scores = self.PathScores(
  192. iterations=1, beam_size=130, max_steps=5, batch_size=22)
  193. self.assertArrayNear(all_path_scores[0], beam_path_scores[0], 1e-6)
  194. def testBatchOneStepPathScoresAgree(self):
  195. """Ensures that path scores computed in the beam are same in the net."""
  196. all_path_scores, beam_path_scores = self.PathScores(
  197. iterations=1, beam_size=130, max_steps=1, batch_size=22)
  198. self.assertArrayNear(all_path_scores[0], beam_path_scores[0], 1e-6)
  199. if __name__ == '__main__':
  200. googletest.main()