sprites_gen.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Generate the sprites tfrecords from raw_images."""
  16. import os
  17. import random
  18. import re
  19. import sys
  20. import numpy as np
  21. import scipy.misc
  22. import tensorflow as tf
  23. tf.flags.DEFINE_string('data_filepattern', '', 'The raw images.')
  24. tf.flags.DEFINE_string('out_file', '',
  25. 'File name for the tfrecord output.')
  26. def _read_images():
  27. """Read images from image files into data structure."""
  28. sprites = dict()
  29. files = tf.gfile.Glob(tf.flags.FLAGS.data_filepattern)
  30. for f in files:
  31. image = scipy.misc.imread(f)
  32. m = re.search('image_([0-9]+)_([0-9]+)_([0-9]+).jpg', os.path.basename(f))
  33. if m.group(1) not in sprites:
  34. sprites[m.group(1)] = dict()
  35. character = sprites[m.group(1)]
  36. if m.group(2) not in character:
  37. character[m.group(2)] = dict()
  38. pose = character[m.group(2)]
  39. pose[int(m.group(3))] = image
  40. return sprites
  41. def _images_to_example(image, image2):
  42. """Convert 2 consecutive image to a SequenceExample."""
  43. example = tf.SequenceExample()
  44. feature_list = example.feature_lists.feature_list['moving_objs']
  45. feature = feature_list.feature.add()
  46. feature.float_list.value.extend(np.reshape(image, [-1]).tolist())
  47. feature = feature_list.feature.add()
  48. feature.float_list.value.extend(np.reshape(image2, [-1]).tolist())
  49. return example
  50. def generate_input():
  51. """Generate tfrecords."""
  52. sprites = _read_images()
  53. sys.stderr.write('Finish reading images.\n')
  54. train_writer = tf.python_io.TFRecordWriter(
  55. tf.flags.FLAGS.out_file.replace('sprites', 'sprites_train'))
  56. test_writer = tf.python_io.TFRecordWriter(
  57. tf.flags.FLAGS.out_file.replace('sprites', 'sprites_test'))
  58. train_examples = []
  59. test_examples = []
  60. for i in sprites:
  61. if int(i) < 24:
  62. examples = test_examples
  63. else:
  64. examples = train_examples
  65. character = sprites[i]
  66. for j in character.keys():
  67. pose = character[j]
  68. for k in xrange(1, len(pose), 1):
  69. image = pose[k]
  70. image2 = pose[k+1]
  71. examples.append(_images_to_example(image, image2))
  72. sys.stderr.write('Finish generating examples: %d, %d.\n' %
  73. (len(train_examples), len(test_examples)))
  74. random.shuffle(train_examples)
  75. _ = [train_writer.write(ex.SerializeToString()) for ex in train_examples]
  76. _ = [test_writer.write(ex.SerializeToString()) for ex in test_examples]
  77. def main(_):
  78. generate_input()
  79. if __name__ == '__main__':
  80. tf.app.run()