lstm_ops.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Convolutional LSTM implementation."""
  16. import tensorflow as tf
  17. from tensorflow.contrib.slim import add_arg_scope
  18. from tensorflow.contrib.slim import layers
  19. def init_state(inputs,
  20. state_shape,
  21. state_initializer=tf.zeros_initializer(),
  22. dtype=tf.float32):
  23. """Helper function to create an initial state given inputs.
  24. Args:
  25. inputs: input Tensor, at least 2D, the first dimension being batch_size
  26. state_shape: the shape of the state.
  27. state_initializer: Initializer(shape, dtype) for state Tensor.
  28. dtype: Optional dtype, needed when inputs is None.
  29. Returns:
  30. A tensors representing the initial state.
  31. """
  32. if inputs is not None:
  33. # Handle both the dynamic shape as well as the inferred shape.
  34. inferred_batch_size = inputs.get_shape().with_rank_at_least(1)[0]
  35. batch_size = tf.shape(inputs)[0]
  36. dtype = inputs.dtype
  37. else:
  38. inferred_batch_size = 0
  39. batch_size = 0
  40. initial_state = state_initializer(
  41. tf.stack([batch_size] + state_shape),
  42. dtype=dtype)
  43. initial_state.set_shape([inferred_batch_size] + state_shape)
  44. return initial_state
  45. @add_arg_scope
  46. def basic_conv_lstm_cell(inputs,
  47. state,
  48. num_channels,
  49. filter_size=5,
  50. forget_bias=1.0,
  51. scope=None,
  52. reuse=None):
  53. """Basic LSTM recurrent network cell, with 2D convolution connctions.
  54. We add forget_bias (default: 1) to the biases of the forget gate in order to
  55. reduce the scale of forgetting in the beginning of the training.
  56. It does not allow cell clipping, a projection layer, and does not
  57. use peep-hole connections: it is the basic baseline.
  58. Args:
  59. inputs: input Tensor, 4D, batch x height x width x channels.
  60. state: state Tensor, 4D, batch x height x width x channels.
  61. num_channels: the number of output channels in the layer.
  62. filter_size: the shape of the each convolution filter.
  63. forget_bias: the initial value of the forget biases.
  64. scope: Optional scope for variable_scope.
  65. reuse: whether or not the layer and the variables should be reused.
  66. Returns:
  67. a tuple of tensors representing output and the new state.
  68. """
  69. spatial_size = inputs.get_shape()[1:3]
  70. if state is None:
  71. state = init_state(inputs, list(spatial_size) + [2 * num_channels])
  72. with tf.variable_scope(scope,
  73. 'BasicConvLstmCell',
  74. [inputs, state],
  75. reuse=reuse):
  76. inputs.get_shape().assert_has_rank(4)
  77. state.get_shape().assert_has_rank(4)
  78. c, h = tf.split(axis=3, num_or_size_splits=2, value=state)
  79. inputs_h = tf.concat(axis=3, values=[inputs, h])
  80. # Parameters of gates are concatenated into one conv for efficiency.
  81. i_j_f_o = layers.conv2d(inputs_h,
  82. 4 * num_channels, [filter_size, filter_size],
  83. stride=1,
  84. activation_fn=None,
  85. scope='Gates')
  86. # i = input_gate, j = new_input, f = forget_gate, o = output_gate
  87. i, j, f, o = tf.split(axis=3, num_or_size_splits=4, value=i_j_f_o)
  88. new_c = c * tf.sigmoid(f + forget_bias) + tf.sigmoid(i) * tf.tanh(j)
  89. new_h = tf.tanh(new_c) * tf.sigmoid(o)
  90. return new_h, tf.concat(axis=3, values=[new_c, new_h])