prediction_input.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Code for building the input for the prediction model."""
  16. import os
  17. import numpy as np
  18. import tensorflow as tf
  19. from tensorflow.python.platform import flags
  20. from tensorflow.python.platform import gfile
  21. FLAGS = flags.FLAGS
  22. # Original image dimensions
  23. ORIGINAL_WIDTH = 640
  24. ORIGINAL_HEIGHT = 512
  25. COLOR_CHAN = 3
  26. # Default image dimensions.
  27. IMG_WIDTH = 64
  28. IMG_HEIGHT = 64
  29. # Dimension of the state and action.
  30. STATE_DIM = 5
  31. def build_tfrecord_input(training=True):
  32. """Create input tfrecord tensors.
  33. Args:
  34. training: training or validation data.
  35. Returns:
  36. list of tensors corresponding to images, actions, and states. The images
  37. tensor is 5D, batch x time x height x width x channels. The state and
  38. action tensors are 3D, batch x time x dimension.
  39. Raises:
  40. RuntimeError: if no files found.
  41. """
  42. filenames = gfile.Glob(os.path.join(FLAGS.data_dir, '*'))
  43. if not filenames:
  44. raise RuntimeError('No data files found.')
  45. index = int(np.floor(FLAGS.train_val_split * len(filenames)))
  46. if training:
  47. filenames = filenames[:index]
  48. else:
  49. filenames = filenames[index:]
  50. filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
  51. reader = tf.TFRecordReader()
  52. _, serialized_example = reader.read(filename_queue)
  53. image_seq, state_seq, action_seq = [], [], []
  54. for i in range(FLAGS.sequence_length):
  55. image_name = 'move/' + str(i) + '/image/encoded'
  56. action_name = 'move/' + str(i) + '/commanded_pose/vec_pitch_yaw'
  57. state_name = 'move/' + str(i) + '/endeffector/vec_pitch_yaw'
  58. if FLAGS.use_state:
  59. features = {image_name: tf.FixedLenFeature([1], tf.string),
  60. action_name: tf.FixedLenFeature([STATE_DIM], tf.float32),
  61. state_name: tf.FixedLenFeature([STATE_DIM], tf.float32)}
  62. else:
  63. features = {image_name: tf.FixedLenFeature([1], tf.string)}
  64. features = tf.parse_single_example(serialized_example, features=features)
  65. image_buffer = tf.reshape(features[image_name], shape=[])
  66. image = tf.image.decode_jpeg(image_buffer, channels=COLOR_CHAN)
  67. image.set_shape([ORIGINAL_HEIGHT, ORIGINAL_WIDTH, COLOR_CHAN])
  68. if IMG_HEIGHT != IMG_WIDTH:
  69. raise ValueError('Unequal height and width unsupported')
  70. crop_size = min(ORIGINAL_HEIGHT, ORIGINAL_WIDTH)
  71. image = tf.image.resize_image_with_crop_or_pad(image, crop_size, crop_size)
  72. image = tf.reshape(image, [1, crop_size, crop_size, COLOR_CHAN])
  73. image = tf.image.resize_bicubic(image, [IMG_HEIGHT, IMG_WIDTH])
  74. image = tf.cast(image, tf.float32) / 255.0
  75. image_seq.append(image)
  76. if FLAGS.use_state:
  77. state = tf.reshape(features[state_name], shape=[1, STATE_DIM])
  78. state_seq.append(state)
  79. action = tf.reshape(features[action_name], shape=[1, STATE_DIM])
  80. action_seq.append(action)
  81. image_seq = tf.concat(axis=0, values=image_seq)
  82. if FLAGS.use_state:
  83. state_seq = tf.concat(axis=0, values=state_seq)
  84. action_seq = tf.concat(axis=0, values=action_seq)
  85. [image_batch, action_batch, state_batch] = tf.train.batch(
  86. [image_seq, action_seq, state_seq],
  87. FLAGS.batch_size,
  88. num_threads=FLAGS.batch_size,
  89. capacity=100 * FLAGS.batch_size)
  90. return image_batch, action_batch, state_batch
  91. else:
  92. image_batch = tf.train.batch(
  93. [image_seq],
  94. FLAGS.batch_size,
  95. num_threads=FLAGS.batch_size,
  96. capacity=100 * FLAGS.batch_size)
  97. zeros_batch = tf.zeros([FLAGS.batch_size, FLAGS.sequence_length, STATE_DIM])
  98. return image_batch, zeros_batch, zeros_batch