123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- # Copyright 2016 The TensorFlow Authors All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Code for building the input for the prediction model."""
- import os
- import numpy as np
- import tensorflow as tf
- from tensorflow.python.platform import flags
- from tensorflow.python.platform import gfile
- FLAGS = flags.FLAGS
- # Original image dimensions
- ORIGINAL_WIDTH = 640
- ORIGINAL_HEIGHT = 512
- COLOR_CHAN = 3
- # Default image dimensions.
- IMG_WIDTH = 64
- IMG_HEIGHT = 64
- # Dimension of the state and action.
- STATE_DIM = 5
- def build_tfrecord_input(training=True):
- """Create input tfrecord tensors.
- Args:
- training: training or validation data.
- Returns:
- list of tensors corresponding to images, actions, and states. The images
- tensor is 5D, batch x time x height x width x channels. The state and
- action tensors are 3D, batch x time x dimension.
- Raises:
- RuntimeError: if no files found.
- """
- filenames = gfile.Glob(os.path.join(FLAGS.data_dir, '*'))
- if not filenames:
- raise RuntimeError('No data files found.')
- index = int(np.floor(FLAGS.train_val_split * len(filenames)))
- if training:
- filenames = filenames[:index]
- else:
- filenames = filenames[index:]
- filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
- reader = tf.TFRecordReader()
- _, serialized_example = reader.read(filename_queue)
- image_seq, state_seq, action_seq = [], [], []
- for i in range(FLAGS.sequence_length):
- image_name = 'move/' + str(i) + '/image/encoded'
- action_name = 'move/' + str(i) + '/commanded_pose/vec_pitch_yaw'
- state_name = 'move/' + str(i) + '/endeffector/vec_pitch_yaw'
- if FLAGS.use_state:
- features = {image_name: tf.FixedLenFeature([1], tf.string),
- action_name: tf.FixedLenFeature([STATE_DIM], tf.float32),
- state_name: tf.FixedLenFeature([STATE_DIM], tf.float32)}
- else:
- features = {image_name: tf.FixedLenFeature([1], tf.string)}
- features = tf.parse_single_example(serialized_example, features=features)
- image_buffer = tf.reshape(features[image_name], shape=[])
- image = tf.image.decode_jpeg(image_buffer, channels=COLOR_CHAN)
- image.set_shape([ORIGINAL_HEIGHT, ORIGINAL_WIDTH, COLOR_CHAN])
- if IMG_HEIGHT != IMG_WIDTH:
- raise ValueError('Unequal height and width unsupported')
- crop_size = min(ORIGINAL_HEIGHT, ORIGINAL_WIDTH)
- image = tf.image.resize_image_with_crop_or_pad(image, crop_size, crop_size)
- image = tf.reshape(image, [1, crop_size, crop_size, COLOR_CHAN])
- image = tf.image.resize_bicubic(image, [IMG_HEIGHT, IMG_WIDTH])
- image = tf.cast(image, tf.float32) / 255.0
- image_seq.append(image)
- if FLAGS.use_state:
- state = tf.reshape(features[state_name], shape=[1, STATE_DIM])
- state_seq.append(state)
- action = tf.reshape(features[action_name], shape=[1, STATE_DIM])
- action_seq.append(action)
- image_seq = tf.concat(axis=0, values=image_seq)
- if FLAGS.use_state:
- state_seq = tf.concat(axis=0, values=state_seq)
- action_seq = tf.concat(axis=0, values=action_seq)
- [image_batch, action_batch, state_batch] = tf.train.batch(
- [image_seq, action_seq, state_seq],
- FLAGS.batch_size,
- num_threads=FLAGS.batch_size,
- capacity=100 * FLAGS.batch_size)
- return image_batch, action_batch, state_batch
- else:
- image_batch = tf.train.batch(
- [image_seq],
- FLAGS.batch_size,
- num_threads=FLAGS.batch_size,
- capacity=100 * FLAGS.batch_size)
- zeros_batch = tf.zeros([FLAGS.batch_size, FLAGS.sequence_length, STATE_DIM])
- return image_batch, zeros_batch, zeros_batch
|