nn_ops.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Ops and utilities for neural networks.
  16. For now, just an LSTM layer.
  17. """
  18. import shapes
  19. import tensorflow as tf
  20. rnn = tf.load_op_library("../cc/rnn_ops.so")
  21. def rnn_helper(inp,
  22. length,
  23. cell_type=None,
  24. direction="forward",
  25. name=None,
  26. *args,
  27. **kwargs):
  28. """Adds ops for a recurrent neural network layer.
  29. This function calls an actual implementation of a recurrent neural network
  30. based on `cell_type`.
  31. There are three modes depending on the value of `direction`:
  32. forward: Adds a forward RNN.
  33. backward: Adds a backward RNN.
  34. bidirectional: Adds both forward and backward RNNs and creates a
  35. bidirectional RNN.
  36. Args:
  37. inp: A 3-D tensor of shape [`batch_size`, `max_length`, `feature_dim`].
  38. length: A 1-D tensor of shape [`batch_size`] and type int64. Each element
  39. represents the length of the corresponding sequence in `inp`.
  40. cell_type: Cell type of RNN. Currently can only be "lstm".
  41. direction: One of "forward", "backward", "bidirectional".
  42. name: Name of the op.
  43. *args: Other arguments to the layer.
  44. **kwargs: Keyword arugments to the layer.
  45. Returns:
  46. A 3-D tensor of shape [`batch_size`, `max_length`, `num_nodes`].
  47. """
  48. assert cell_type is not None
  49. rnn_func = None
  50. if cell_type == "lstm":
  51. rnn_func = lstm_layer
  52. assert rnn_func is not None
  53. assert direction in ["forward", "backward", "bidirectional"]
  54. with tf.variable_scope(name):
  55. if direction in ["forward", "bidirectional"]:
  56. forward = rnn_func(
  57. inp=inp,
  58. length=length,
  59. backward=False,
  60. name="forward",
  61. *args,
  62. **kwargs)
  63. if isinstance(forward, tuple):
  64. # lstm_layer returns a tuple (output, memory). We only need the first
  65. # element.
  66. forward = forward[0]
  67. if direction in ["backward", "bidirectional"]:
  68. backward = rnn_func(
  69. inp=inp,
  70. length=length,
  71. backward=True,
  72. name="backward",
  73. *args,
  74. **kwargs)
  75. if isinstance(backward, tuple):
  76. # lstm_layer returns a tuple (output, memory). We only need the first
  77. # element.
  78. backward = backward[0]
  79. if direction == "forward":
  80. out = forward
  81. elif direction == "backward":
  82. out = backward
  83. else:
  84. out = tf.concat(axis=2, values=[forward, backward])
  85. return out
  86. @tf.RegisterShape("VariableLSTM")
  87. def _variable_lstm_shape(op):
  88. """Shape function for the VariableLSTM op."""
  89. input_shape = op.inputs[0].get_shape().with_rank(4)
  90. state_shape = op.inputs[1].get_shape().with_rank(2)
  91. memory_shape = op.inputs[2].get_shape().with_rank(2)
  92. w_m_m_shape = op.inputs[3].get_shape().with_rank(3)
  93. batch_size = input_shape[0].merge_with(state_shape[0])
  94. batch_size = input_shape[0].merge_with(memory_shape[0])
  95. seq_len = input_shape[1]
  96. gate_num = input_shape[2].merge_with(w_m_m_shape[1])
  97. output_dim = input_shape[3].merge_with(state_shape[1])
  98. output_dim = output_dim.merge_with(memory_shape[1])
  99. output_dim = output_dim.merge_with(w_m_m_shape[0])
  100. output_dim = output_dim.merge_with(w_m_m_shape[2])
  101. return [[batch_size, seq_len, output_dim],
  102. [batch_size, seq_len, gate_num, output_dim],
  103. [batch_size, seq_len, output_dim]]
  104. @tf.RegisterGradient("VariableLSTM")
  105. def _variable_lstm_grad(op, act_grad, gate_grad, mem_grad):
  106. """Gradient function for the VariableLSTM op."""
  107. initial_state = op.inputs[1]
  108. initial_memory = op.inputs[2]
  109. w_m_m = op.inputs[3]
  110. act = op.outputs[0]
  111. gate_raw_act = op.outputs[1]
  112. memory = op.outputs[2]
  113. return rnn.variable_lstm_grad(initial_state, initial_memory, w_m_m, act,
  114. gate_raw_act, memory, act_grad, gate_grad,
  115. mem_grad)
  116. def lstm_layer(inp,
  117. length=None,
  118. state=None,
  119. memory=None,
  120. num_nodes=None,
  121. backward=False,
  122. clip=50.0,
  123. reg_func=tf.nn.l2_loss,
  124. weight_reg=False,
  125. weight_collection="LSTMWeights",
  126. bias_reg=False,
  127. stddev=None,
  128. seed=None,
  129. decode=False,
  130. use_native_weights=False,
  131. name=None):
  132. """Adds ops for an LSTM layer.
  133. This adds ops for the following operations:
  134. input => (forward-LSTM|backward-LSTM) => output
  135. The direction of the LSTM is determined by `backward`. If it is false, the
  136. forward LSTM is used, the backward one otherwise.
  137. Args:
  138. inp: A 3-D tensor of shape [`batch_size`, `max_length`, `feature_dim`].
  139. length: A 1-D tensor of shape [`batch_size`] and type int64. Each element
  140. represents the length of the corresponding sequence in `inp`.
  141. state: If specified, uses it as the initial state.
  142. memory: If specified, uses it as the initial memory.
  143. num_nodes: The number of LSTM cells.
  144. backward: If true, reverses the `inp` before adding the ops. The output is
  145. also reversed so that the direction is the same as `inp`.
  146. clip: Value used to clip the cell values.
  147. reg_func: Function used for the weight regularization such as
  148. `tf.nn.l2_loss`.
  149. weight_reg: If true, regularize the filter weights with `reg_func`.
  150. weight_collection: Collection to add the weights to for regularization.
  151. bias_reg: If true, regularize the bias vector with `reg_func`.
  152. stddev: Standard deviation used to initialize the variables.
  153. seed: Seed used to initialize the variables.
  154. decode: If true, does not add ops which are not used for inference.
  155. use_native_weights: If true, uses weights in the same format as the native
  156. implementations.
  157. name: Name of the op.
  158. Returns:
  159. A 3-D tensor of shape [`batch_size`, `max_length`, `num_nodes`].
  160. """
  161. with tf.variable_scope(name):
  162. if backward:
  163. if length is None:
  164. inp = tf.reverse(inp, [1])
  165. else:
  166. inp = tf.reverse_sequence(inp, length, 1, 0)
  167. num_prev = inp.get_shape()[2]
  168. if stddev:
  169. initializer = tf.truncated_normal_initializer(stddev=stddev, seed=seed)
  170. else:
  171. initializer = tf.uniform_unit_scaling_initializer(seed=seed)
  172. if use_native_weights:
  173. with tf.variable_scope("LSTMCell"):
  174. w = tf.get_variable(
  175. "W_0",
  176. shape=[num_prev + num_nodes, 4 * num_nodes],
  177. initializer=initializer,
  178. dtype=tf.float32)
  179. w_i_m = tf.slice(w, [0, 0], [num_prev, 4 * num_nodes], name="w_i_m")
  180. w_m_m = tf.reshape(
  181. tf.slice(w, [num_prev, 0], [num_nodes, 4 * num_nodes]),
  182. [num_nodes, 4, num_nodes],
  183. name="w_m_m")
  184. else:
  185. w_i_m = tf.get_variable("w_i_m", [num_prev, 4 * num_nodes],
  186. initializer=initializer)
  187. w_m_m = tf.get_variable("w_m_m", [num_nodes, 4, num_nodes],
  188. initializer=initializer)
  189. if not decode and weight_reg:
  190. tf.add_to_collection(weight_collection, reg_func(w_i_m, name="w_i_m_reg"))
  191. tf.add_to_collection(weight_collection, reg_func(w_m_m, name="w_m_m_reg"))
  192. batch_size = shapes.tensor_dim(inp, dim=0)
  193. num_frames = shapes.tensor_dim(inp, dim=1)
  194. prev = tf.reshape(inp, tf.stack([batch_size * num_frames, num_prev]))
  195. if use_native_weights:
  196. with tf.variable_scope("LSTMCell"):
  197. b = tf.get_variable(
  198. "B",
  199. shape=[4 * num_nodes],
  200. initializer=tf.zeros_initializer(),
  201. dtype=tf.float32)
  202. biases = tf.identity(b, name="biases")
  203. else:
  204. biases = tf.get_variable(
  205. "biases", [4 * num_nodes], initializer=tf.constant_initializer(0.0))
  206. if not decode and bias_reg:
  207. tf.add_to_collection(
  208. weight_collection, reg_func(
  209. biases, name="biases_reg"))
  210. prev = tf.nn.xw_plus_b(prev, w_i_m, biases)
  211. prev = tf.reshape(prev, tf.stack([batch_size, num_frames, 4, num_nodes]))
  212. if state is None:
  213. state = tf.fill(tf.stack([batch_size, num_nodes]), 0.0)
  214. if memory is None:
  215. memory = tf.fill(tf.stack([batch_size, num_nodes]), 0.0)
  216. out, _, mem = rnn.variable_lstm(prev, state, memory, w_m_m, clip=clip)
  217. if backward:
  218. if length is None:
  219. out = tf.reverse(out, [1])
  220. else:
  221. out = tf.reverse_sequence(out, length, 1, 0)
  222. return out, mem