prediction_model.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Model architecture for predictive model, including CDNA, DNA, and STP."""
  16. import numpy as np
  17. import tensorflow as tf
  18. import tensorflow.contrib.slim as slim
  19. from tensorflow.contrib.layers.python import layers as tf_layers
  20. from lstm_ops import basic_conv_lstm_cell
  21. # Amount to use when lower bounding tensors
  22. RELU_SHIFT = 1e-12
  23. # kernel size for DNA and CDNA.
  24. DNA_KERN_SIZE = 5
  25. def construct_model(images,
  26. actions=None,
  27. states=None,
  28. iter_num=-1.0,
  29. k=-1,
  30. use_state=True,
  31. num_masks=10,
  32. stp=False,
  33. cdna=True,
  34. dna=False,
  35. context_frames=2):
  36. """Build convolutional lstm video predictor using STP, CDNA, or DNA.
  37. Args:
  38. images: tensor of ground truth image sequences
  39. actions: tensor of action sequences
  40. states: tensor of ground truth state sequences
  41. iter_num: tensor of the current training iteration (for sched. sampling)
  42. k: constant used for scheduled sampling. -1 to feed in own prediction.
  43. use_state: True to include state and action in prediction
  44. num_masks: the number of different pixel motion predictions (and
  45. the number of masks for each of those predictions)
  46. stp: True to use Spatial Transformer Predictor (STP)
  47. cdna: True to use Convoluational Dynamic Neural Advection (CDNA)
  48. dna: True to use Dynamic Neural Advection (DNA)
  49. context_frames: number of ground truth frames to pass in before
  50. feeding in own predictions
  51. Returns:
  52. gen_images: predicted future image frames
  53. gen_states: predicted future states
  54. Raises:
  55. ValueError: if more than one network option specified or more than 1 mask
  56. specified for DNA model.
  57. """
  58. if stp + cdna + dna != 1:
  59. raise ValueError('More than one, or no network option specified.')
  60. batch_size, img_height, img_width, color_channels = images[0].get_shape()[0:4]
  61. lstm_func = basic_conv_lstm_cell
  62. # Generated robot states and images.
  63. gen_states, gen_images = [], []
  64. current_state = states[0]
  65. if k == -1:
  66. feedself = True
  67. else:
  68. # Scheduled sampling:
  69. # Calculate number of ground-truth frames to pass in.
  70. num_ground_truth = tf.to_int32(
  71. tf.round(tf.to_float(batch_size) * (k / (k + tf.exp(iter_num / k)))))
  72. feedself = False
  73. # LSTM state sizes and states.
  74. lstm_size = np.int32(np.array([32, 32, 64, 64, 128, 64, 32]))
  75. lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None
  76. lstm_state5, lstm_state6, lstm_state7 = None, None, None
  77. for image, action in zip(images[:-1], actions[:-1]):
  78. # Reuse variables after the first timestep.
  79. reuse = bool(gen_images)
  80. done_warm_start = len(gen_images) > context_frames - 1
  81. with slim.arg_scope(
  82. [lstm_func, slim.layers.conv2d, slim.layers.fully_connected,
  83. tf_layers.layer_norm, slim.layers.conv2d_transpose],
  84. reuse=reuse):
  85. if feedself and done_warm_start:
  86. # Feed in generated image.
  87. prev_image = gen_images[-1]
  88. elif done_warm_start:
  89. # Scheduled sampling
  90. prev_image = scheduled_sample(image, gen_images[-1], batch_size,
  91. num_ground_truth)
  92. else:
  93. # Always feed in ground_truth
  94. prev_image = image
  95. # Predicted state is always fed back in
  96. state_action = tf.concat(axis=1, values=[action, current_state])
  97. enc0 = slim.layers.conv2d(
  98. prev_image,
  99. 32, [5, 5],
  100. stride=2,
  101. scope='scale1_conv1',
  102. normalizer_fn=tf_layers.layer_norm,
  103. normalizer_params={'scope': 'layer_norm1'})
  104. hidden1, lstm_state1 = lstm_func(
  105. enc0, lstm_state1, lstm_size[0], scope='state1')
  106. hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2')
  107. hidden2, lstm_state2 = lstm_func(
  108. hidden1, lstm_state2, lstm_size[1], scope='state2')
  109. hidden2 = tf_layers.layer_norm(hidden2, scope='layer_norm3')
  110. enc1 = slim.layers.conv2d(
  111. hidden2, hidden2.get_shape()[3], [3, 3], stride=2, scope='conv2')
  112. hidden3, lstm_state3 = lstm_func(
  113. enc1, lstm_state3, lstm_size[2], scope='state3')
  114. hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4')
  115. hidden4, lstm_state4 = lstm_func(
  116. hidden3, lstm_state4, lstm_size[3], scope='state4')
  117. hidden4 = tf_layers.layer_norm(hidden4, scope='layer_norm5')
  118. enc2 = slim.layers.conv2d(
  119. hidden4, hidden4.get_shape()[3], [3, 3], stride=2, scope='conv3')
  120. # Pass in state and action.
  121. smear = tf.reshape(
  122. state_action,
  123. [int(batch_size), 1, 1, int(state_action.get_shape()[1])])
  124. smear = tf.tile(
  125. smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1])
  126. if use_state:
  127. enc2 = tf.concat(axis=3, values=[enc2, smear])
  128. enc3 = slim.layers.conv2d(
  129. enc2, hidden4.get_shape()[3], [1, 1], stride=1, scope='conv4')
  130. hidden5, lstm_state5 = lstm_func(
  131. enc3, lstm_state5, lstm_size[4], scope='state5') # last 8x8
  132. hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6')
  133. enc4 = slim.layers.conv2d_transpose(
  134. hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1')
  135. hidden6, lstm_state6 = lstm_func(
  136. enc4, lstm_state6, lstm_size[5], scope='state6') # 16x16
  137. hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7')
  138. # Skip connection.
  139. hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16
  140. enc5 = slim.layers.conv2d_transpose(
  141. hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2')
  142. hidden7, lstm_state7 = lstm_func(
  143. enc5, lstm_state7, lstm_size[6], scope='state7') # 32x32
  144. hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8')
  145. # Skip connection.
  146. hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32
  147. enc6 = slim.layers.conv2d_transpose(
  148. hidden7,
  149. hidden7.get_shape()[3], 3, stride=2, scope='convt3',
  150. normalizer_fn=tf_layers.layer_norm,
  151. normalizer_params={'scope': 'layer_norm9'})
  152. if dna:
  153. # Using largest hidden state for predicting untied conv kernels.
  154. enc7 = slim.layers.conv2d_transpose(
  155. enc6, DNA_KERN_SIZE**2, 1, stride=1, scope='convt4')
  156. else:
  157. # Using largest hidden state for predicting a new image layer.
  158. enc7 = slim.layers.conv2d_transpose(
  159. enc6, color_channels, 1, stride=1, scope='convt4')
  160. # This allows the network to also generate one image from scratch,
  161. # which is useful when regions of the image become unoccluded.
  162. transformed = [tf.nn.sigmoid(enc7)]
  163. if stp:
  164. stp_input0 = tf.reshape(hidden5, [int(batch_size), -1])
  165. stp_input1 = slim.layers.fully_connected(
  166. stp_input0, 100, scope='fc_stp')
  167. transformed += stp_transformation(prev_image, stp_input1, num_masks)
  168. elif cdna:
  169. cdna_input = tf.reshape(hidden5, [int(batch_size), -1])
  170. transformed += cdna_transformation(prev_image, cdna_input, num_masks,
  171. int(color_channels))
  172. elif dna:
  173. # Only one mask is supported (more should be unnecessary).
  174. if num_masks != 1:
  175. raise ValueError('Only one mask is supported for DNA model.')
  176. transformed = [dna_transformation(prev_image, enc7)]
  177. masks = slim.layers.conv2d_transpose(
  178. enc6, num_masks + 1, 1, stride=1, scope='convt7')
  179. masks = tf.reshape(
  180. tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])),
  181. [int(batch_size), int(img_height), int(img_width), num_masks + 1])
  182. mask_list = tf.split(axis=3, num_or_size_splits=num_masks + 1, value=masks)
  183. output = mask_list[0] * prev_image
  184. for layer, mask in zip(transformed, mask_list[1:]):
  185. output += layer * mask
  186. gen_images.append(output)
  187. current_state = slim.layers.fully_connected(
  188. state_action,
  189. int(current_state.get_shape()[1]),
  190. scope='state_pred',
  191. activation_fn=None)
  192. gen_states.append(current_state)
  193. return gen_images, gen_states
  194. ## Utility functions
  195. def stp_transformation(prev_image, stp_input, num_masks):
  196. """Apply spatial transformer predictor (STP) to previous image.
  197. Args:
  198. prev_image: previous image to be transformed.
  199. stp_input: hidden layer to be used for computing STN parameters.
  200. num_masks: number of masks and hence the number of STP transformations.
  201. Returns:
  202. List of images transformed by the predicted STP parameters.
  203. """
  204. # Only import spatial transformer if needed.
  205. from spatial_transformer import transformer
  206. identity_params = tf.convert_to_tensor(
  207. np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32))
  208. transformed = []
  209. for i in range(num_masks - 1):
  210. params = slim.layers.fully_connected(
  211. stp_input, 6, scope='stp_params' + str(i),
  212. activation_fn=None) + identity_params
  213. transformed.append(transformer(prev_image, params))
  214. return transformed
  215. def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
  216. """Apply convolutional dynamic neural advection to previous image.
  217. Args:
  218. prev_image: previous image to be transformed.
  219. cdna_input: hidden lyaer to be used for computing CDNA kernels.
  220. num_masks: the number of masks and hence the number of CDNA transformations.
  221. color_channels: the number of color channels in the images.
  222. Returns:
  223. List of images transformed by the predicted CDNA kernels.
  224. """
  225. batch_size = int(cdna_input.get_shape()[0])
  226. # Predict kernels using linear function of last hidden layer.
  227. cdna_kerns = slim.layers.fully_connected(
  228. cdna_input,
  229. DNA_KERN_SIZE * DNA_KERN_SIZE * num_masks,
  230. scope='cdna_params',
  231. activation_fn=None)
  232. # Reshape and normalize.
  233. cdna_kerns = tf.reshape(
  234. cdna_kerns, [batch_size, DNA_KERN_SIZE, DNA_KERN_SIZE, 1, num_masks])
  235. cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT
  236. norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keep_dims=True)
  237. cdna_kerns /= norm_factor
  238. cdna_kerns = tf.tile(cdna_kerns, [1, 1, 1, color_channels, 1])
  239. cdna_kerns = tf.split(axis=0, num_or_size_splits=batch_size, value=cdna_kerns)
  240. prev_images = tf.split(axis=0, num_or_size_splits=batch_size, value=prev_image)
  241. # Transform image.
  242. transformed = []
  243. for kernel, preimg in zip(cdna_kerns, prev_images):
  244. kernel = tf.squeeze(kernel)
  245. if len(kernel.get_shape()) == 3:
  246. kernel = tf.expand_dims(kernel, -1)
  247. transformed.append(
  248. tf.nn.depthwise_conv2d(preimg, kernel, [1, 1, 1, 1], 'SAME'))
  249. transformed = tf.concat(axis=0, values=transformed)
  250. transformed = tf.split(axis=3, num_or_size_splits=num_masks, value=transformed)
  251. return transformed
  252. def dna_transformation(prev_image, dna_input):
  253. """Apply dynamic neural advection to previous image.
  254. Args:
  255. prev_image: previous image to be transformed.
  256. dna_input: hidden lyaer to be used for computing DNA transformation.
  257. Returns:
  258. List of images transformed by the predicted CDNA kernels.
  259. """
  260. # Construct translated images.
  261. prev_image_pad = tf.pad(prev_image, [[0, 0], [2, 2], [2, 2], [0, 0]])
  262. image_height = int(prev_image.get_shape()[1])
  263. image_width = int(prev_image.get_shape()[2])
  264. inputs = []
  265. for xkern in range(DNA_KERN_SIZE):
  266. for ykern in range(DNA_KERN_SIZE):
  267. inputs.append(
  268. tf.expand_dims(
  269. tf.slice(prev_image_pad, [0, xkern, ykern, 0],
  270. [-1, image_height, image_width, -1]), [3]))
  271. inputs = tf.concat(axis=3, values=inputs)
  272. # Normalize channels to 1.
  273. kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT
  274. kernel = tf.expand_dims(
  275. kernel / tf.reduce_sum(
  276. kernel, [3], keep_dims=True), [4])
  277. return tf.reduce_sum(kernel * inputs, [3], keep_dims=False)
  278. def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
  279. """Sample batch with specified mix of ground truth and generated data points.
  280. Args:
  281. ground_truth_x: tensor of ground-truth data points.
  282. generated_x: tensor of generated data points.
  283. batch_size: batch size
  284. num_ground_truth: number of ground-truth examples to include in batch.
  285. Returns:
  286. New batch with num_ground_truth sampled from ground_truth_x and the rest
  287. from generated_x.
  288. """
  289. idx = tf.random_shuffle(tf.range(int(batch_size)))
  290. ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
  291. generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))
  292. ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
  293. generated_examps = tf.gather(generated_x, generated_idx)
  294. return tf.dynamic_stitch([ground_truth_idx, generated_idx],
  295. [ground_truth_examps, generated_examps])