1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- # Copyright 2016 The TensorFlow Authors All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Generate the sprites tfrecords from raw_images."""
- import os
- import random
- import re
- import sys
- import numpy as np
- import scipy.misc
- import tensorflow as tf
- tf.flags.DEFINE_string('data_filepattern', '', 'The raw images.')
- tf.flags.DEFINE_string('out_file', '',
- 'File name for the tfrecord output.')
- def _read_images():
- """Read images from image files into data structure."""
- sprites = dict()
- files = tf.gfile.Glob(tf.flags.FLAGS.data_filepattern)
- for f in files:
- image = scipy.misc.imread(f)
- m = re.search('image_([0-9]+)_([0-9]+)_([0-9]+).jpg', os.path.basename(f))
- if m.group(1) not in sprites:
- sprites[m.group(1)] = dict()
- character = sprites[m.group(1)]
- if m.group(2) not in character:
- character[m.group(2)] = dict()
- pose = character[m.group(2)]
- pose[int(m.group(3))] = image
- return sprites
- def _images_to_example(image, image2):
- """Convert 2 consecutive image to a SequenceExample."""
- example = tf.SequenceExample()
- feature_list = example.feature_lists.feature_list['moving_objs']
- feature = feature_list.feature.add()
- feature.float_list.value.extend(np.reshape(image, [-1]).tolist())
- feature = feature_list.feature.add()
- feature.float_list.value.extend(np.reshape(image2, [-1]).tolist())
- return example
- def generate_input():
- """Generate tfrecords."""
- sprites = _read_images()
- sys.stderr.write('Finish reading images.\n')
- train_writer = tf.python_io.TFRecordWriter(
- tf.flags.FLAGS.out_file.replace('sprites', 'sprites_train'))
- test_writer = tf.python_io.TFRecordWriter(
- tf.flags.FLAGS.out_file.replace('sprites', 'sprites_test'))
- train_examples = []
- test_examples = []
- for i in sprites:
- if int(i) < 24:
- examples = test_examples
- else:
- examples = train_examples
- character = sprites[i]
- for j in character.keys():
- pose = character[j]
- for k in xrange(1, len(pose), 1):
- image = pose[k]
- image2 = pose[k+1]
- examples.append(_images_to_example(image, image2))
- sys.stderr.write('Finish generating examples: %d, %d.\n' %
- (len(train_examples), len(test_examples)))
- random.shuffle(train_examples)
- _ = [train_writer.write(ex.SerializeToString()) for ex in train_examples]
- _ = [test_writer.write(ex.SerializeToString()) for ex in test_examples]
- def main(_):
- generate_input()
- if __name__ == '__main__':
- tf.app.run()
|