graph_builder_test.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656
  1. """Tests for graph_builder."""
  2. import collections
  3. import os.path
  4. import numpy as np
  5. import tensorflow as tf
  6. from google.protobuf import text_format
  7. from dragnn.protos import spec_pb2
  8. from dragnn.protos import trace_pb2
  9. from dragnn.python import dragnn_ops
  10. from dragnn.python import graph_builder
  11. from syntaxnet import sentence_pb2
  12. from tensorflow.python.framework import test_util
  13. from tensorflow.python.platform import googletest
  14. from tensorflow.python.platform import tf_logging as logging
  15. import dragnn.python.load_dragnn_cc_impl
  16. import syntaxnet.load_parser_ops
  17. FLAGS = tf.app.flags.FLAGS
  18. if not hasattr(FLAGS, 'test_srcdir'):
  19. FLAGS.test_srcdir = ''
  20. if not hasattr(FLAGS, 'test_tmpdir'):
  21. FLAGS.test_tmpdir = tf.test.get_temp_dir()
  22. _DUMMY_GOLD_SENTENCE = """
  23. token {
  24. word: "sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
  25. }
  26. token {
  27. word: "0" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
  28. }
  29. token {
  30. word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
  31. }
  32. """
  33. # The second sentence has different length, to test the effect of
  34. # mixed-length batches.
  35. _DUMMY_GOLD_SENTENCE_2 = """
  36. token {
  37. word: "sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
  38. }
  39. """
  40. # The test sentence is the gold sentence with the tags and parse information
  41. # removed.
  42. _DUMMY_TEST_SENTENCE = """
  43. token {
  44. word: "sentence" start: 0 end: 7
  45. }
  46. token {
  47. word: "0" start: 9 end: 9
  48. }
  49. token {
  50. word: "." start: 10 end: 10
  51. }
  52. """
  53. _DUMMY_TEST_SENTENCE_2 = """
  54. token {
  55. word: "sentence" start: 0 end: 7
  56. }
  57. """
  58. _TAGGER_EXPECTED_SENTENCES = [
  59. """
  60. token {
  61. word: "sentence" start: 0 end: 7 tag: "NN"
  62. }
  63. token {
  64. word: "0" start: 9 end: 9 tag: "CD"
  65. }
  66. token {
  67. word: "." start: 10 end: 10 tag: "."
  68. }
  69. """, """
  70. token {
  71. word: "sentence" start: 0 end: 7 tag: "NN"
  72. }
  73. """
  74. ]
  75. _TAGGER_PARSER_EXPECTED_SENTENCES = [
  76. """
  77. token {
  78. word: "sentence" start: 0 end: 7 tag: "NN" label: "ROOT"
  79. }
  80. token {
  81. word: "0" start: 9 end: 9 head: 0 tag: "CD" label: "num"
  82. }
  83. token {
  84. word: "." start: 10 end: 10 head: 0 tag: "." label: "punct"
  85. }
  86. """, """
  87. token {
  88. word: "sentence" start: 0 end: 7 tag: "NN" label: "ROOT"
  89. }
  90. """
  91. ]
  92. _UNLABELED_PARSER_EXPECTED_SENTENCES = [
  93. """
  94. token {
  95. word: "sentence" start: 0 end: 7 label: "punct"
  96. }
  97. token {
  98. word: "0" start: 9 end: 9 head: 0 label: "punct"
  99. }
  100. token {
  101. word: "." start: 10 end: 10 head: 0 label: "punct"
  102. }
  103. """, """
  104. token {
  105. word: "sentence" start: 0 end: 7 label: "punct"
  106. }
  107. """
  108. ]
  109. _LABELED_PARSER_EXPECTED_SENTENCES = [
  110. """
  111. token {
  112. word: "sentence" start: 0 end: 7 label: "ROOT"
  113. }
  114. token {
  115. word: "0" start: 9 end: 9 head: 0 label: "num"
  116. }
  117. token {
  118. word: "." start: 10 end: 10 head: 0 label: "punct"
  119. }
  120. """, """
  121. token {
  122. word: "sentence" start: 0 end: 7 label: "ROOT"
  123. }
  124. """
  125. ]
  126. def _as_op(x):
  127. """Always returns the tf.Operation associated with a node."""
  128. return x.op if isinstance(x, tf.Tensor) else x
  129. def _find_input_path(src, dst_predicate):
  130. """Finds an input path from `src` to a node that satisfies `dst_predicate`.
  131. TensorFlow graphs are directed. We generate paths from outputs to inputs,
  132. recursively searching both direct (i.e. data) and control inputs. Graphs with
  133. while_loop control flow may contain cycles. Therefore we eliminate loops
  134. during the DFS.
  135. Args:
  136. src: tf.Tensor or tf.Operation root node.
  137. dst_predicate: function taking one argument (a node), returning true iff a
  138. a target node has been found.
  139. Returns:
  140. a path from `src` to the first node that satisfies dest_predicate, or the
  141. empty list otherwise.
  142. """
  143. path_to = {src: None}
  144. def dfs(x):
  145. if dst_predicate(x):
  146. return x
  147. x_op = _as_op(x)
  148. for y in x_op.control_inputs + list(x_op.inputs):
  149. # Check if we've already visited node `y`.
  150. if y not in path_to:
  151. path_to[y] = x
  152. res = dfs(y)
  153. if res is not None:
  154. return res
  155. return None
  156. dst = dfs(src)
  157. path = []
  158. while dst in path_to:
  159. path.append(dst)
  160. dst = path_to[dst]
  161. return list(reversed(path))
  162. def _find_input_path_to_type(src, dst_type):
  163. """Finds a path from `src` to a node with type (i.e. kernel) `dst_type`."""
  164. return _find_input_path(src, lambda x: _as_op(x).type == dst_type)
  165. class GraphBuilderTest(test_util.TensorFlowTestCase):
  166. def assertEmpty(self, container, msg=None):
  167. """Assert that an object has zero length.
  168. Args:
  169. container: Anything that implements the collections.Sized interface.
  170. msg: Optional message to report on failure.
  171. """
  172. if not isinstance(container, collections.Sized):
  173. self.fail('Expected a Sized object, got: '
  174. '{!r}'.format(type(container).__name__), msg)
  175. # explicitly check the length since some Sized objects (e.g. numpy.ndarray)
  176. # have strange __nonzero__/__bool__ behavior.
  177. if len(container):
  178. self.fail('{!r} has length of {}.'.format(container, len(container)), msg)
  179. def assertNotEmpty(self, container, msg=None):
  180. """Assert that an object has non-zero length.
  181. Args:
  182. container: Anything that implements the collections.Sized interface.
  183. msg: Optional message to report on failure.
  184. """
  185. if not isinstance(container, collections.Sized):
  186. self.fail('Expected a Sized object, got: '
  187. '{!r}'.format(type(container).__name__), msg)
  188. # explicitly check the length since some Sized objects (e.g. numpy.ndarray)
  189. # have strange __nonzero__/__bool__ behavior.
  190. if not len(container):
  191. self.fail('{!r} has length of 0.'.format(container), msg)
  192. def LoadSpec(self, spec_path):
  193. master_spec = spec_pb2.MasterSpec()
  194. testdata = os.path.join(FLAGS.test_srcdir,
  195. 'dragnn/core/testdata')
  196. with file(os.path.join(testdata, spec_path), 'r') as fin:
  197. text_format.Parse(fin.read().replace('TESTDATA', testdata), master_spec)
  198. return master_spec
  199. def MakeHyperparams(self, **kwargs):
  200. hyperparam_config = spec_pb2.GridPoint()
  201. for key in kwargs:
  202. setattr(hyperparam_config, key, kwargs[key])
  203. return hyperparam_config
  204. def RunTraining(self, hyperparam_config):
  205. master_spec = self.LoadSpec('master_spec_link.textproto')
  206. self.assertTrue(isinstance(hyperparam_config, spec_pb2.GridPoint))
  207. gold_doc = sentence_pb2.Sentence()
  208. text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc)
  209. gold_doc_2 = sentence_pb2.Sentence()
  210. text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2)
  211. reader_strings = [
  212. gold_doc.SerializeToString(), gold_doc_2.SerializeToString()
  213. ]
  214. tf.logging.info('Generating graph with config: %s', hyperparam_config)
  215. with tf.Graph().as_default():
  216. builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
  217. target = spec_pb2.TrainTarget()
  218. target.name = 'testTraining-all'
  219. train = builder.add_training_from_config(target)
  220. with self.test_session() as sess:
  221. logging.info('Initializing')
  222. sess.run(tf.global_variables_initializer())
  223. # Run one iteration of training and verify nothing crashes.
  224. logging.info('Training')
  225. sess.run(train['run'], feed_dict={train['input_batch']: reader_strings})
  226. def testTraining(self):
  227. """Tests the default hyperparameter settings."""
  228. self.RunTraining(self.MakeHyperparams())
  229. def testTrainingWithGradientClipping(self):
  230. """Adds code coverage for gradient clipping."""
  231. self.RunTraining(self.MakeHyperparams(gradient_clip_norm=1.25))
  232. def testTrainingWithAdamAndAveraging(self):
  233. """Adds code coverage for ADAM and the use of moving averaging."""
  234. self.RunTraining(
  235. self.MakeHyperparams(learning_method='adam', use_moving_average=True))
  236. def testTrainingWithCompositeOptimizer(self):
  237. """Adds code coverage for CompositeOptimizer."""
  238. grid_point = self.MakeHyperparams(learning_method='composite')
  239. grid_point.composite_optimizer_spec.method1.learning_method = 'adam'
  240. grid_point.composite_optimizer_spec.method2.learning_method = 'momentum'
  241. grid_point.composite_optimizer_spec.method2.momentum = 0.9
  242. self.RunTraining(grid_point)
  243. def RunFullTrainingAndInference(self,
  244. test_name,
  245. master_spec_path=None,
  246. master_spec=None,
  247. component_weights=None,
  248. unroll_using_oracle=None,
  249. num_evaluated_components=1,
  250. expected_num_actions=None,
  251. expected=None,
  252. batch_size_limit=None):
  253. if not master_spec:
  254. master_spec = self.LoadSpec(master_spec_path)
  255. gold_doc = sentence_pb2.Sentence()
  256. text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc)
  257. gold_doc_2 = sentence_pb2.Sentence()
  258. text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2)
  259. gold_reader_strings = [
  260. gold_doc.SerializeToString(), gold_doc_2.SerializeToString()
  261. ]
  262. test_doc = sentence_pb2.Sentence()
  263. text_format.Parse(_DUMMY_TEST_SENTENCE, test_doc)
  264. test_doc_2 = sentence_pb2.Sentence()
  265. text_format.Parse(_DUMMY_TEST_SENTENCE_2, test_doc_2)
  266. test_reader_strings = [
  267. test_doc.SerializeToString(), test_doc.SerializeToString(),
  268. test_doc_2.SerializeToString(), test_doc.SerializeToString()
  269. ]
  270. if batch_size_limit is not None:
  271. gold_reader_strings = gold_reader_strings[:batch_size_limit]
  272. test_reader_strings = test_reader_strings[:batch_size_limit]
  273. with tf.Graph().as_default():
  274. tf.set_random_seed(1)
  275. hyperparam_config = spec_pb2.GridPoint()
  276. builder = graph_builder.MasterBuilder(
  277. master_spec, hyperparam_config, pool_scope=test_name)
  278. target = spec_pb2.TrainTarget()
  279. target.name = 'testFullInference-train-%s' % test_name
  280. if component_weights:
  281. target.component_weights.extend(component_weights)
  282. else:
  283. target.component_weights.extend([0] * len(master_spec.component))
  284. target.component_weights[-1] = 1.0
  285. if unroll_using_oracle:
  286. target.unroll_using_oracle.extend(unroll_using_oracle)
  287. else:
  288. target.unroll_using_oracle.extend([False] * len(master_spec.component))
  289. target.unroll_using_oracle[-1] = True
  290. train = builder.add_training_from_config(target)
  291. oracle_trace = builder.add_training_from_config(
  292. target, prefix='train_traced-', trace_only=True)
  293. builder.add_saver()
  294. anno = builder.add_annotation(test_name)
  295. trace = builder.add_annotation(test_name + '-traced', enable_tracing=True)
  296. # Verifies that the summaries can be built.
  297. for component in builder.components:
  298. component.get_summaries()
  299. config = tf.ConfigProto(
  300. intra_op_parallelism_threads=0, inter_op_parallelism_threads=0)
  301. with self.test_session(config=config) as sess:
  302. logging.info('Initializing')
  303. sess.run(tf.global_variables_initializer())
  304. logging.info('Dry run oracle trace...')
  305. traces = sess.run(
  306. oracle_trace['traces'],
  307. feed_dict={oracle_trace['input_batch']: gold_reader_strings})
  308. # Check that the oracle traces are not empty.
  309. for serialized_trace in traces:
  310. master_trace = trace_pb2.MasterTrace()
  311. master_trace.ParseFromString(serialized_trace)
  312. self.assertTrue(master_trace.component_trace)
  313. self.assertTrue(master_trace.component_trace[0].step_trace)
  314. logging.info('Simulating training...')
  315. break_iter = 400
  316. is_resolved = False
  317. for i in range(0,
  318. 400): # needs ~100 iterations, but is not deterministic
  319. cost, eval_res_val = sess.run(
  320. [train['cost'], train['metrics']],
  321. feed_dict={train['input_batch']: gold_reader_strings})
  322. logging.info('cost = %s', cost)
  323. self.assertFalse(np.isnan(cost))
  324. total_val = eval_res_val.reshape((-1, 2))[:, 0].sum()
  325. correct_val = eval_res_val.reshape((-1, 2))[:, 1].sum()
  326. if correct_val == total_val and not is_resolved:
  327. logging.info('... converged on iteration %d with (correct, total) '
  328. '= (%d, %d)', i, correct_val, total_val)
  329. is_resolved = True
  330. # Run for slightly longer than convergence to help with quantized
  331. # weight tiebreakers.
  332. break_iter = i + 50
  333. if i == break_iter:
  334. break
  335. # If training failed, report total/correct actions for each component.
  336. if not expected_num_actions:
  337. expected_num_actions = 4 * num_evaluated_components
  338. if (correct_val != total_val or correct_val != expected_num_actions or
  339. total_val != expected_num_actions):
  340. for c in xrange(len(master_spec.component)):
  341. logging.error('component %s:\nname=%s\ntotal=%s\ncorrect=%s', c,
  342. master_spec.component[c].name, eval_res_val[2 * c],
  343. eval_res_val[2 * c + 1])
  344. assert correct_val == total_val, 'Did not converge! %d vs %d.' % (
  345. correct_val, total_val)
  346. self.assertEqual(expected_num_actions, correct_val)
  347. self.assertEqual(expected_num_actions, total_val)
  348. builder.saver.save(sess, os.path.join(FLAGS.test_tmpdir, 'model'))
  349. logging.info('Running test.')
  350. logging.info('Printing annotations')
  351. annotations = sess.run(
  352. anno['annotations'],
  353. feed_dict={anno['input_batch']: test_reader_strings})
  354. logging.info('Put %d inputs in, got %d annotations out.',
  355. len(test_reader_strings), len(annotations))
  356. # Also run the annotation graph with tracing enabled.
  357. annotations_with_trace, traces = sess.run(
  358. [trace['annotations'], trace['traces']],
  359. feed_dict={trace['input_batch']: test_reader_strings})
  360. # The result of the two annotation graphs should be identical.
  361. self.assertItemsEqual(annotations, annotations_with_trace)
  362. # Check that the inference traces are not empty.
  363. for serialized_trace in traces:
  364. master_trace = trace_pb2.MasterTrace()
  365. master_trace.ParseFromString(serialized_trace)
  366. self.assertTrue(master_trace.component_trace)
  367. self.assertTrue(master_trace.component_trace[0].step_trace)
  368. self.assertEqual(len(test_reader_strings), len(annotations))
  369. pred_sentences = []
  370. for annotation in annotations:
  371. pred_sentences.append(sentence_pb2.Sentence())
  372. pred_sentences[-1].ParseFromString(annotation)
  373. if expected is None:
  374. expected = _TAGGER_EXPECTED_SENTENCES
  375. expected_sentences = [expected[i] for i in [0, 0, 1, 0]]
  376. for i, pred_sentence in enumerate(pred_sentences):
  377. self.assertProtoEquals(expected_sentences[i], pred_sentence)
  378. def testSimpleTagger(self):
  379. self.RunFullTrainingAndInference('simple-tagger',
  380. 'simple_tagger_master_spec.textproto')
  381. def testSimpleTaggerLayerNorm(self):
  382. spec = self.LoadSpec('simple_tagger_master_spec.textproto')
  383. spec.component[0].network_unit.parameters['layer_norm_hidden'] = 'True'
  384. spec.component[0].network_unit.parameters['layer_norm_input'] = 'True'
  385. self.RunFullTrainingAndInference('simple-tagger', master_spec=spec)
  386. def testSimpleTaggerLSTM(self):
  387. self.RunFullTrainingAndInference('simple-tagger-lstm',
  388. 'simple_tagger_lstm_master_spec.textproto')
  389. def testSimpleTaggerWrappedLSTM(self):
  390. self.RunFullTrainingAndInference(
  391. 'simple-tagger-wrapped-lstm',
  392. 'simple_tagger_wrapped_lstm_master_spec.textproto')
  393. def testSplitTagger(self):
  394. self.RunFullTrainingAndInference('split-tagger',
  395. 'split_tagger_master_spec.textproto')
  396. def testTaggerParser(self):
  397. self.RunFullTrainingAndInference(
  398. 'tagger-parser',
  399. 'tagger_parser_master_spec.textproto',
  400. component_weights=[0., 1., 1.],
  401. unroll_using_oracle=[False, True, True],
  402. expected_num_actions=12,
  403. expected=_TAGGER_PARSER_EXPECTED_SENTENCES)
  404. def testTaggerParserWithAttention(self):
  405. spec = self.LoadSpec('tagger_parser_master_spec.textproto')
  406. # Make the 'parser' component attend to the 'tagger' component.
  407. self.assertEqual('tagger', spec.component[1].name)
  408. self.assertEqual('parser', spec.component[2].name)
  409. spec.component[2].attention_component = 'tagger'
  410. # Attention + beam decoding is not yet supported.
  411. spec.component[2].inference_beam_size = 1
  412. # Running with batch size equal to 1 should be fine.
  413. self.RunFullTrainingAndInference(
  414. 'tagger-parser',
  415. master_spec=spec,
  416. batch_size_limit=1,
  417. component_weights=[0., 1., 1.],
  418. unroll_using_oracle=[False, True, True],
  419. expected_num_actions=9,
  420. expected=_TAGGER_PARSER_EXPECTED_SENTENCES)
  421. def testTaggerParserWithAttentionBatchDeath(self):
  422. spec = self.LoadSpec('tagger_parser_master_spec.textproto')
  423. # Make the 'parser' component attend to the 'tagger' component.
  424. self.assertEqual('tagger', spec.component[1].name)
  425. self.assertEqual('parser', spec.component[2].name)
  426. spec.component[2].attention_component = 'tagger'
  427. # Trying to run with a batch size greater than 1 should fail:
  428. with self.assertRaises(tf.errors.InvalidArgumentError):
  429. self.RunFullTrainingAndInference(
  430. 'tagger-parser',
  431. master_spec=spec,
  432. component_weights=[0., 1., 1.],
  433. unroll_using_oracle=[False, True, True],
  434. expected_num_actions=9,
  435. expected=_TAGGER_PARSER_EXPECTED_SENTENCES)
  436. def testSimpleParser(self):
  437. self.RunFullTrainingAndInference(
  438. 'simple-parser',
  439. 'simple_parser_master_spec.textproto',
  440. expected_num_actions=8,
  441. component_weights=[1],
  442. expected=_LABELED_PARSER_EXPECTED_SENTENCES)
  443. def checkOpOrder(self, name, endpoint, expected_op_order):
  444. """Checks that ops ending up at root are called in the expected order.
  445. To check the order, we find a path along the directed graph formed by
  446. the inputs of each op. If op X has a chain of inputs to op Y, then X
  447. cannot be executed before Y. There may be multiple paths between any two
  448. ops, but the ops along any path are executed in that order. Therefore, we
  449. look up the expected ops in reverse order.
  450. Args:
  451. name: string name of the endpoint, for logging.
  452. endpoint: node whose execution we want to check.
  453. expected_op_order: string list of op types, in the order we expecte them
  454. to be executed leading up to `endpoint`.
  455. """
  456. for target in reversed(expected_op_order):
  457. path = _find_input_path_to_type(endpoint, target)
  458. self.assertNotEmpty(path)
  459. logging.info('path[%d] from %s to %s: %s',
  460. len(path), name, target, [_as_op(x).type for x in path])
  461. endpoint = path[-1]
  462. def getBuilderAndTarget(
  463. self, test_name, master_spec_path='simple_parser_master_spec.textproto'):
  464. """Generates a MasterBuilder and TrainTarget based on a simple spec."""
  465. master_spec = self.LoadSpec(master_spec_path)
  466. hyperparam_config = spec_pb2.GridPoint()
  467. target = spec_pb2.TrainTarget()
  468. target.name = 'test-%s-train' % test_name
  469. target.component_weights.extend([0] * len(master_spec.component))
  470. target.component_weights[-1] = 1.0
  471. target.unroll_using_oracle.extend([False] * len(master_spec.component))
  472. target.unroll_using_oracle[-1] = True
  473. builder = graph_builder.MasterBuilder(
  474. master_spec, hyperparam_config, pool_scope=test_name)
  475. return builder, target
  476. def testGetSessionReleaseSession(self):
  477. """Checks that GetSession and ReleaseSession are called in order."""
  478. test_name = 'get-session-release-session'
  479. with tf.Graph().as_default():
  480. # Build the actual graphs. The choice of spec is arbitrary, as long as
  481. # training and annotation nodes can be constructed.
  482. builder, target = self.getBuilderAndTarget(test_name)
  483. train = builder.add_training_from_config(target)
  484. anno = builder.add_annotation(test_name)
  485. # We want to ensure that certain ops are executed in the correct order.
  486. # Specifically, the ops GetSession and ReleaseSession must both be called,
  487. # and in that order.
  488. #
  489. # First of all, the path to a non-existent node type should be empty.
  490. path = _find_input_path_to_type(train['run'], 'foo')
  491. self.assertEmpty(path)
  492. # The train['run'] is expected to start by calling GetSession, and to end
  493. # by calling ReleaseSession.
  494. self.checkOpOrder('train', train['run'], ['GetSession', 'ReleaseSession'])
  495. # A similar contract applies to the annotations.
  496. self.checkOpOrder('annotations', anno['annotations'],
  497. ['GetSession', 'ReleaseSession'])
  498. def testAttachDataReader(self):
  499. """Checks that train['run'] and 'annotations' call AttachDataReader."""
  500. test_name = 'attach-data-reader'
  501. with tf.Graph().as_default():
  502. builder, target = self.getBuilderAndTarget(test_name)
  503. train = builder.add_training_from_config(target)
  504. anno = builder.add_annotation(test_name)
  505. # AttachDataReader should be called between GetSession and ReleaseSession.
  506. self.checkOpOrder('train', train['run'],
  507. ['GetSession', 'AttachDataReader', 'ReleaseSession'])
  508. # A similar contract applies to the annotations.
  509. self.checkOpOrder('annotations', anno['annotations'],
  510. ['GetSession', 'AttachDataReader', 'ReleaseSession'])
  511. def testSetTracingFalse(self):
  512. """Checks that 'annotations' doesn't call SetTracing if disabled."""
  513. test_name = 'set-tracing-false'
  514. with tf.Graph().as_default():
  515. builder, _ = self.getBuilderAndTarget(test_name)
  516. # Note: "enable_tracing=False" is the default.
  517. anno = builder.add_annotation(test_name, enable_tracing=False)
  518. # ReleaseSession should still be there.
  519. path = _find_input_path_to_type(anno['annotations'], 'ReleaseSession')
  520. self.assertNotEmpty(path)
  521. # As should AttachDataReader.
  522. path = _find_input_path_to_type(path[-1], 'AttachDataReader')
  523. self.assertNotEmpty(path)
  524. # But SetTracing should not be called.
  525. set_tracing_path = _find_input_path_to_type(path[-1], 'SetTracing')
  526. self.assertEmpty(set_tracing_path)
  527. # Instead, we should go to GetSession.
  528. path = _find_input_path_to_type(path[-1], 'GetSession')
  529. self.assertNotEmpty(path)
  530. def testSetTracingTrue(self):
  531. """Checks that 'annotations' does call SetTracing if enabled."""
  532. test_name = 'set-tracing-true'
  533. with tf.Graph().as_default():
  534. builder, _ = self.getBuilderAndTarget(test_name)
  535. anno = builder.add_annotation(test_name, enable_tracing=True)
  536. # Check SetTracing is called after GetSession but before AttachDataReader.
  537. self.checkOpOrder('annotations', anno['annotations'], [
  538. 'GetSession', 'SetTracing', 'AttachDataReader', 'ReleaseSession'
  539. ])
  540. # Same for the 'traces' output, if that's what you were to call.
  541. self.checkOpOrder('traces', anno['traces'], [
  542. 'GetSession', 'SetTracing', 'AttachDataReader', 'ReleaseSession'
  543. ])
  544. if __name__ == '__main__':
  545. googletest.main()