skip_thoughts_model_test.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for tensorflow_models.skip_thoughts.skip_thoughts_model."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import numpy as np
  20. import tensorflow as tf
  21. from skip_thoughts import configuration
  22. from skip_thoughts import skip_thoughts_model
  23. class SkipThoughtsModel(skip_thoughts_model.SkipThoughtsModel):
  24. """Subclass of SkipThoughtsModel without the disk I/O."""
  25. def build_inputs(self):
  26. if self.mode == "encode":
  27. # Encode mode doesn't read from disk, so defer to parent.
  28. return super(SkipThoughtsModel, self).build_inputs()
  29. else:
  30. # Replace disk I/O with random Tensors.
  31. self.encode_ids = tf.random_uniform(
  32. [self.config.batch_size, 15],
  33. minval=0,
  34. maxval=self.config.vocab_size,
  35. dtype=tf.int64)
  36. self.decode_pre_ids = tf.random_uniform(
  37. [self.config.batch_size, 15],
  38. minval=0,
  39. maxval=self.config.vocab_size,
  40. dtype=tf.int64)
  41. self.decode_post_ids = tf.random_uniform(
  42. [self.config.batch_size, 15],
  43. minval=0,
  44. maxval=self.config.vocab_size,
  45. dtype=tf.int64)
  46. self.encode_mask = tf.ones_like(self.encode_ids)
  47. self.decode_pre_mask = tf.ones_like(self.decode_pre_ids)
  48. self.decode_post_mask = tf.ones_like(self.decode_post_ids)
  49. class SkipThoughtsModelTest(tf.test.TestCase):
  50. def setUp(self):
  51. super(SkipThoughtsModelTest, self).setUp()
  52. self._model_config = configuration.model_config()
  53. def _countModelParameters(self):
  54. """Counts the number of parameters in the model at top level scope."""
  55. counter = {}
  56. for v in tf.global_variables():
  57. name = v.op.name.split("/")[0]
  58. num_params = v.get_shape().num_elements()
  59. if not num_params:
  60. self.fail("Could not infer num_elements from Variable %s" % v.op.name)
  61. counter[name] = counter.get(name, 0) + num_params
  62. return counter
  63. def _checkModelParameters(self):
  64. """Verifies the number of parameters in the model."""
  65. param_counts = self._countModelParameters()
  66. expected_param_counts = {
  67. # vocab_size * embedding_size
  68. "word_embedding": 12400000,
  69. # GRU Cells
  70. "encoder": 21772800,
  71. "decoder_pre": 21772800,
  72. "decoder_post": 21772800,
  73. # (encoder_dim + 1) * vocab_size
  74. "logits": 48020000,
  75. "global_step": 1,
  76. }
  77. self.assertDictEqual(expected_param_counts, param_counts)
  78. def _checkOutputs(self, expected_shapes, feed_dict=None):
  79. """Verifies that the model produces expected outputs.
  80. Args:
  81. expected_shapes: A dict mapping Tensor or Tensor name to expected output
  82. shape.
  83. feed_dict: Values of Tensors to feed into Session.run().
  84. """
  85. fetches = expected_shapes.keys()
  86. with self.test_session() as sess:
  87. sess.run(tf.global_variables_initializer())
  88. outputs = sess.run(fetches, feed_dict)
  89. for index, output in enumerate(outputs):
  90. tensor = fetches[index]
  91. expected = expected_shapes[tensor]
  92. actual = output.shape
  93. if expected != actual:
  94. self.fail("Tensor %s has shape %s (expected %s)." % (tensor, actual,
  95. expected))
  96. def testBuildForTraining(self):
  97. model = SkipThoughtsModel(self._model_config, mode="train")
  98. model.build()
  99. self._checkModelParameters()
  100. expected_shapes = {
  101. # [batch_size, length]
  102. model.encode_ids: (128, 15),
  103. model.decode_pre_ids: (128, 15),
  104. model.decode_post_ids: (128, 15),
  105. model.encode_mask: (128, 15),
  106. model.decode_pre_mask: (128, 15),
  107. model.decode_post_mask: (128, 15),
  108. # [batch_size, length, word_embedding_dim]
  109. model.encode_emb: (128, 15, 620),
  110. model.decode_pre_emb: (128, 15, 620),
  111. model.decode_post_emb: (128, 15, 620),
  112. # [batch_size, encoder_dim]
  113. model.thought_vectors: (128, 2400),
  114. # [batch_size * length]
  115. model.target_cross_entropy_losses[0]: (1920,),
  116. model.target_cross_entropy_losses[1]: (1920,),
  117. # [batch_size * length]
  118. model.target_cross_entropy_loss_weights[0]: (1920,),
  119. model.target_cross_entropy_loss_weights[1]: (1920,),
  120. # Scalar
  121. model.total_loss: (),
  122. }
  123. self._checkOutputs(expected_shapes)
  124. def testBuildForEval(self):
  125. model = SkipThoughtsModel(self._model_config, mode="eval")
  126. model.build()
  127. self._checkModelParameters()
  128. expected_shapes = {
  129. # [batch_size, length]
  130. model.encode_ids: (128, 15),
  131. model.decode_pre_ids: (128, 15),
  132. model.decode_post_ids: (128, 15),
  133. model.encode_mask: (128, 15),
  134. model.decode_pre_mask: (128, 15),
  135. model.decode_post_mask: (128, 15),
  136. # [batch_size, length, word_embedding_dim]
  137. model.encode_emb: (128, 15, 620),
  138. model.decode_pre_emb: (128, 15, 620),
  139. model.decode_post_emb: (128, 15, 620),
  140. # [batch_size, encoder_dim]
  141. model.thought_vectors: (128, 2400),
  142. # [batch_size * length]
  143. model.target_cross_entropy_losses[0]: (1920,),
  144. model.target_cross_entropy_losses[1]: (1920,),
  145. # [batch_size * length]
  146. model.target_cross_entropy_loss_weights[0]: (1920,),
  147. model.target_cross_entropy_loss_weights[1]: (1920,),
  148. # Scalar
  149. model.total_loss: (),
  150. }
  151. self._checkOutputs(expected_shapes)
  152. def testBuildForEncode(self):
  153. model = SkipThoughtsModel(self._model_config, mode="encode")
  154. model.build()
  155. # Test feeding a batch of word embeddings to get skip thought vectors.
  156. encode_emb = np.random.rand(64, 15, 620)
  157. encode_mask = np.ones((64, 15), dtype=np.int64)
  158. feed_dict = {model.encode_emb: encode_emb, model.encode_mask: encode_mask}
  159. expected_shapes = {
  160. # [batch_size, encoder_dim]
  161. model.thought_vectors: (64, 2400),
  162. }
  163. self._checkOutputs(expected_shapes, feed_dict)
  164. if __name__ == "__main__":
  165. tf.test.main()