123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Ops and utilities for neural networks.
- For now, just an LSTM layer.
- """
- import shapes
- import tensorflow as tf
- rnn = tf.load_op_library("../cc/rnn_ops.so")
- def rnn_helper(inp,
- length,
- cell_type=None,
- direction="forward",
- name=None,
- *args,
- **kwargs):
- """Adds ops for a recurrent neural network layer.
- This function calls an actual implementation of a recurrent neural network
- based on `cell_type`.
- There are three modes depending on the value of `direction`:
- forward: Adds a forward RNN.
- backward: Adds a backward RNN.
- bidirectional: Adds both forward and backward RNNs and creates a
- bidirectional RNN.
- Args:
- inp: A 3-D tensor of shape [`batch_size`, `max_length`, `feature_dim`].
- length: A 1-D tensor of shape [`batch_size`] and type int64. Each element
- represents the length of the corresponding sequence in `inp`.
- cell_type: Cell type of RNN. Currently can only be "lstm".
- direction: One of "forward", "backward", "bidirectional".
- name: Name of the op.
- *args: Other arguments to the layer.
- **kwargs: Keyword arugments to the layer.
- Returns:
- A 3-D tensor of shape [`batch_size`, `max_length`, `num_nodes`].
- """
- assert cell_type is not None
- rnn_func = None
- if cell_type == "lstm":
- rnn_func = lstm_layer
- assert rnn_func is not None
- assert direction in ["forward", "backward", "bidirectional"]
- with tf.variable_scope(name):
- if direction in ["forward", "bidirectional"]:
- forward = rnn_func(
- inp=inp,
- length=length,
- backward=False,
- name="forward",
- *args,
- **kwargs)
- if isinstance(forward, tuple):
- # lstm_layer returns a tuple (output, memory). We only need the first
- # element.
- forward = forward[0]
- if direction in ["backward", "bidirectional"]:
- backward = rnn_func(
- inp=inp,
- length=length,
- backward=True,
- name="backward",
- *args,
- **kwargs)
- if isinstance(backward, tuple):
- # lstm_layer returns a tuple (output, memory). We only need the first
- # element.
- backward = backward[0]
- if direction == "forward":
- out = forward
- elif direction == "backward":
- out = backward
- else:
- out = tf.concat(axis=2, values=[forward, backward])
- return out
- @tf.RegisterShape("VariableLSTM")
- def _variable_lstm_shape(op):
- """Shape function for the VariableLSTM op."""
- input_shape = op.inputs[0].get_shape().with_rank(4)
- state_shape = op.inputs[1].get_shape().with_rank(2)
- memory_shape = op.inputs[2].get_shape().with_rank(2)
- w_m_m_shape = op.inputs[3].get_shape().with_rank(3)
- batch_size = input_shape[0].merge_with(state_shape[0])
- batch_size = input_shape[0].merge_with(memory_shape[0])
- seq_len = input_shape[1]
- gate_num = input_shape[2].merge_with(w_m_m_shape[1])
- output_dim = input_shape[3].merge_with(state_shape[1])
- output_dim = output_dim.merge_with(memory_shape[1])
- output_dim = output_dim.merge_with(w_m_m_shape[0])
- output_dim = output_dim.merge_with(w_m_m_shape[2])
- return [[batch_size, seq_len, output_dim],
- [batch_size, seq_len, gate_num, output_dim],
- [batch_size, seq_len, output_dim]]
- @tf.RegisterGradient("VariableLSTM")
- def _variable_lstm_grad(op, act_grad, gate_grad, mem_grad):
- """Gradient function for the VariableLSTM op."""
- initial_state = op.inputs[1]
- initial_memory = op.inputs[2]
- w_m_m = op.inputs[3]
- act = op.outputs[0]
- gate_raw_act = op.outputs[1]
- memory = op.outputs[2]
- return rnn.variable_lstm_grad(initial_state, initial_memory, w_m_m, act,
- gate_raw_act, memory, act_grad, gate_grad,
- mem_grad)
- def lstm_layer(inp,
- length=None,
- state=None,
- memory=None,
- num_nodes=None,
- backward=False,
- clip=50.0,
- reg_func=tf.nn.l2_loss,
- weight_reg=False,
- weight_collection="LSTMWeights",
- bias_reg=False,
- stddev=None,
- seed=None,
- decode=False,
- use_native_weights=False,
- name=None):
- """Adds ops for an LSTM layer.
- This adds ops for the following operations:
- input => (forward-LSTM|backward-LSTM) => output
- The direction of the LSTM is determined by `backward`. If it is false, the
- forward LSTM is used, the backward one otherwise.
- Args:
- inp: A 3-D tensor of shape [`batch_size`, `max_length`, `feature_dim`].
- length: A 1-D tensor of shape [`batch_size`] and type int64. Each element
- represents the length of the corresponding sequence in `inp`.
- state: If specified, uses it as the initial state.
- memory: If specified, uses it as the initial memory.
- num_nodes: The number of LSTM cells.
- backward: If true, reverses the `inp` before adding the ops. The output is
- also reversed so that the direction is the same as `inp`.
- clip: Value used to clip the cell values.
- reg_func: Function used for the weight regularization such as
- `tf.nn.l2_loss`.
- weight_reg: If true, regularize the filter weights with `reg_func`.
- weight_collection: Collection to add the weights to for regularization.
- bias_reg: If true, regularize the bias vector with `reg_func`.
- stddev: Standard deviation used to initialize the variables.
- seed: Seed used to initialize the variables.
- decode: If true, does not add ops which are not used for inference.
- use_native_weights: If true, uses weights in the same format as the native
- implementations.
- name: Name of the op.
- Returns:
- A 3-D tensor of shape [`batch_size`, `max_length`, `num_nodes`].
- """
- with tf.variable_scope(name):
- if backward:
- if length is None:
- inp = tf.reverse(inp, [1])
- else:
- inp = tf.reverse_sequence(inp, length, 1, 0)
- num_prev = inp.get_shape()[2]
- if stddev:
- initializer = tf.truncated_normal_initializer(stddev=stddev, seed=seed)
- else:
- initializer = tf.uniform_unit_scaling_initializer(seed=seed)
- if use_native_weights:
- with tf.variable_scope("LSTMCell"):
- w = tf.get_variable(
- "W_0",
- shape=[num_prev + num_nodes, 4 * num_nodes],
- initializer=initializer,
- dtype=tf.float32)
- w_i_m = tf.slice(w, [0, 0], [num_prev, 4 * num_nodes], name="w_i_m")
- w_m_m = tf.reshape(
- tf.slice(w, [num_prev, 0], [num_nodes, 4 * num_nodes]),
- [num_nodes, 4, num_nodes],
- name="w_m_m")
- else:
- w_i_m = tf.get_variable("w_i_m", [num_prev, 4 * num_nodes],
- initializer=initializer)
- w_m_m = tf.get_variable("w_m_m", [num_nodes, 4, num_nodes],
- initializer=initializer)
- if not decode and weight_reg:
- tf.add_to_collection(weight_collection, reg_func(w_i_m, name="w_i_m_reg"))
- tf.add_to_collection(weight_collection, reg_func(w_m_m, name="w_m_m_reg"))
- batch_size = shapes.tensor_dim(inp, dim=0)
- num_frames = shapes.tensor_dim(inp, dim=1)
- prev = tf.reshape(inp, tf.stack([batch_size * num_frames, num_prev]))
- if use_native_weights:
- with tf.variable_scope("LSTMCell"):
- b = tf.get_variable(
- "B",
- shape=[4 * num_nodes],
- initializer=tf.zeros_initializer(),
- dtype=tf.float32)
- biases = tf.identity(b, name="biases")
- else:
- biases = tf.get_variable(
- "biases", [4 * num_nodes], initializer=tf.constant_initializer(0.0))
- if not decode and bias_reg:
- tf.add_to_collection(
- weight_collection, reg_func(
- biases, name="biases_reg"))
- prev = tf.nn.xw_plus_b(prev, w_i_m, biases)
- prev = tf.reshape(prev, tf.stack([batch_size, num_frames, 4, num_nodes]))
- if state is None:
- state = tf.fill(tf.stack([batch_size, num_nodes]), 0.0)
- if memory is None:
- memory = tf.fill(tf.stack([batch_size, num_nodes]), 0.0)
- out, _, mem = rnn.variable_lstm(prev, state, memory, w_m_m, clip=clip)
- if backward:
- if length is None:
- out = tf.reverse(out, [1])
- else:
- out = tf.reverse_sequence(out, length, 1, 0)
- return out, mem
|