bulk_component_test.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for bulk_component.
  16. Verifies that:
  17. 1. BulkFeatureExtractor and BulkAnnotator both raise NotImplementedError when
  18. non-identity translator configured.
  19. 2. BulkFeatureExtractor and BulkAnnotator both raise RuntimeError when
  20. recurrent linked features are configured.
  21. 3. BulkAnnotator raises RuntimeError when fixed features are configured.
  22. 4. BulkFeatureIdExtractor raises ValueError when linked features are configured,
  23. or when the fixed features are invalid.
  24. """
  25. import os.path
  26. import tensorflow as tf
  27. from tensorflow.python.framework import test_util
  28. from tensorflow.python.platform import googletest
  29. from google.protobuf import text_format
  30. from dragnn.protos import spec_pb2
  31. from dragnn.python import bulk_component
  32. from dragnn.python import component
  33. from dragnn.python import dragnn_ops
  34. from dragnn.python import network_units
  35. from syntaxnet import sentence_pb2
  36. import dragnn.python.load_dragnn_cc_impl
  37. import syntaxnet.load_parser_ops
  38. FLAGS = tf.app.flags.FLAGS
  39. class MockNetworkUnit(object):
  40. def get_layer_size(self, unused_layer_name):
  41. return 64
  42. class MockComponent(object):
  43. def __init__(self):
  44. self.name = 'mock'
  45. self.network = MockNetworkUnit()
  46. class MockMaster(object):
  47. def __init__(self):
  48. self.spec = spec_pb2.MasterSpec()
  49. self.hyperparams = spec_pb2.GridPoint()
  50. self.lookup_component = {'mock': MockComponent()}
  51. def _create_fake_corpus():
  52. """Returns a list of fake serialized sentences for tests."""
  53. num_docs = 4
  54. corpus = []
  55. for num_tokens in range(1, num_docs + 1):
  56. sentence = sentence_pb2.Sentence()
  57. sentence.text = 'x' * num_tokens
  58. for i in range(num_tokens):
  59. token = sentence.token.add()
  60. token.word = 'x'
  61. token.start = i
  62. token.end = i
  63. corpus.append(sentence.SerializeToString())
  64. return corpus
  65. class BulkComponentTest(test_util.TensorFlowTestCase):
  66. def setUp(self):
  67. self.master = MockMaster()
  68. self.master_state = component.MasterState(
  69. handle='handle', current_batch_size=2)
  70. self.network_states = {
  71. 'mock': component.NetworkState(),
  72. 'test': component.NetworkState(),
  73. }
  74. def testFailsOnNonIdentityTranslator(self):
  75. component_spec = spec_pb2.ComponentSpec()
  76. text_format.Parse("""
  77. name: "test"
  78. network_unit {
  79. registered_name: "IdentityNetwork"
  80. }
  81. linked_feature {
  82. name: "features" embedding_dim: -1 size: 1
  83. source_translator: "history"
  84. source_component: "mock"
  85. }
  86. """, component_spec)
  87. # For feature extraction:
  88. with tf.Graph().as_default():
  89. comp = bulk_component.BulkFeatureExtractorComponentBuilder(
  90. self.master, component_spec)
  91. # Expect feature extraction to generate a error due to the "history"
  92. # translator.
  93. with self.assertRaises(NotImplementedError):
  94. comp.build_greedy_training(self.master_state, self.network_states)
  95. # As well as annotation:
  96. with tf.Graph().as_default():
  97. comp = bulk_component.BulkAnnotatorComponentBuilder(
  98. self.master, component_spec)
  99. with self.assertRaises(NotImplementedError):
  100. comp.build_greedy_training(self.master_state, self.network_states)
  101. def testFailsOnRecurrentLinkedFeature(self):
  102. component_spec = spec_pb2.ComponentSpec()
  103. text_format.Parse("""
  104. name: "test"
  105. network_unit {
  106. registered_name: "FeedForwardNetwork"
  107. parameters {
  108. key: 'hidden_layer_sizes' value: '64'
  109. }
  110. }
  111. linked_feature {
  112. name: "features" embedding_dim: -1 size: 1
  113. source_translator: "identity"
  114. source_component: "test"
  115. source_layer: "layer_0"
  116. }
  117. """, component_spec)
  118. # For feature extraction:
  119. with tf.Graph().as_default():
  120. comp = bulk_component.BulkFeatureExtractorComponentBuilder(
  121. self.master, component_spec)
  122. # Expect feature extraction to generate a error due to the "history"
  123. # translator.
  124. with self.assertRaises(RuntimeError):
  125. comp.build_greedy_training(self.master_state, self.network_states)
  126. # As well as annotation:
  127. with tf.Graph().as_default():
  128. comp = bulk_component.BulkAnnotatorComponentBuilder(
  129. self.master, component_spec)
  130. with self.assertRaises(RuntimeError):
  131. comp.build_greedy_training(self.master_state, self.network_states)
  132. def testConstantFixedFeatureFailsIfNotPretrained(self):
  133. component_spec = spec_pb2.ComponentSpec()
  134. text_format.Parse("""
  135. name: "test"
  136. network_unit {
  137. registered_name: "IdentityNetwork"
  138. }
  139. fixed_feature {
  140. name: "fixed" embedding_dim: 32 size: 1
  141. is_constant: true
  142. }
  143. component_builder {
  144. registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
  145. }
  146. """, component_spec)
  147. with tf.Graph().as_default():
  148. comp = bulk_component.BulkFeatureExtractorComponentBuilder(
  149. self.master, component_spec)
  150. with self.assertRaisesRegexp(ValueError,
  151. 'Constant embeddings must be pretrained'):
  152. comp.build_greedy_training(self.master_state, self.network_states)
  153. with self.assertRaisesRegexp(ValueError,
  154. 'Constant embeddings must be pretrained'):
  155. comp.build_greedy_inference(
  156. self.master_state, self.network_states, during_training=True)
  157. with self.assertRaisesRegexp(ValueError,
  158. 'Constant embeddings must be pretrained'):
  159. comp.build_greedy_inference(
  160. self.master_state, self.network_states, during_training=False)
  161. def testNormalFixedFeaturesAreDifferentiable(self):
  162. component_spec = spec_pb2.ComponentSpec()
  163. text_format.Parse("""
  164. name: "test"
  165. network_unit {
  166. registered_name: "IdentityNetwork"
  167. }
  168. fixed_feature {
  169. name: "fixed" embedding_dim: 32 size: 1
  170. pretrained_embedding_matrix { part {} }
  171. vocab { part {} }
  172. }
  173. component_builder {
  174. registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
  175. }
  176. """, component_spec)
  177. with tf.Graph().as_default():
  178. comp = bulk_component.BulkFeatureExtractorComponentBuilder(
  179. self.master, component_spec)
  180. # Get embedding matrix variables.
  181. with tf.variable_scope(comp.name, reuse=True):
  182. fixed_embedding_matrix = tf.get_variable(
  183. network_units.fixed_embeddings_name(0))
  184. # Get output layer.
  185. comp.build_greedy_training(self.master_state, self.network_states)
  186. activations = self.network_states[comp.name].activations
  187. outputs = activations[comp.network.layers[0].name].bulk_tensor
  188. # Compute the gradient of the output layer w.r.t. the embedding matrix.
  189. # This should be well-defined for in the normal case.
  190. gradients = tf.gradients(outputs, fixed_embedding_matrix)
  191. self.assertEqual(len(gradients), 1)
  192. self.assertFalse(gradients[0] is None)
  193. def testConstantFixedFeaturesAreNotDifferentiableButOthersAre(self):
  194. component_spec = spec_pb2.ComponentSpec()
  195. text_format.Parse("""
  196. name: "test"
  197. network_unit {
  198. registered_name: "IdentityNetwork"
  199. }
  200. fixed_feature {
  201. name: "constant" embedding_dim: 32 size: 1
  202. is_constant: true
  203. pretrained_embedding_matrix { part {} }
  204. vocab { part {} }
  205. }
  206. fixed_feature {
  207. name: "trainable" embedding_dim: 32 size: 1
  208. pretrained_embedding_matrix { part {} }
  209. vocab { part {} }
  210. }
  211. component_builder {
  212. registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
  213. }
  214. """, component_spec)
  215. with tf.Graph().as_default():
  216. comp = bulk_component.BulkFeatureExtractorComponentBuilder(
  217. self.master, component_spec)
  218. # Get embedding matrix variables.
  219. with tf.variable_scope(comp.name, reuse=True):
  220. constant_embedding_matrix = tf.get_variable(
  221. network_units.fixed_embeddings_name(0))
  222. trainable_embedding_matrix = tf.get_variable(
  223. network_units.fixed_embeddings_name(1))
  224. # Get output layer.
  225. comp.build_greedy_training(self.master_state, self.network_states)
  226. activations = self.network_states[comp.name].activations
  227. outputs = activations[comp.network.layers[0].name].bulk_tensor
  228. # The constant embeddings are non-differentiable.
  229. constant_gradients = tf.gradients(outputs, constant_embedding_matrix)
  230. self.assertEqual(len(constant_gradients), 1)
  231. self.assertTrue(constant_gradients[0] is None)
  232. # The trainable embeddings are differentiable.
  233. trainable_gradients = tf.gradients(outputs, trainable_embedding_matrix)
  234. self.assertEqual(len(trainable_gradients), 1)
  235. self.assertFalse(trainable_gradients[0] is None)
  236. def testFailsOnFixedFeature(self):
  237. component_spec = spec_pb2.ComponentSpec()
  238. text_format.Parse("""
  239. name: "annotate"
  240. network_unit {
  241. registered_name: "IdentityNetwork"
  242. }
  243. fixed_feature {
  244. name: "fixed" embedding_dim: 32 size: 1
  245. }
  246. """, component_spec)
  247. with tf.Graph().as_default():
  248. comp = bulk_component.BulkAnnotatorComponentBuilder(
  249. self.master, component_spec)
  250. # Expect feature extraction to generate a runtime error due to the
  251. # fixed feature.
  252. with self.assertRaises(RuntimeError):
  253. comp.build_greedy_training(self.master_state, self.network_states)
  254. def testBulkFeatureIdExtractorOkWithOneFixedFeature(self):
  255. component_spec = spec_pb2.ComponentSpec()
  256. text_format.Parse("""
  257. name: "test"
  258. network_unit {
  259. registered_name: "IdentityNetwork"
  260. }
  261. fixed_feature {
  262. name: "fixed" embedding_dim: -1 size: 1
  263. }
  264. """, component_spec)
  265. with tf.Graph().as_default():
  266. comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
  267. self.master, component_spec)
  268. # Should not raise errors.
  269. self.network_states[component_spec.name] = component.NetworkState()
  270. comp.build_greedy_training(self.master_state, self.network_states)
  271. self.network_states[component_spec.name] = component.NetworkState()
  272. comp.build_greedy_inference(self.master_state, self.network_states)
  273. def testBulkFeatureIdExtractorFailsOnLinkedFeature(self):
  274. component_spec = spec_pb2.ComponentSpec()
  275. text_format.Parse("""
  276. name: "test"
  277. network_unit {
  278. registered_name: "IdentityNetwork"
  279. }
  280. fixed_feature {
  281. name: "fixed" embedding_dim: -1 size: 1
  282. }
  283. linked_feature {
  284. name: "linked" embedding_dim: -1 size: 1
  285. source_translator: "identity"
  286. source_component: "mock"
  287. }
  288. """, component_spec)
  289. with tf.Graph().as_default():
  290. with self.assertRaises(ValueError):
  291. unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
  292. self.master, component_spec)
  293. def testBulkFeatureIdExtractorOkWithMultipleFixedFeatures(self):
  294. component_spec = spec_pb2.ComponentSpec()
  295. text_format.Parse("""
  296. name: "test"
  297. network_unit {
  298. registered_name: "IdentityNetwork"
  299. }
  300. fixed_feature {
  301. name: "fixed1" embedding_dim: -1 size: 1
  302. }
  303. fixed_feature {
  304. name: "fixed2" embedding_dim: -1 size: 1
  305. }
  306. fixed_feature {
  307. name: "fixed3" embedding_dim: -1 size: 1
  308. }
  309. """, component_spec)
  310. with tf.Graph().as_default():
  311. comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
  312. self.master, component_spec)
  313. # Should not raise errors.
  314. self.network_states[component_spec.name] = component.NetworkState()
  315. comp.build_greedy_training(self.master_state, self.network_states)
  316. self.network_states[component_spec.name] = component.NetworkState()
  317. comp.build_greedy_inference(self.master_state, self.network_states)
  318. def testBulkFeatureIdExtractorFailsOnEmbeddedFixedFeature(self):
  319. component_spec = spec_pb2.ComponentSpec()
  320. text_format.Parse("""
  321. name: "test"
  322. network_unit {
  323. registered_name: "IdentityNetwork"
  324. }
  325. fixed_feature {
  326. name: "fixed" embedding_dim: 2 size: 1
  327. }
  328. """, component_spec)
  329. with tf.Graph().as_default():
  330. with self.assertRaises(ValueError):
  331. unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
  332. self.master, component_spec)
  333. def testBulkFeatureIdExtractorExtractFocusWithOffset(self):
  334. path = os.path.join(tf.test.get_temp_dir(), 'label-map')
  335. with open(path, 'w') as label_map_file:
  336. label_map_file.write('0\n')
  337. master_spec = spec_pb2.MasterSpec()
  338. text_format.Parse("""
  339. component {
  340. name: "test"
  341. transition_system {
  342. registered_name: "shift-only"
  343. }
  344. resource {
  345. name: "label-map"
  346. part {
  347. file_pattern: "%s"
  348. file_format: "text"
  349. }
  350. }
  351. network_unit {
  352. registered_name: "ExportFixedFeaturesNetwork"
  353. }
  354. backend {
  355. registered_name: "SyntaxNetComponent"
  356. }
  357. fixed_feature {
  358. name: "focus1" embedding_dim: -1 size: 1 fml: "input.focus"
  359. predicate_map: "none"
  360. }
  361. fixed_feature {
  362. name: "focus2" embedding_dim: -1 size: 1 fml: "input(1).focus"
  363. predicate_map: "none"
  364. }
  365. fixed_feature {
  366. name: "focus3" embedding_dim: -1 size: 1 fml: "input(2).focus"
  367. predicate_map: "none"
  368. }
  369. }
  370. """ % path, master_spec)
  371. with tf.Graph().as_default():
  372. corpus = _create_fake_corpus()
  373. corpus = tf.constant(corpus, shape=[len(corpus)])
  374. handle = dragnn_ops.get_session(
  375. container='test',
  376. master_spec=master_spec.SerializeToString(),
  377. grid_point='')
  378. handle = dragnn_ops.attach_data_reader(handle, corpus)
  379. handle = dragnn_ops.init_component_data(
  380. handle, beam_size=1, component='test')
  381. batch_size = dragnn_ops.batch_size(handle, component='test')
  382. master_state = component.MasterState(handle, batch_size)
  383. extractor = bulk_component.BulkFeatureIdExtractorComponentBuilder(
  384. self.master, master_spec.component[0])
  385. network_state = component.NetworkState()
  386. self.network_states['test'] = network_state
  387. handle = extractor.build_greedy_inference(master_state,
  388. self.network_states)
  389. focus1 = network_state.activations['focus1'].bulk_tensor
  390. focus2 = network_state.activations['focus2'].bulk_tensor
  391. focus3 = network_state.activations['focus3'].bulk_tensor
  392. with self.test_session() as sess:
  393. focus1, focus2, focus3 = sess.run([focus1, focus2, focus3])
  394. tf.logging.info('focus1=\n%s', focus1)
  395. tf.logging.info('focus2=\n%s', focus2)
  396. tf.logging.info('focus3=\n%s', focus3)
  397. self.assertAllEqual(
  398. focus1,
  399. [[0], [-1], [-1], [-1],
  400. [0], [1], [-1], [-1],
  401. [0], [1], [2], [-1],
  402. [0], [1], [2], [3]])
  403. self.assertAllEqual(
  404. focus2,
  405. [[-1], [-1], [-1], [-1],
  406. [1], [-1], [-1], [-1],
  407. [1], [2], [-1], [-1],
  408. [1], [2], [3], [-1]])
  409. self.assertAllEqual(
  410. focus3,
  411. [[-1], [-1], [-1], [-1],
  412. [-1], [-1], [-1], [-1],
  413. [2], [-1], [-1], [-1],
  414. [2], [3], [-1], [-1]])
  415. if __name__ == '__main__':
  416. googletest.main()