graph_builder_test.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for graph_builder."""
  16. import collections
  17. import os.path
  18. import numpy as np
  19. import tensorflow as tf
  20. from google.protobuf import text_format
  21. from dragnn.protos import spec_pb2
  22. from dragnn.protos import trace_pb2
  23. from dragnn.python import dragnn_ops
  24. from dragnn.python import graph_builder
  25. from syntaxnet import sentence_pb2
  26. from tensorflow.python.framework import test_util
  27. from tensorflow.python.platform import googletest
  28. from tensorflow.python.platform import tf_logging as logging
  29. import dragnn.python.load_dragnn_cc_impl
  30. import syntaxnet.load_parser_ops
  31. FLAGS = tf.app.flags.FLAGS
  32. if not hasattr(FLAGS, 'test_srcdir'):
  33. FLAGS.test_srcdir = ''
  34. if not hasattr(FLAGS, 'test_tmpdir'):
  35. FLAGS.test_tmpdir = tf.test.get_temp_dir()
  36. _DUMMY_GOLD_SENTENCE = """
  37. token {
  38. word: "sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
  39. }
  40. token {
  41. word: "0" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
  42. }
  43. token {
  44. word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
  45. }
  46. """
  47. # The second sentence has different length, to test the effect of
  48. # mixed-length batches.
  49. _DUMMY_GOLD_SENTENCE_2 = """
  50. token {
  51. word: "sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
  52. }
  53. """
  54. # The test sentence is the gold sentence with the tags and parse information
  55. # removed.
  56. _DUMMY_TEST_SENTENCE = """
  57. token {
  58. word: "sentence" start: 0 end: 7
  59. }
  60. token {
  61. word: "0" start: 9 end: 9
  62. }
  63. token {
  64. word: "." start: 10 end: 10
  65. }
  66. """
  67. _DUMMY_TEST_SENTENCE_2 = """
  68. token {
  69. word: "sentence" start: 0 end: 7
  70. }
  71. """
  72. _TAGGER_EXPECTED_SENTENCES = [
  73. """
  74. token {
  75. word: "sentence" start: 0 end: 7 tag: "NN"
  76. }
  77. token {
  78. word: "0" start: 9 end: 9 tag: "CD"
  79. }
  80. token {
  81. word: "." start: 10 end: 10 tag: "."
  82. }
  83. """, """
  84. token {
  85. word: "sentence" start: 0 end: 7 tag: "NN"
  86. }
  87. """
  88. ]
  89. _TAGGER_PARSER_EXPECTED_SENTENCES = [
  90. """
  91. token {
  92. word: "sentence" start: 0 end: 7 tag: "NN" label: "ROOT"
  93. }
  94. token {
  95. word: "0" start: 9 end: 9 head: 0 tag: "CD" label: "num"
  96. }
  97. token {
  98. word: "." start: 10 end: 10 head: 0 tag: "." label: "punct"
  99. }
  100. """, """
  101. token {
  102. word: "sentence" start: 0 end: 7 tag: "NN" label: "ROOT"
  103. }
  104. """
  105. ]
  106. _UNLABELED_PARSER_EXPECTED_SENTENCES = [
  107. """
  108. token {
  109. word: "sentence" start: 0 end: 7 label: "punct"
  110. }
  111. token {
  112. word: "0" start: 9 end: 9 head: 0 label: "punct"
  113. }
  114. token {
  115. word: "." start: 10 end: 10 head: 0 label: "punct"
  116. }
  117. """, """
  118. token {
  119. word: "sentence" start: 0 end: 7 label: "punct"
  120. }
  121. """
  122. ]
  123. _LABELED_PARSER_EXPECTED_SENTENCES = [
  124. """
  125. token {
  126. word: "sentence" start: 0 end: 7 label: "ROOT"
  127. }
  128. token {
  129. word: "0" start: 9 end: 9 head: 0 label: "num"
  130. }
  131. token {
  132. word: "." start: 10 end: 10 head: 0 label: "punct"
  133. }
  134. """, """
  135. token {
  136. word: "sentence" start: 0 end: 7 label: "ROOT"
  137. }
  138. """
  139. ]
  140. def _as_op(x):
  141. """Always returns the tf.Operation associated with a node."""
  142. return x.op if isinstance(x, tf.Tensor) else x
  143. def _find_input_path(src, dst_predicate):
  144. """Finds an input path from `src` to a node that satisfies `dst_predicate`.
  145. TensorFlow graphs are directed. We generate paths from outputs to inputs,
  146. recursively searching both direct (i.e. data) and control inputs. Graphs with
  147. while_loop control flow may contain cycles. Therefore we eliminate loops
  148. during the DFS.
  149. Args:
  150. src: tf.Tensor or tf.Operation root node.
  151. dst_predicate: function taking one argument (a node), returning true iff a
  152. a target node has been found.
  153. Returns:
  154. a path from `src` to the first node that satisfies dest_predicate, or the
  155. empty list otherwise.
  156. """
  157. path_to = {src: None}
  158. def dfs(x):
  159. if dst_predicate(x):
  160. return x
  161. x_op = _as_op(x)
  162. for y in x_op.control_inputs + list(x_op.inputs):
  163. # Check if we've already visited node `y`.
  164. if y not in path_to:
  165. path_to[y] = x
  166. res = dfs(y)
  167. if res is not None:
  168. return res
  169. return None
  170. dst = dfs(src)
  171. path = []
  172. while dst in path_to:
  173. path.append(dst)
  174. dst = path_to[dst]
  175. return list(reversed(path))
  176. def _find_input_path_to_type(src, dst_type):
  177. """Finds a path from `src` to a node with type (i.e. kernel) `dst_type`."""
  178. return _find_input_path(src, lambda x: _as_op(x).type == dst_type)
  179. class GraphBuilderTest(test_util.TensorFlowTestCase):
  180. def assertEmpty(self, container, msg=None):
  181. """Assert that an object has zero length.
  182. Args:
  183. container: Anything that implements the collections.Sized interface.
  184. msg: Optional message to report on failure.
  185. """
  186. if not isinstance(container, collections.Sized):
  187. self.fail('Expected a Sized object, got: '
  188. '{!r}'.format(type(container).__name__), msg)
  189. # explicitly check the length since some Sized objects (e.g. numpy.ndarray)
  190. # have strange __nonzero__/__bool__ behavior.
  191. if len(container):
  192. self.fail('{!r} has length of {}.'.format(container, len(container)), msg)
  193. def assertNotEmpty(self, container, msg=None):
  194. """Assert that an object has non-zero length.
  195. Args:
  196. container: Anything that implements the collections.Sized interface.
  197. msg: Optional message to report on failure.
  198. """
  199. if not isinstance(container, collections.Sized):
  200. self.fail('Expected a Sized object, got: '
  201. '{!r}'.format(type(container).__name__), msg)
  202. # explicitly check the length since some Sized objects (e.g. numpy.ndarray)
  203. # have strange __nonzero__/__bool__ behavior.
  204. if not len(container):
  205. self.fail('{!r} has length of 0.'.format(container), msg)
  206. def LoadSpec(self, spec_path):
  207. master_spec = spec_pb2.MasterSpec()
  208. testdata = os.path.join(FLAGS.test_srcdir,
  209. 'dragnn/core/testdata')
  210. with file(os.path.join(testdata, spec_path), 'r') as fin:
  211. text_format.Parse(fin.read().replace('TESTDATA', testdata), master_spec)
  212. return master_spec
  213. def MakeHyperparams(self, **kwargs):
  214. hyperparam_config = spec_pb2.GridPoint()
  215. for key in kwargs:
  216. setattr(hyperparam_config, key, kwargs[key])
  217. return hyperparam_config
  218. def RunTraining(self, hyperparam_config):
  219. master_spec = self.LoadSpec('master_spec_link.textproto')
  220. self.assertTrue(isinstance(hyperparam_config, spec_pb2.GridPoint))
  221. gold_doc = sentence_pb2.Sentence()
  222. text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc)
  223. gold_doc_2 = sentence_pb2.Sentence()
  224. text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2)
  225. reader_strings = [
  226. gold_doc.SerializeToString(), gold_doc_2.SerializeToString()
  227. ]
  228. tf.logging.info('Generating graph with config: %s', hyperparam_config)
  229. with tf.Graph().as_default():
  230. builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
  231. target = spec_pb2.TrainTarget()
  232. target.name = 'testTraining-all'
  233. train = builder.add_training_from_config(target)
  234. with self.test_session() as sess:
  235. logging.info('Initializing')
  236. sess.run(tf.global_variables_initializer())
  237. # Run one iteration of training and verify nothing crashes.
  238. logging.info('Training')
  239. sess.run(train['run'], feed_dict={train['input_batch']: reader_strings})
  240. def testTraining(self):
  241. """Tests the default hyperparameter settings."""
  242. self.RunTraining(self.MakeHyperparams())
  243. def testTrainingWithGradientClipping(self):
  244. """Adds code coverage for gradient clipping."""
  245. self.RunTraining(self.MakeHyperparams(gradient_clip_norm=1.25))
  246. def testTrainingWithAdamAndAveraging(self):
  247. """Adds code coverage for ADAM and the use of moving averaging."""
  248. self.RunTraining(
  249. self.MakeHyperparams(learning_method='adam', use_moving_average=True))
  250. def testTrainingWithCompositeOptimizer(self):
  251. """Adds code coverage for CompositeOptimizer."""
  252. grid_point = self.MakeHyperparams(learning_method='composite')
  253. grid_point.composite_optimizer_spec.method1.learning_method = 'adam'
  254. grid_point.composite_optimizer_spec.method2.learning_method = 'momentum'
  255. grid_point.composite_optimizer_spec.method2.momentum = 0.9
  256. self.RunTraining(grid_point)
  257. def RunFullTrainingAndInference(self,
  258. test_name,
  259. master_spec_path=None,
  260. master_spec=None,
  261. component_weights=None,
  262. unroll_using_oracle=None,
  263. num_evaluated_components=1,
  264. expected_num_actions=None,
  265. expected=None,
  266. batch_size_limit=None):
  267. if not master_spec:
  268. master_spec = self.LoadSpec(master_spec_path)
  269. gold_doc = sentence_pb2.Sentence()
  270. text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc)
  271. gold_doc_2 = sentence_pb2.Sentence()
  272. text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2)
  273. gold_reader_strings = [
  274. gold_doc.SerializeToString(), gold_doc_2.SerializeToString()
  275. ]
  276. test_doc = sentence_pb2.Sentence()
  277. text_format.Parse(_DUMMY_TEST_SENTENCE, test_doc)
  278. test_doc_2 = sentence_pb2.Sentence()
  279. text_format.Parse(_DUMMY_TEST_SENTENCE_2, test_doc_2)
  280. test_reader_strings = [
  281. test_doc.SerializeToString(), test_doc.SerializeToString(),
  282. test_doc_2.SerializeToString(), test_doc.SerializeToString()
  283. ]
  284. if batch_size_limit is not None:
  285. gold_reader_strings = gold_reader_strings[:batch_size_limit]
  286. test_reader_strings = test_reader_strings[:batch_size_limit]
  287. with tf.Graph().as_default():
  288. tf.set_random_seed(1)
  289. hyperparam_config = spec_pb2.GridPoint()
  290. builder = graph_builder.MasterBuilder(
  291. master_spec, hyperparam_config, pool_scope=test_name)
  292. target = spec_pb2.TrainTarget()
  293. target.name = 'testFullInference-train-%s' % test_name
  294. if component_weights:
  295. target.component_weights.extend(component_weights)
  296. else:
  297. target.component_weights.extend([0] * len(master_spec.component))
  298. target.component_weights[-1] = 1.0
  299. if unroll_using_oracle:
  300. target.unroll_using_oracle.extend(unroll_using_oracle)
  301. else:
  302. target.unroll_using_oracle.extend([False] * len(master_spec.component))
  303. target.unroll_using_oracle[-1] = True
  304. train = builder.add_training_from_config(target)
  305. oracle_trace = builder.add_training_from_config(
  306. target, prefix='train_traced-', trace_only=True)
  307. builder.add_saver()
  308. anno = builder.add_annotation(test_name)
  309. trace = builder.add_annotation(test_name + '-traced', enable_tracing=True)
  310. # Verifies that the summaries can be built.
  311. for component in builder.components:
  312. component.get_summaries()
  313. config = tf.ConfigProto(
  314. intra_op_parallelism_threads=0, inter_op_parallelism_threads=0)
  315. with self.test_session(config=config) as sess:
  316. logging.info('Initializing')
  317. sess.run(tf.global_variables_initializer())
  318. logging.info('Dry run oracle trace...')
  319. traces = sess.run(
  320. oracle_trace['traces'],
  321. feed_dict={oracle_trace['input_batch']: gold_reader_strings})
  322. # Check that the oracle traces are not empty.
  323. for serialized_trace in traces:
  324. master_trace = trace_pb2.MasterTrace()
  325. master_trace.ParseFromString(serialized_trace)
  326. self.assertTrue(master_trace.component_trace)
  327. self.assertTrue(master_trace.component_trace[0].step_trace)
  328. logging.info('Simulating training...')
  329. break_iter = 400
  330. is_resolved = False
  331. for i in range(0,
  332. 400): # needs ~100 iterations, but is not deterministic
  333. cost, eval_res_val = sess.run(
  334. [train['cost'], train['metrics']],
  335. feed_dict={train['input_batch']: gold_reader_strings})
  336. logging.info('cost = %s', cost)
  337. self.assertFalse(np.isnan(cost))
  338. total_val = eval_res_val.reshape((-1, 2))[:, 0].sum()
  339. correct_val = eval_res_val.reshape((-1, 2))[:, 1].sum()
  340. if correct_val == total_val and not is_resolved:
  341. logging.info('... converged on iteration %d with (correct, total) '
  342. '= (%d, %d)', i, correct_val, total_val)
  343. is_resolved = True
  344. # Run for slightly longer than convergence to help with quantized
  345. # weight tiebreakers.
  346. break_iter = i + 50
  347. if i == break_iter:
  348. break
  349. # If training failed, report total/correct actions for each component.
  350. if not expected_num_actions:
  351. expected_num_actions = 4 * num_evaluated_components
  352. if (correct_val != total_val or correct_val != expected_num_actions or
  353. total_val != expected_num_actions):
  354. for c in xrange(len(master_spec.component)):
  355. logging.error('component %s:\nname=%s\ntotal=%s\ncorrect=%s', c,
  356. master_spec.component[c].name, eval_res_val[2 * c],
  357. eval_res_val[2 * c + 1])
  358. assert correct_val == total_val, 'Did not converge! %d vs %d.' % (
  359. correct_val, total_val)
  360. self.assertEqual(expected_num_actions, correct_val)
  361. self.assertEqual(expected_num_actions, total_val)
  362. builder.saver.save(sess, os.path.join(FLAGS.test_tmpdir, 'model'))
  363. logging.info('Running test.')
  364. logging.info('Printing annotations')
  365. annotations = sess.run(
  366. anno['annotations'],
  367. feed_dict={anno['input_batch']: test_reader_strings})
  368. logging.info('Put %d inputs in, got %d annotations out.',
  369. len(test_reader_strings), len(annotations))
  370. # Also run the annotation graph with tracing enabled.
  371. annotations_with_trace, traces = sess.run(
  372. [trace['annotations'], trace['traces']],
  373. feed_dict={trace['input_batch']: test_reader_strings})
  374. # The result of the two annotation graphs should be identical.
  375. self.assertItemsEqual(annotations, annotations_with_trace)
  376. # Check that the inference traces are not empty.
  377. for serialized_trace in traces:
  378. master_trace = trace_pb2.MasterTrace()
  379. master_trace.ParseFromString(serialized_trace)
  380. self.assertTrue(master_trace.component_trace)
  381. self.assertTrue(master_trace.component_trace[0].step_trace)
  382. self.assertEqual(len(test_reader_strings), len(annotations))
  383. pred_sentences = []
  384. for annotation in annotations:
  385. pred_sentences.append(sentence_pb2.Sentence())
  386. pred_sentences[-1].ParseFromString(annotation)
  387. if expected is None:
  388. expected = _TAGGER_EXPECTED_SENTENCES
  389. expected_sentences = [expected[i] for i in [0, 0, 1, 0]]
  390. for i, pred_sentence in enumerate(pred_sentences):
  391. self.assertProtoEquals(expected_sentences[i], pred_sentence)
  392. def testSimpleTagger(self):
  393. self.RunFullTrainingAndInference('simple-tagger',
  394. 'simple_tagger_master_spec.textproto')
  395. def testSimpleTaggerLayerNorm(self):
  396. spec = self.LoadSpec('simple_tagger_master_spec.textproto')
  397. spec.component[0].network_unit.parameters['layer_norm_hidden'] = 'True'
  398. spec.component[0].network_unit.parameters['layer_norm_input'] = 'True'
  399. self.RunFullTrainingAndInference('simple-tagger', master_spec=spec)
  400. def testSimpleTaggerLSTM(self):
  401. self.RunFullTrainingAndInference('simple-tagger-lstm',
  402. 'simple_tagger_lstm_master_spec.textproto')
  403. def testSimpleTaggerWrappedLSTM(self):
  404. self.RunFullTrainingAndInference(
  405. 'simple-tagger-wrapped-lstm',
  406. 'simple_tagger_wrapped_lstm_master_spec.textproto')
  407. def testSplitTagger(self):
  408. self.RunFullTrainingAndInference('split-tagger',
  409. 'split_tagger_master_spec.textproto')
  410. def testTaggerParser(self):
  411. self.RunFullTrainingAndInference(
  412. 'tagger-parser',
  413. 'tagger_parser_master_spec.textproto',
  414. component_weights=[0., 1., 1.],
  415. unroll_using_oracle=[False, True, True],
  416. expected_num_actions=12,
  417. expected=_TAGGER_PARSER_EXPECTED_SENTENCES)
  418. def testTaggerParserWithAttention(self):
  419. spec = self.LoadSpec('tagger_parser_master_spec.textproto')
  420. # Make the 'parser' component attend to the 'tagger' component.
  421. self.assertEqual('tagger', spec.component[1].name)
  422. self.assertEqual('parser', spec.component[2].name)
  423. spec.component[2].attention_component = 'tagger'
  424. # Attention + beam decoding is not yet supported.
  425. spec.component[2].inference_beam_size = 1
  426. # Running with batch size equal to 1 should be fine.
  427. self.RunFullTrainingAndInference(
  428. 'tagger-parser',
  429. master_spec=spec,
  430. batch_size_limit=1,
  431. component_weights=[0., 1., 1.],
  432. unroll_using_oracle=[False, True, True],
  433. expected_num_actions=9,
  434. expected=_TAGGER_PARSER_EXPECTED_SENTENCES)
  435. def testTaggerParserWithAttentionBatchDeath(self):
  436. spec = self.LoadSpec('tagger_parser_master_spec.textproto')
  437. # Make the 'parser' component attend to the 'tagger' component.
  438. self.assertEqual('tagger', spec.component[1].name)
  439. self.assertEqual('parser', spec.component[2].name)
  440. spec.component[2].attention_component = 'tagger'
  441. # Trying to run with a batch size greater than 1 should fail:
  442. with self.assertRaises(tf.errors.InvalidArgumentError):
  443. self.RunFullTrainingAndInference(
  444. 'tagger-parser',
  445. master_spec=spec,
  446. component_weights=[0., 1., 1.],
  447. unroll_using_oracle=[False, True, True],
  448. expected_num_actions=9,
  449. expected=_TAGGER_PARSER_EXPECTED_SENTENCES)
  450. def testStructuredTrainingNotImplementedDeath(self):
  451. spec = self.LoadSpec('simple_parser_master_spec.textproto')
  452. # Make the 'parser' component have a beam at training time.
  453. self.assertEqual('parser', spec.component[0].name)
  454. spec.component[0].training_beam_size = 8
  455. # The training run should fail at runtime rather than build time.
  456. with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
  457. r'\[Not implemented.\]'):
  458. self.RunFullTrainingAndInference(
  459. 'simple-parser',
  460. master_spec=spec,
  461. expected_num_actions=8,
  462. component_weights=[1],
  463. expected=_LABELED_PARSER_EXPECTED_SENTENCES)
  464. def testSimpleParser(self):
  465. self.RunFullTrainingAndInference(
  466. 'simple-parser',
  467. 'simple_parser_master_spec.textproto',
  468. expected_num_actions=8,
  469. component_weights=[1],
  470. expected=_LABELED_PARSER_EXPECTED_SENTENCES)
  471. def checkOpOrder(self, name, endpoint, expected_op_order):
  472. """Checks that ops ending up at root are called in the expected order.
  473. To check the order, we find a path along the directed graph formed by
  474. the inputs of each op. If op X has a chain of inputs to op Y, then X
  475. cannot be executed before Y. There may be multiple paths between any two
  476. ops, but the ops along any path are executed in that order. Therefore, we
  477. look up the expected ops in reverse order.
  478. Args:
  479. name: string name of the endpoint, for logging.
  480. endpoint: node whose execution we want to check.
  481. expected_op_order: string list of op types, in the order we expecte them
  482. to be executed leading up to `endpoint`.
  483. """
  484. for target in reversed(expected_op_order):
  485. path = _find_input_path_to_type(endpoint, target)
  486. self.assertNotEmpty(path)
  487. logging.info('path[%d] from %s to %s: %s',
  488. len(path), name, target, [_as_op(x).type for x in path])
  489. endpoint = path[-1]
  490. def getBuilderAndTarget(
  491. self, test_name, master_spec_path='simple_parser_master_spec.textproto'):
  492. """Generates a MasterBuilder and TrainTarget based on a simple spec."""
  493. master_spec = self.LoadSpec(master_spec_path)
  494. hyperparam_config = spec_pb2.GridPoint()
  495. target = spec_pb2.TrainTarget()
  496. target.name = 'test-%s-train' % test_name
  497. target.component_weights.extend([0] * len(master_spec.component))
  498. target.component_weights[-1] = 1.0
  499. target.unroll_using_oracle.extend([False] * len(master_spec.component))
  500. target.unroll_using_oracle[-1] = True
  501. builder = graph_builder.MasterBuilder(
  502. master_spec, hyperparam_config, pool_scope=test_name)
  503. return builder, target
  504. def testGetSessionReleaseSession(self):
  505. """Checks that GetSession and ReleaseSession are called in order."""
  506. test_name = 'get-session-release-session'
  507. with tf.Graph().as_default():
  508. # Build the actual graphs. The choice of spec is arbitrary, as long as
  509. # training and annotation nodes can be constructed.
  510. builder, target = self.getBuilderAndTarget(test_name)
  511. train = builder.add_training_from_config(target)
  512. anno = builder.add_annotation(test_name)
  513. # We want to ensure that certain ops are executed in the correct order.
  514. # Specifically, the ops GetSession and ReleaseSession must both be called,
  515. # and in that order.
  516. #
  517. # First of all, the path to a non-existent node type should be empty.
  518. path = _find_input_path_to_type(train['run'], 'foo')
  519. self.assertEmpty(path)
  520. # The train['run'] is expected to start by calling GetSession, and to end
  521. # by calling ReleaseSession.
  522. self.checkOpOrder('train', train['run'], ['GetSession', 'ReleaseSession'])
  523. # A similar contract applies to the annotations.
  524. self.checkOpOrder('annotations', anno['annotations'],
  525. ['GetSession', 'ReleaseSession'])
  526. def testAttachDataReader(self):
  527. """Checks that train['run'] and 'annotations' call AttachDataReader."""
  528. test_name = 'attach-data-reader'
  529. with tf.Graph().as_default():
  530. builder, target = self.getBuilderAndTarget(test_name)
  531. train = builder.add_training_from_config(target)
  532. anno = builder.add_annotation(test_name)
  533. # AttachDataReader should be called between GetSession and ReleaseSession.
  534. self.checkOpOrder('train', train['run'],
  535. ['GetSession', 'AttachDataReader', 'ReleaseSession'])
  536. # A similar contract applies to the annotations.
  537. self.checkOpOrder('annotations', anno['annotations'],
  538. ['GetSession', 'AttachDataReader', 'ReleaseSession'])
  539. def testSetTracingFalse(self):
  540. """Checks that 'annotations' doesn't call SetTracing if disabled."""
  541. test_name = 'set-tracing-false'
  542. with tf.Graph().as_default():
  543. builder, _ = self.getBuilderAndTarget(test_name)
  544. # Note: "enable_tracing=False" is the default.
  545. anno = builder.add_annotation(test_name, enable_tracing=False)
  546. # ReleaseSession should still be there.
  547. path = _find_input_path_to_type(anno['annotations'], 'ReleaseSession')
  548. self.assertNotEmpty(path)
  549. # As should AttachDataReader.
  550. path = _find_input_path_to_type(path[-1], 'AttachDataReader')
  551. self.assertNotEmpty(path)
  552. # But SetTracing should not be called.
  553. set_tracing_path = _find_input_path_to_type(path[-1], 'SetTracing')
  554. self.assertEmpty(set_tracing_path)
  555. # Instead, we should go to GetSession.
  556. path = _find_input_path_to_type(path[-1], 'GetSession')
  557. self.assertNotEmpty(path)
  558. def testSetTracingTrue(self):
  559. """Checks that 'annotations' does call SetTracing if enabled."""
  560. test_name = 'set-tracing-true'
  561. with tf.Graph().as_default():
  562. builder, _ = self.getBuilderAndTarget(test_name)
  563. anno = builder.add_annotation(test_name, enable_tracing=True)
  564. # Check SetTracing is called after GetSession but before AttachDataReader.
  565. self.checkOpOrder('annotations', anno['annotations'], [
  566. 'GetSession', 'SetTracing', 'AttachDataReader', 'ReleaseSession'
  567. ])
  568. # Same for the 'traces' output, if that's what you were to call.
  569. self.checkOpOrder('traces', anno['traces'], [
  570. 'GetSession', 'SetTracing', 'AttachDataReader', 'ReleaseSession'
  571. ])
  572. if __name__ == '__main__':
  573. googletest.main()