vgsl_model_test.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for vgsl_model."""
  16. import os
  17. import numpy as np
  18. import tensorflow as tf
  19. import vgsl_input
  20. import vgsl_model
  21. def _testdata(filename):
  22. return os.path.join('../testdata/', filename)
  23. def _rand(*size):
  24. return np.random.uniform(size=size).astype('f')
  25. class VgslModelTest(tf.test.TestCase):
  26. def testParseInputSpec(self):
  27. """The parser must return the numbers in the correct order.
  28. """
  29. shape = vgsl_model._ParseInputSpec(input_spec='32,42,256,3')
  30. self.assertEqual(
  31. shape,
  32. vgsl_input.ImageShape(
  33. batch_size=32, height=42, width=256, depth=3))
  34. # Nones must be inserted for zero sizes.
  35. shape = vgsl_model._ParseInputSpec(input_spec='1,0,0,3')
  36. self.assertEqual(
  37. shape,
  38. vgsl_input.ImageShape(
  39. batch_size=1, height=None, width=None, depth=3))
  40. def testParseOutputSpec(self):
  41. """The parser must return the correct args in the correct order.
  42. """
  43. out_dims, out_func, num_classes = vgsl_model._ParseOutputSpec(
  44. output_spec='O1c142')
  45. self.assertEqual(out_dims, 1)
  46. self.assertEqual(out_func, 'c')
  47. self.assertEqual(num_classes, 142)
  48. out_dims, out_func, num_classes = vgsl_model._ParseOutputSpec(
  49. output_spec='O2s99')
  50. self.assertEqual(out_dims, 2)
  51. self.assertEqual(out_func, 's')
  52. self.assertEqual(num_classes, 99)
  53. out_dims, out_func, num_classes = vgsl_model._ParseOutputSpec(
  54. output_spec='O0l12')
  55. self.assertEqual(out_dims, 0)
  56. self.assertEqual(out_func, 'l')
  57. self.assertEqual(num_classes, 12)
  58. def testPadLabels2d(self):
  59. """Must pad timesteps in labels to match logits.
  60. """
  61. with self.test_session() as sess:
  62. # Make placeholders for logits and labels.
  63. ph_logits = tf.placeholder(tf.float32, shape=(None, None, 42))
  64. ph_labels = tf.placeholder(tf.int64, shape=(None, None))
  65. padded_labels = vgsl_model._PadLabels2d(tf.shape(ph_logits)[1], ph_labels)
  66. # Make actual inputs.
  67. real_logits = _rand(4, 97, 42)
  68. real_labels = _rand(4, 85)
  69. np_array = sess.run([padded_labels],
  70. feed_dict={ph_logits: real_logits,
  71. ph_labels: real_labels})[0]
  72. self.assertEqual(tuple(np_array.shape), (4, 97))
  73. real_labels = _rand(4, 97)
  74. np_array = sess.run([padded_labels],
  75. feed_dict={ph_logits: real_logits,
  76. ph_labels: real_labels})[0]
  77. self.assertEqual(tuple(np_array.shape), (4, 97))
  78. real_labels = _rand(4, 100)
  79. np_array = sess.run([padded_labels],
  80. feed_dict={ph_logits: real_logits,
  81. ph_labels: real_labels})[0]
  82. self.assertEqual(tuple(np_array.shape), (4, 97))
  83. def testPadLabels3d(self):
  84. """Must pad height and width in labels to match logits.
  85. The tricky thing with 3-d is that the rows and columns need to remain
  86. intact, so we'll test it with small known data.
  87. """
  88. with self.test_session() as sess:
  89. # Make placeholders for logits and labels.
  90. ph_logits = tf.placeholder(tf.float32, shape=(None, None, None, 42))
  91. ph_labels = tf.placeholder(tf.int64, shape=(None, None, None))
  92. padded_labels = vgsl_model._PadLabels3d(ph_logits, ph_labels)
  93. # Make actual inputs.
  94. real_logits = _rand(1, 3, 4, 42)
  95. # Test all 9 combinations of height x width in [small, ok, big]
  96. real_labels = np.arange(6).reshape((1, 2, 3)) # Height small, width small
  97. np_array = sess.run([padded_labels],
  98. feed_dict={ph_logits: real_logits,
  99. ph_labels: real_labels})[0]
  100. self.assertEqual(tuple(np_array.shape), (1, 3, 4))
  101. self.assertAllEqual(np_array[0, :, :],
  102. [[0, 1, 2, 0], [3, 4, 5, 0], [0, 0, 0, 0]])
  103. real_labels = np.arange(8).reshape((1, 2, 4)) # Height small, width ok
  104. np_array = sess.run([padded_labels],
  105. feed_dict={ph_logits: real_logits,
  106. ph_labels: real_labels})[0]
  107. self.assertEqual(tuple(np_array.shape), (1, 3, 4))
  108. self.assertAllEqual(np_array[0, :, :],
  109. [[0, 1, 2, 3], [4, 5, 6, 7], [0, 0, 0, 0]])
  110. real_labels = np.arange(10).reshape((1, 2, 5)) # Height small, width big
  111. np_array = sess.run([padded_labels],
  112. feed_dict={ph_logits: real_logits,
  113. ph_labels: real_labels})[0]
  114. self.assertEqual(tuple(np_array.shape), (1, 3, 4))
  115. self.assertAllEqual(np_array[0, :, :],
  116. [[0, 1, 2, 3], [5, 6, 7, 8], [0, 0, 0, 0]])
  117. real_labels = np.arange(9).reshape((1, 3, 3)) # Height ok, width small
  118. np_array = sess.run([padded_labels],
  119. feed_dict={ph_logits: real_logits,
  120. ph_labels: real_labels})[0]
  121. self.assertEqual(tuple(np_array.shape), (1, 3, 4))
  122. self.assertAllEqual(np_array[0, :, :],
  123. [[0, 1, 2, 0], [3, 4, 5, 0], [6, 7, 8, 0]])
  124. real_labels = np.arange(12).reshape((1, 3, 4)) # Height ok, width ok
  125. np_array = sess.run([padded_labels],
  126. feed_dict={ph_logits: real_logits,
  127. ph_labels: real_labels})[0]
  128. self.assertEqual(tuple(np_array.shape), (1, 3, 4))
  129. self.assertAllEqual(np_array[0, :, :],
  130. [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
  131. real_labels = np.arange(15).reshape((1, 3, 5)) # Height ok, width big
  132. np_array = sess.run([padded_labels],
  133. feed_dict={ph_logits: real_logits,
  134. ph_labels: real_labels})[0]
  135. self.assertEqual(tuple(np_array.shape), (1, 3, 4))
  136. self.assertAllEqual(np_array[0, :, :],
  137. [[0, 1, 2, 3], [5, 6, 7, 8], [10, 11, 12, 13]])
  138. real_labels = np.arange(12).reshape((1, 4, 3)) # Height big, width small
  139. np_array = sess.run([padded_labels],
  140. feed_dict={ph_logits: real_logits,
  141. ph_labels: real_labels})[0]
  142. self.assertEqual(tuple(np_array.shape), (1, 3, 4))
  143. self.assertAllEqual(np_array[0, :, :],
  144. [[0, 1, 2, 0], [3, 4, 5, 0], [6, 7, 8, 0]])
  145. real_labels = np.arange(16).reshape((1, 4, 4)) # Height big, width ok
  146. np_array = sess.run([padded_labels],
  147. feed_dict={ph_logits: real_logits,
  148. ph_labels: real_labels})[0]
  149. self.assertEqual(tuple(np_array.shape), (1, 3, 4))
  150. self.assertAllEqual(np_array[0, :, :],
  151. [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
  152. real_labels = np.arange(20).reshape((1, 4, 5)) # Height big, width big
  153. np_array = sess.run([padded_labels],
  154. feed_dict={ph_logits: real_logits,
  155. ph_labels: real_labels})[0]
  156. self.assertEqual(tuple(np_array.shape), (1, 3, 4))
  157. self.assertAllEqual(np_array[0, :, :],
  158. [[0, 1, 2, 3], [5, 6, 7, 8], [10, 11, 12, 13]])
  159. def testEndToEndSizes0d(self):
  160. """Tests that the output sizes match when training/running real 0d data.
  161. Uses mnist with dual summarizing LSTMs to reduce to a single value.
  162. """
  163. filename = _testdata('mnist-tiny')
  164. with self.test_session() as sess:
  165. model = vgsl_model.InitNetwork(
  166. filename,
  167. model_spec='4,0,0,1[Cr5,5,16 Mp3,3 Lfys16 Lfxs16]O0s12',
  168. mode='train')
  169. tf.global_variables_initializer().run(session=sess)
  170. coord = tf.train.Coordinator()
  171. tf.train.start_queue_runners(sess=sess, coord=coord)
  172. _, step = model.TrainAStep(sess)
  173. self.assertEqual(step, 1)
  174. output, labels = model.RunAStep(sess)
  175. self.assertEqual(len(output.shape), 2)
  176. self.assertEqual(len(labels.shape), 1)
  177. self.assertEqual(output.shape[0], labels.shape[0])
  178. self.assertEqual(output.shape[1], 12)
  179. # TODO(rays) Support logistic and test with Imagenet (as 0d, multi-object.)
  180. def testEndToEndSizes1dCTC(self):
  181. """Tests that the output sizes match when training with CTC.
  182. Basic bidi LSTM on top of convolution and summarizing LSTM with CTC.
  183. """
  184. filename = _testdata('arial-32-tiny')
  185. with self.test_session() as sess:
  186. model = vgsl_model.InitNetwork(
  187. filename,
  188. model_spec='2,0,0,1[Cr5,5,16 Mp3,3 Lfys16 Lbx100]O1c105',
  189. mode='train')
  190. tf.global_variables_initializer().run(session=sess)
  191. coord = tf.train.Coordinator()
  192. tf.train.start_queue_runners(sess=sess, coord=coord)
  193. _, step = model.TrainAStep(sess)
  194. self.assertEqual(step, 1)
  195. output, labels = model.RunAStep(sess)
  196. self.assertEqual(len(output.shape), 3)
  197. self.assertEqual(len(labels.shape), 2)
  198. self.assertEqual(output.shape[0], labels.shape[0])
  199. # This is ctc - the only cast-iron guarantee is labels <= output.
  200. self.assertLessEqual(labels.shape[1], output.shape[1])
  201. self.assertEqual(output.shape[2], 105)
  202. def testEndToEndSizes1dFixed(self):
  203. """Tests that the output sizes match when training/running 1 data.
  204. Convolution, summarizing LSTM with fwd rev fwd to allow no CTC.
  205. """
  206. filename = _testdata('numbers-16-tiny')
  207. with self.test_session() as sess:
  208. model = vgsl_model.InitNetwork(
  209. filename,
  210. model_spec='8,0,0,1[Cr5,5,16 Mp3,3 Lfys16 Lfx64 Lrx64 Lfx64]O1s12',
  211. mode='train')
  212. tf.global_variables_initializer().run(session=sess)
  213. coord = tf.train.Coordinator()
  214. tf.train.start_queue_runners(sess=sess, coord=coord)
  215. _, step = model.TrainAStep(sess)
  216. self.assertEqual(step, 1)
  217. output, labels = model.RunAStep(sess)
  218. self.assertEqual(len(output.shape), 3)
  219. self.assertEqual(len(labels.shape), 2)
  220. self.assertEqual(output.shape[0], labels.shape[0])
  221. # Not CTC, output lengths match.
  222. self.assertEqual(output.shape[1], labels.shape[1])
  223. self.assertEqual(output.shape[2], 12)
  224. # TODO(rays) Get a 2-d dataset and support 2d (heat map) outputs.
  225. if __name__ == '__main__':
  226. tf.test.main()