vgslspecs_test.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for vgslspecs."""
  16. import numpy as np
  17. import tensorflow as tf
  18. import vgslspecs
  19. def _rand(*size):
  20. return np.random.uniform(size=size).astype('f')
  21. class VgslspecsTest(tf.test.TestCase):
  22. def __init__(self, other):
  23. super(VgslspecsTest, self).__init__(other)
  24. self.max_width = 36
  25. self.max_height = 24
  26. self.batch_size = 4
  27. def SetupInputs(self):
  28. # Make placeholders for standard inputs.
  29. # Everything is variable in the input, except the depth.
  30. self.ph_image = tf.placeholder(
  31. tf.float32, shape=(None, None, None, 3), name='inputs')
  32. self.ph_widths = tf.placeholder(tf.int64, shape=(None,), name='w')
  33. self.ph_heights = tf.placeholder(tf.int64, shape=(None,), name='h')
  34. # Make actual inputs.
  35. self.in_image = _rand(self.batch_size, self.max_height, self.max_width, 3)
  36. self.in_widths = [24, 12, self.max_width, 30]
  37. self.in_heights = [self.max_height, 18, 12, 6]
  38. def ExpectScaledSize(self, spec, target_shape, factor=1):
  39. """Tests that the output of the graph of the given spec has target_shape."""
  40. with tf.Graph().as_default():
  41. with self.test_session() as sess:
  42. self.SetupInputs()
  43. # Only the placeholders are given at construction time.
  44. vgsl = vgslspecs.VGSLSpecs(self.ph_widths, self.ph_heights, True)
  45. outputs = vgsl.Build(self.ph_image, spec)
  46. # Compute the expected output widths from the given scale factor.
  47. target_widths = tf.div(self.in_widths, factor).eval()
  48. target_heights = tf.div(self.in_heights, factor).eval()
  49. # Run with the 'real' data.
  50. tf.global_variables_initializer().run()
  51. res_image, res_widths, res_heights = sess.run(
  52. [outputs, vgsl.GetLengths(2), vgsl.GetLengths(1)],
  53. feed_dict={self.ph_image: self.in_image,
  54. self.ph_widths: self.in_widths,
  55. self.ph_heights: self.in_heights})
  56. self.assertEqual(tuple(res_image.shape), target_shape)
  57. if target_shape[1] > 1:
  58. self.assertEqual(tuple(res_heights), tuple(target_heights))
  59. if target_shape[2] > 1:
  60. self.assertEqual(tuple(res_widths), tuple(target_widths))
  61. def testSameSizeConv(self):
  62. """Test all types of Conv. There is no scaling."""
  63. self.ExpectScaledSize(
  64. '[Cs{MyConv}5,5,16 Ct3,3,12 Cr4,4,24 Cl5,5,64]',
  65. (self.batch_size, self.max_height, self.max_width, 64))
  66. def testSameSizeLSTM(self):
  67. """Test all non-reducing LSTMs. Output depth is doubled with BiDi."""
  68. self.ExpectScaledSize('[Lfx16 Lrx8 Do Lbx24 Lfy12 Do{MyDo} Lry7 Lby32]',
  69. (self.batch_size, self.max_height, self.max_width,
  70. 64))
  71. def testSameSizeParallel(self):
  72. """Parallel affects depth, but not scale."""
  73. self.ExpectScaledSize('[Cs5,5,16 (Lfx{MyLSTM}32 Lrx32 Lbx16)]',
  74. (self.batch_size, self.max_height, self.max_width,
  75. 96))
  76. def testScalingOps(self):
  77. """Test a heterogeneous series with scaling."""
  78. self.ExpectScaledSize('[Cs5,5,16 Mp{MyPool}2,2 Ct3,3,32 Mp3,3 Lfx32 Lry64]',
  79. (self.batch_size, self.max_height / 6,
  80. self.max_width / 6, 64), 6)
  81. def testXReduction(self):
  82. """Test a heterogeneous series with reduction of x-dimension."""
  83. self.ExpectScaledSize('[Cr5,5,16 Mp2,2 Ct3,3,32 Mp3,3 Lfxs32 Lry64]',
  84. (self.batch_size, self.max_height / 6, 1, 64), 6)
  85. def testYReduction(self):
  86. """Test a heterogeneous series with reduction of y-dimension."""
  87. self.ExpectScaledSize('[Cl5,5,16 Mp2,2 Ct3,3,32 Mp3,3 Lfys32 Lfx64]',
  88. (self.batch_size, 1, self.max_width / 6, 64), 6)
  89. def testXYReduction(self):
  90. """Test a heterogeneous series with reduction to 0-d."""
  91. self.ExpectScaledSize(
  92. '[Cr5,5,16 Lfys32 Lfxs64 Fr{MyFC}16 Ft20 Fl12 Fs32 Fm40]',
  93. (self.batch_size, 1, 1, 40))
  94. def testReshapeTile(self):
  95. """Tests that a tiled input can be reshaped to the batch dimension."""
  96. self.ExpectScaledSize('[S2(3x0)0,2 Cr5,5,16 Lfys16]',
  97. (self.batch_size * 3, 1, self.max_width / 3, 16), 3)
  98. def testReshapeDepth(self):
  99. """Tests that depth can be reshaped to the x dimension."""
  100. self.ExpectScaledSize('[Cl5,5,16 Mp3,3 (Lrys32 Lbys16 Lfys32) S3(3x0)2,3]',
  101. (self.batch_size, 1, self.max_width, 32))
  102. if __name__ == '__main__':
  103. tf.test.main()