bulk_component_test.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. """Tests for bulk_component.
  2. Verifies that:
  3. 1. BulkFeatureExtractor and BulkAnnotator both raise NotImplementedError when
  4. non-identity translator configured.
  5. 2. BulkFeatureExtractor and BulkAnnotator both raise RuntimeError when
  6. recurrent linked features are configured.
  7. 3. BulkAnnotator raises RuntimeError when fixed features are configured.
  8. 4. BulkFeatureIdExtractor raises ValueError when linked features are configured,
  9. or when the fixed features are invalid.
  10. """
  11. import os.path
  12. import tensorflow as tf
  13. from tensorflow.python.framework import test_util
  14. from tensorflow.python.platform import googletest
  15. from google.protobuf import text_format
  16. from dragnn.protos import spec_pb2
  17. from dragnn.python import bulk_component
  18. from dragnn.python import component
  19. from dragnn.python import dragnn_ops
  20. from dragnn.python import network_units
  21. from syntaxnet import sentence_pb2
  22. import dragnn.python.load_dragnn_cc_impl
  23. import syntaxnet.load_parser_ops
  24. FLAGS = tf.app.flags.FLAGS
  25. class MockNetworkUnit(object):
  26. def get_layer_size(self, unused_layer_name):
  27. return 64
  28. class MockComponent(object):
  29. def __init__(self):
  30. self.name = 'mock'
  31. self.network = MockNetworkUnit()
  32. class MockMaster(object):
  33. def __init__(self):
  34. self.spec = spec_pb2.MasterSpec()
  35. self.hyperparams = spec_pb2.GridPoint()
  36. self.lookup_component = {'mock': MockComponent()}
  37. def _create_fake_corpus():
  38. """Returns a list of fake serialized sentences for tests."""
  39. num_docs = 4
  40. corpus = []
  41. for num_tokens in range(1, num_docs + 1):
  42. sentence = sentence_pb2.Sentence()
  43. sentence.text = 'x' * num_tokens
  44. for i in range(num_tokens):
  45. token = sentence.token.add()
  46. token.word = 'x'
  47. token.start = i
  48. token.end = i
  49. corpus.append(sentence.SerializeToString())
  50. return corpus
  51. class BulkComponentTest(test_util.TensorFlowTestCase):
  52. def setUp(self):
  53. self.master = MockMaster()
  54. self.master_state = component.MasterState(
  55. handle='handle', current_batch_size=2)
  56. self.network_states = {
  57. 'mock': component.NetworkState(),
  58. 'test': component.NetworkState(),
  59. }
  60. def testFailsOnNonIdentityTranslator(self):
  61. component_spec = spec_pb2.ComponentSpec()
  62. text_format.Parse("""
  63. name: "test"
  64. network_unit {
  65. registered_name: "IdentityNetwork"
  66. }
  67. linked_feature {
  68. name: "features" embedding_dim: -1 size: 1
  69. source_translator: "history"
  70. source_component: "mock"
  71. }
  72. """, component_spec)
  73. # For feature extraction:
  74. with tf.Graph().as_default():
  75. comp = bulk_component.BulkFeatureExtractorComponentBuilder(
  76. self.master, component_spec)
  77. # Expect feature extraction to generate a error due to the "history"
  78. # translator.
  79. with self.assertRaises(NotImplementedError):
  80. comp.build_greedy_training(self.master_state, self.network_states)
  81. # As well as annotation:
  82. with tf.Graph().as_default():
  83. comp = bulk_component.BulkAnnotatorComponentBuilder(
  84. self.master, component_spec)
  85. with self.assertRaises(NotImplementedError):
  86. comp.build_greedy_training(self.master_state, self.network_states)
  87. def testFailsOnRecurrentLinkedFeature(self):
  88. component_spec = spec_pb2.ComponentSpec()
  89. text_format.Parse("""
  90. name: "test"
  91. network_unit {
  92. registered_name: "FeedForwardNetwork"
  93. parameters {
  94. key: 'hidden_layer_sizes' value: '64'
  95. }
  96. }
  97. linked_feature {
  98. name: "features" embedding_dim: -1 size: 1
  99. source_translator: "identity"
  100. source_component: "test"
  101. source_layer: "layer_0"
  102. }
  103. """, component_spec)
  104. # For feature extraction:
  105. with tf.Graph().as_default():
  106. comp = bulk_component.BulkFeatureExtractorComponentBuilder(
  107. self.master, component_spec)
  108. # Expect feature extraction to generate a error due to the "history"
  109. # translator.
  110. with self.assertRaises(RuntimeError):
  111. comp.build_greedy_training(self.master_state, self.network_states)
  112. # As well as annotation:
  113. with tf.Graph().as_default():
  114. comp = bulk_component.BulkAnnotatorComponentBuilder(
  115. self.master, component_spec)
  116. with self.assertRaises(RuntimeError):
  117. comp.build_greedy_training(self.master_state, self.network_states)
  118. def testConstantFixedFeatureFailsIfNotPretrained(self):
  119. component_spec = spec_pb2.ComponentSpec()
  120. text_format.Parse("""
  121. name: "test"
  122. network_unit {
  123. registered_name: "IdentityNetwork"
  124. }
  125. fixed_feature {
  126. name: "fixed" embedding_dim: 32 size: 1
  127. is_constant: true
  128. }
  129. component_builder {
  130. registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
  131. }
  132. """, component_spec)
  133. with tf.Graph().as_default():
  134. comp = bulk_component.BulkFeatureExtractorComponentBuilder(
  135. self.master, component_spec)
  136. with self.assertRaisesRegexp(ValueError,
  137. 'Constant embeddings must be pretrained'):
  138. comp.build_greedy_training(self.master_state, self.network_states)
  139. with self.assertRaisesRegexp(ValueError,
  140. 'Constant embeddings must be pretrained'):
  141. comp.build_greedy_inference(
  142. self.master_state, self.network_states, during_training=True)
  143. with self.assertRaisesRegexp(ValueError,
  144. 'Constant embeddings must be pretrained'):
  145. comp.build_greedy_inference(
  146. self.master_state, self.network_states, during_training=False)
  147. def testNormalFixedFeaturesAreDifferentiable(self):
  148. component_spec = spec_pb2.ComponentSpec()
  149. text_format.Parse("""
  150. name: "test"
  151. network_unit {
  152. registered_name: "IdentityNetwork"
  153. }
  154. fixed_feature {
  155. name: "fixed" embedding_dim: 32 size: 1
  156. pretrained_embedding_matrix { part {} }
  157. vocab { part {} }
  158. }
  159. component_builder {
  160. registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
  161. }
  162. """, component_spec)
  163. with tf.Graph().as_default():
  164. comp = bulk_component.BulkFeatureExtractorComponentBuilder(
  165. self.master, component_spec)
  166. # Get embedding matrix variables.
  167. with tf.variable_scope(comp.name, reuse=True):
  168. fixed_embedding_matrix = tf.get_variable(
  169. network_units.fixed_embeddings_name(0))
  170. # Get output layer.
  171. comp.build_greedy_training(self.master_state, self.network_states)
  172. activations = self.network_states[comp.name].activations
  173. outputs = activations[comp.network.layers[0].name].bulk_tensor
  174. # Compute the gradient of the output layer w.r.t. the embedding matrix.
  175. # This should be well-defined for in the normal case.
  176. gradients = tf.gradients(outputs, fixed_embedding_matrix)
  177. self.assertEqual(len(gradients), 1)
  178. self.assertFalse(gradients[0] is None)
  179. def testConstantFixedFeaturesAreNotDifferentiableButOthersAre(self):
  180. component_spec = spec_pb2.ComponentSpec()
  181. text_format.Parse("""
  182. name: "test"
  183. network_unit {
  184. registered_name: "IdentityNetwork"
  185. }
  186. fixed_feature {
  187. name: "constant" embedding_dim: 32 size: 1
  188. is_constant: true
  189. pretrained_embedding_matrix { part {} }
  190. vocab { part {} }
  191. }
  192. fixed_feature {
  193. name: "trainable" embedding_dim: 32 size: 1
  194. pretrained_embedding_matrix { part {} }
  195. vocab { part {} }
  196. }
  197. component_builder {
  198. registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
  199. }
  200. """, component_spec)
  201. with tf.Graph().as_default():
  202. comp = bulk_component.BulkFeatureExtractorComponentBuilder(
  203. self.master, component_spec)
  204. # Get embedding matrix variables.
  205. with tf.variable_scope(comp.name, reuse=True):
  206. constant_embedding_matrix = tf.get_variable(
  207. network_units.fixed_embeddings_name(0))
  208. trainable_embedding_matrix = tf.get_variable(
  209. network_units.fixed_embeddings_name(1))
  210. # Get output layer.
  211. comp.build_greedy_training(self.master_state, self.network_states)
  212. activations = self.network_states[comp.name].activations
  213. outputs = activations[comp.network.layers[0].name].bulk_tensor
  214. # The constant embeddings are non-differentiable.
  215. constant_gradients = tf.gradients(outputs, constant_embedding_matrix)
  216. self.assertEqual(len(constant_gradients), 1)
  217. self.assertTrue(constant_gradients[0] is None)
  218. # The trainable embeddings are differentiable.
  219. trainable_gradients = tf.gradients(outputs, trainable_embedding_matrix)
  220. self.assertEqual(len(trainable_gradients), 1)
  221. self.assertFalse(trainable_gradients[0] is None)
  222. def testFailsOnFixedFeature(self):
  223. component_spec = spec_pb2.ComponentSpec()
  224. text_format.Parse("""
  225. name: "annotate"
  226. network_unit {
  227. registered_name: "IdentityNetwork"
  228. }
  229. fixed_feature {
  230. name: "fixed" embedding_dim: 32 size: 1
  231. }
  232. """, component_spec)
  233. with tf.Graph().as_default():
  234. comp = bulk_component.BulkAnnotatorComponentBuilder(
  235. self.master, component_spec)
  236. # Expect feature extraction to generate a runtime error due to the
  237. # fixed feature.
  238. with self.assertRaises(RuntimeError):
  239. comp.build_greedy_training(self.master_state, self.network_states)
  240. def testBulkFeatureIdExtractorOkWithOneFixedFeature(self):
  241. component_spec = spec_pb2.ComponentSpec()
  242. text_format.Parse("""
  243. name: "test"
  244. network_unit {
  245. registered_name: "IdentityNetwork"
  246. }
  247. fixed_feature {
  248. name: "fixed" embedding_dim: -1 size: 1
  249. }
  250. """, component_spec)
  251. with tf.Graph().as_default():
  252. comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
  253. self.master, component_spec)
  254. # Should not raise errors.
  255. self.network_states[component_spec.name] = component.NetworkState()
  256. comp.build_greedy_training(self.master_state, self.network_states)
  257. self.network_states[component_spec.name] = component.NetworkState()
  258. comp.build_greedy_inference(self.master_state, self.network_states)
  259. def testBulkFeatureIdExtractorFailsOnLinkedFeature(self):
  260. component_spec = spec_pb2.ComponentSpec()
  261. text_format.Parse("""
  262. name: "test"
  263. network_unit {
  264. registered_name: "IdentityNetwork"
  265. }
  266. fixed_feature {
  267. name: "fixed" embedding_dim: -1 size: 1
  268. }
  269. linked_feature {
  270. name: "linked" embedding_dim: -1 size: 1
  271. source_translator: "identity"
  272. source_component: "mock"
  273. }
  274. """, component_spec)
  275. with tf.Graph().as_default():
  276. with self.assertRaises(ValueError):
  277. unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
  278. self.master, component_spec)
  279. def testBulkFeatureIdExtractorOkWithMultipleFixedFeatures(self):
  280. component_spec = spec_pb2.ComponentSpec()
  281. text_format.Parse("""
  282. name: "test"
  283. network_unit {
  284. registered_name: "IdentityNetwork"
  285. }
  286. fixed_feature {
  287. name: "fixed1" embedding_dim: -1 size: 1
  288. }
  289. fixed_feature {
  290. name: "fixed2" embedding_dim: -1 size: 1
  291. }
  292. fixed_feature {
  293. name: "fixed3" embedding_dim: -1 size: 1
  294. }
  295. """, component_spec)
  296. with tf.Graph().as_default():
  297. comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
  298. self.master, component_spec)
  299. # Should not raise errors.
  300. self.network_states[component_spec.name] = component.NetworkState()
  301. comp.build_greedy_training(self.master_state, self.network_states)
  302. self.network_states[component_spec.name] = component.NetworkState()
  303. comp.build_greedy_inference(self.master_state, self.network_states)
  304. def testBulkFeatureIdExtractorFailsOnEmbeddedFixedFeature(self):
  305. component_spec = spec_pb2.ComponentSpec()
  306. text_format.Parse("""
  307. name: "test"
  308. network_unit {
  309. registered_name: "IdentityNetwork"
  310. }
  311. fixed_feature {
  312. name: "fixed" embedding_dim: 2 size: 1
  313. }
  314. """, component_spec)
  315. with tf.Graph().as_default():
  316. with self.assertRaises(ValueError):
  317. unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
  318. self.master, component_spec)
  319. def testBulkFeatureIdExtractorExtractFocusWithOffset(self):
  320. path = os.path.join(tf.test.get_temp_dir(), 'label-map')
  321. with open(path, 'w') as label_map_file:
  322. label_map_file.write('0\n')
  323. master_spec = spec_pb2.MasterSpec()
  324. text_format.Parse("""
  325. component {
  326. name: "test"
  327. transition_system {
  328. registered_name: "shift-only"
  329. }
  330. resource {
  331. name: "label-map"
  332. part {
  333. file_pattern: "%s"
  334. file_format: "text"
  335. }
  336. }
  337. network_unit {
  338. registered_name: "ExportFixedFeaturesNetwork"
  339. }
  340. backend {
  341. registered_name: "SyntaxNetComponent"
  342. }
  343. fixed_feature {
  344. name: "focus1" embedding_dim: -1 size: 1 fml: "input.focus"
  345. predicate_map: "none"
  346. }
  347. fixed_feature {
  348. name: "focus2" embedding_dim: -1 size: 1 fml: "input(1).focus"
  349. predicate_map: "none"
  350. }
  351. fixed_feature {
  352. name: "focus3" embedding_dim: -1 size: 1 fml: "input(2).focus"
  353. predicate_map: "none"
  354. }
  355. }
  356. """ % path, master_spec)
  357. with tf.Graph().as_default():
  358. corpus = _create_fake_corpus()
  359. corpus = tf.constant(corpus, shape=[len(corpus)])
  360. handle = dragnn_ops.get_session(
  361. container='test',
  362. master_spec=master_spec.SerializeToString(),
  363. grid_point='')
  364. handle = dragnn_ops.attach_data_reader(handle, corpus)
  365. handle = dragnn_ops.init_component_data(
  366. handle, beam_size=1, component='test')
  367. batch_size = dragnn_ops.batch_size(handle, component='test')
  368. master_state = component.MasterState(handle, batch_size)
  369. extractor = bulk_component.BulkFeatureIdExtractorComponentBuilder(
  370. self.master, master_spec.component[0])
  371. network_state = component.NetworkState()
  372. self.network_states['test'] = network_state
  373. handle = extractor.build_greedy_inference(master_state,
  374. self.network_states)
  375. focus1 = network_state.activations['focus1'].bulk_tensor
  376. focus2 = network_state.activations['focus2'].bulk_tensor
  377. focus3 = network_state.activations['focus3'].bulk_tensor
  378. with self.test_session() as sess:
  379. focus1, focus2, focus3 = sess.run([focus1, focus2, focus3])
  380. tf.logging.info('focus1=\n%s', focus1)
  381. tf.logging.info('focus2=\n%s', focus2)
  382. tf.logging.info('focus3=\n%s', focus3)
  383. self.assertAllEqual(
  384. focus1,
  385. [[0], [-1], [-1], [-1],
  386. [0], [1], [-1], [-1],
  387. [0], [1], [2], [-1],
  388. [0], [1], [2], [3]])
  389. self.assertAllEqual(
  390. focus2,
  391. [[-1], [-1], [-1], [-1],
  392. [1], [-1], [-1], [-1],
  393. [1], [2], [-1], [-1],
  394. [1], [2], [3], [-1]])
  395. self.assertAllEqual(
  396. focus3,
  397. [[-1], [-1], [-1], [-1],
  398. [-1], [-1], [-1], [-1],
  399. [2], [-1], [-1], [-1],
  400. [2], [3], [-1], [-1]])
  401. if __name__ == '__main__':
  402. googletest.main()