example_gen.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Generate examples of two objects moving in different directions."""
  16. import random
  17. import sys
  18. import numpy as np
  19. import tensorflow as tf
  20. tf.flags.DEFINE_string('out_file', '',
  21. 'Output file for the tfrecords.')
  22. def _add_object(obj_type, image, image2, xpos, ypos):
  23. """Add a moving obj to two consecutive images."""
  24. obj_size = random.randint(8, 10)
  25. channel = random.randint(0, 2)
  26. move = random.randint(6, 10)
  27. obj = np.zeros([obj_size, obj_size, 3])
  28. if obj_type == 'rectangle':
  29. xpos2 = xpos + move
  30. ypos2 = ypos
  31. for i in xrange(obj_size):
  32. obj[i, 0:i+1, channel] = [1.0 for _ in xrange(i+1)]
  33. elif obj_type == 'square':
  34. xpos2 = xpos
  35. ypos2 = ypos + move
  36. obj[:, :, channel] = 1.0
  37. for x in xrange(obj_size):
  38. for y in xrange(obj_size):
  39. if obj[x, y, channel] == 1.0:
  40. image[xpos+x, ypos+y, channel] = 1.0
  41. image2[xpos2+x, ypos2+y, channel] = 1.0
  42. def _images_to_example(image, image2):
  43. """Convert two consecutive images to SequenceExample."""
  44. example = tf.SequenceExample()
  45. feature_list = example.feature_lists.feature_list['moving_objs']
  46. feature = feature_list.feature.add()
  47. feature.float_list.value.extend(np.reshape(image, [-1]).tolist())
  48. feature = feature_list.feature.add()
  49. feature.float_list.value.extend(np.reshape(image2, [-1]).tolist())
  50. return example
  51. def generate_input():
  52. """Generate tfrecords."""
  53. writer = tf.python_io.TFRecordWriter(tf.flags.FLAGS.out_file)
  54. writer2 = tf.python_io.TFRecordWriter(tf.flags.FLAGS.out_file + '_test')
  55. examples = []
  56. for xpos in xrange(0, 40, 3):
  57. for ypos in xrange(0, 40, 3):
  58. for xpos2 in xrange(0, 40, 3):
  59. for ypos2 in xrange(0, 40, 3):
  60. image = np.zeros([64, 64, 3])
  61. image2 = np.zeros([64, 64, 3])
  62. _add_object('rectangle', image, image2, xpos, ypos)
  63. _add_object('square', image, image2, xpos2, ypos2)
  64. examples.append(_images_to_example(image, image2))
  65. sys.stderr.write('Finish generating examples.\n')
  66. random.shuffle(examples)
  67. for count, ex in enumerate(examples):
  68. if count % 10 == 0:
  69. writer2.write(ex.SerializeToString())
  70. else:
  71. writer.write(ex.SerializeToString())
  72. def main(_):
  73. generate_input()
  74. if __name__ == '__main__':
  75. tf.app.run()