123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539 |
- /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ==============================================================================*/
- // OpKernel of LSTM Neural Networks:
- //
- // LSTM: VariableLSTMOp (VariableLSTMGradOp)
- //
- // where (.*) are the ops to compute gradients for the corresponding ops.
- #define EIGEN_USE_THREADS
- #include <vector>
- #ifdef GOOGLE_INCLUDES
- #include "third_party/eigen3/Eigen/Core"
- #include "third_party/tensorflow/core/framework/op.h"
- #include "third_party/tensorflow/core/framework/op_kernel.h"
- #include "third_party/tensorflow/core/framework/tensor.h"
- #else
- #include "Eigen/Core"
- #include "tensorflow/core/framework/op.h"
- #include "tensorflow/core/framework/op_kernel.h"
- #include "tensorflow/core/framework/tensor.h"
- #endif // GOOGLE_INCLUDES
- namespace tensorflow {
- using Eigen::array;
- using Eigen::DenseIndex;
- using IndexPair = Eigen::IndexPair<int>;
- Status AreDimsEqual(int dim1, int dim2, const string& message) {
- if (dim1 != dim2) {
- return errors::InvalidArgument(message, ": ", dim1, " vs. ", dim2);
- }
- return Status::OK();
- }
- // ------------------------------- VariableLSTMOp -----------------------------
- // Kernel to compute the forward propagation of a Long Short-Term Memory
- // network. See the doc of the op below for more detail.
- class VariableLSTMOp : public OpKernel {
- public:
- explicit VariableLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("clip", &clip_));
- OP_REQUIRES(
- ctx, clip_ >= 0.0,
- errors::InvalidArgument("clip_ needs to be equal or greator than 0"));
- }
- void Compute(OpKernelContext* ctx) override {
- // Inputs.
- const auto input = ctx->input(0).tensor<float, 4>();
- const auto initial_state = ctx->input(1).tensor<float, 2>();
- const auto initial_memory = ctx->input(2).tensor<float, 2>();
- const auto w_m_m = ctx->input(3).tensor<float, 3>();
- const int batch_size = input.dimension(0);
- const int seq_len = input.dimension(1);
- const int output_dim = input.dimension(3);
- // Sanity checks.
- OP_REQUIRES_OK(ctx, AreDimsEqual(4, input.dimension(2), "Input num"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, initial_state.dimension(0),
- "State batch"));
- OP_REQUIRES_OK(
- ctx, AreDimsEqual(output_dim, initial_state.dimension(1), "State dim"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, initial_memory.dimension(0),
- "Memory batch"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(output_dim, initial_memory.dimension(1),
- "Memory dim"));
- OP_REQUIRES_OK(
- ctx, AreDimsEqual(output_dim, w_m_m.dimension(0), "Weight dim 0"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(4, w_m_m.dimension(1), "Weight dim 1"));
- OP_REQUIRES_OK(
- ctx, AreDimsEqual(output_dim, w_m_m.dimension(2), "Weight dim 2"));
- // Outputs.
- Tensor* act_tensor = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(
- 0, {batch_size, seq_len, output_dim}, &act_tensor));
- auto act = act_tensor->tensor<float, 3>();
- act.setZero();
- Tensor* gate_raw_act_tensor = nullptr;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_output(1, {batch_size, seq_len, 4, output_dim},
- &gate_raw_act_tensor));
- auto gate_raw_act = gate_raw_act_tensor->tensor<float, 4>();
- gate_raw_act.setZero();
- Tensor* memory_tensor = nullptr;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_output(2, {batch_size, seq_len, output_dim},
- &memory_tensor));
- auto memory = memory_tensor->tensor<float, 3>();
- memory.setZero();
- // Const and scratch tensors.
- Tensor ones_tensor;
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {batch_size, output_dim},
- &ones_tensor));
- auto ones = ones_tensor.tensor<float, 2>();
- ones.setConstant(1.0);
- Tensor state_tensor;
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {batch_size, output_dim},
- &state_tensor));
- auto state = state_tensor.tensor<float, 2>();
- state = initial_state;
- Tensor scratch_tensor;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_temp(DT_FLOAT, {batch_size, 4, output_dim},
- &scratch_tensor));
- auto scratch = scratch_tensor.tensor<float, 3>();
- scratch.setZero();
- // Uses the most efficient order for the contraction depending on the batch
- // size.
- // This is the code shared by both cases. It is discouraged to use the
- // implicit capture with lambda functions, but it should be clear that what
- // is done here.
- auto Forward = [&](int i) {
- // Each pre-activation value is stored in the following order (See the
- // comment of the op for the meaning):
- //
- // i: 0
- // j: 1
- // f: 2
- // o: 3
- // Adds one to the pre-activation values of the forget gate. This is a
- // heuristic to make the training easier.
- scratch.chip(2, 1) += ones;
- gate_raw_act.chip(i, 1) = scratch;
- // c_t = f_t * c_{t-1} + i_t * j_t
- if (i == 0) {
- state = initial_memory * scratch.chip(2, 1).sigmoid();
- } else {
- state = memory.chip(i - 1, 1) * scratch.chip(2, 1).sigmoid();
- }
- state += scratch.chip(0, 1).sigmoid() * scratch.chip(1, 1).tanh();
- if (clip_ > 0.0) {
- // Clips the values if required.
- state = state.cwiseMax(-clip_).cwiseMin(clip_);
- }
- memory.chip(i, 1) = state;
- // h_t = o_t * tanh(c_t)
- state = scratch.chip(3, 1).sigmoid() * state.tanh();
- act.chip(i, 1) = state;
- };
- if (batch_size == 1) {
- // Reshapes the weight tensor to pretend as if it is a matrix
- // multiplication which is more efficient.
- auto w_m_m_r =
- w_m_m.reshape(array<DenseIndex, 2>{output_dim, 4 * output_dim});
- // Dimensions for the contraction.
- const array<IndexPair, 1> m_m_dim = {IndexPair(1, 0)};
- for (int i = 0; i < seq_len; ++i) {
- // Computes the pre-activation value of the input and each gate.
- scratch = input.chip(i, 1) +
- state.contract(w_m_m_r, m_m_dim)
- .reshape(array<DenseIndex, 3>{batch_size, 4, output_dim});
- Forward(i);
- }
- } else {
- // Shuffles the dimensions of the weight tensor to be efficient when used
- // in the left-hand side. Allocates memory for the shuffled tensor for
- // efficiency.
- Tensor w_m_m_s_tensor;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_temp(DT_FLOAT, {output_dim * 4, output_dim},
- &w_m_m_s_tensor));
- auto w_m_m_s = w_m_m_s_tensor.tensor<float, 2>();
- w_m_m_s = w_m_m.shuffle(array<int, 3>{2, 1, 0})
- .reshape(array<DenseIndex, 2>{output_dim * 4, output_dim});
- // Dimensions for the contraction.
- const array<IndexPair, 1> m_m_dim = {IndexPair(1, 1)};
- for (int i = 0; i < seq_len; ++i) {
- // Computes the pre-activation value of the input and each gate.
- scratch = input.chip(i, 1) +
- w_m_m_s.contract(state, m_m_dim)
- .reshape(array<DenseIndex, 3>{output_dim, 4, batch_size})
- .shuffle(array<int, 3>{2, 1, 0});
- Forward(i);
- }
- }
- }
- private:
- // Threshold to clip the values of memory cells.
- float clip_ = 0;
- };
- REGISTER_KERNEL_BUILDER(Name("VariableLSTM").Device(DEVICE_CPU),
- VariableLSTMOp);
- REGISTER_OP("VariableLSTM")
- .Attr("clip: float = 0.0")
- .Input("input: float32")
- .Input("initial_state: float32")
- .Input("initial_memory: float32")
- .Input("w_m_m: float32")
- .Output("activation: float32")
- .Output("gate_raw_act: float32")
- .Output("memory: float32")
- .Doc(R"doc(
- Computes the forward propagation of a Long Short-Term Memory Network.
- It computes the following equation recursively for `0<t<=T`:
- i_t = sigmoid(a_{i,t})
- j_t = tanh(a_{j,t})
- f_t = sigmoid(a_{f,t} + 1.0)
- o_t = sigmoid(a_{o,t})
- c_t = f_t * c_{t-1} + i_t * j_t
- c'_t = min(max(c_t, -clip), clip) if clip > 0 else c_t
- h_t = o_t * tanh(c'_t)
- where
- a_{l,t} = w_{l,m,m} * h_{t-1} + x'_{l,t}
- where
- x'_{l,t} = w_{l,m,i} * x_{t}.
- `input` corresponds to the concatenation of `X'_i`, `X'_j`, `X'_f`, and `X'_o`
- where `X'_l = (x'_{l,1}, x'_{l,2}, ..., x'_{l,T})`, `initial_state` corresponds
- to `h_{0}`, `initial_memory` corresponds to `c_{0}` and `weight` corresponds to
- `w_{l,m,m}`. `X'_l` (the transformed input) is computed outside of the op in
- advance, so w_{l,m,i} is not passed in to the op.
- `activation` corresponds to `H = (h_1, h_2, ..., h_T)`, `gate_raw_activation`
- corresponds to the concatanation of `A_i`, `A_j`, `A_f` and `A_o`, and `memory`
- corresponds `C = (c_0, c_1, ..., c_T)`.
- All entries in the batch are propagated to the end, and are assumed to be the
- same length.
- input: 4-D with shape `[batch_size, seq_len, 4, num_nodes]`
- initial_state: 2-D with shape `[batch_size, num_nodes]`
- initial_memory: 2-D with shape `[batch_size, num_nodes]`
- w_m_m: 3-D with shape `[num_nodes, 4, num_nodes]`
- activation: 3-D with shape `[batch_size, seq_len, num_nodes]`
- gate_raw_act: 3-D with shape `[batch_size, seq_len, 4, num_nodes]`
- memory: 3-D with shape `[batch_size, seq_len, num_nodes]`
- )doc");
- // ----------------------------- VariableLSTMGradOp ----------------------------
- // Kernel to compute the gradient of VariableLSTMOp.
- class VariableLSTMGradOp : public OpKernel {
- public:
- explicit VariableLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
- void Compute(OpKernelContext* ctx) override {
- // Inputs.
- const auto initial_state = ctx->input(0).tensor<float, 2>();
- const auto initial_memory = ctx->input(1).tensor<float, 2>();
- const auto w_m_m = ctx->input(2).tensor<float, 3>();
- const auto act = ctx->input(3).tensor<float, 3>();
- const auto gate_raw_act = ctx->input(4).tensor<float, 4>();
- const auto memory = ctx->input(5).tensor<float, 3>();
- const auto act_grad = ctx->input(6).tensor<float, 3>();
- const auto gate_raw_act_grad = ctx->input(7).tensor<float, 4>();
- const auto memory_grad = ctx->input(8).tensor<float, 3>();
- const int batch_size = act.dimension(0);
- const int seq_len = act.dimension(1);
- const int output_dim = act.dimension(2);
- // Sanity checks.
- OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, initial_state.dimension(0),
- "State batch"));
- OP_REQUIRES_OK(
- ctx, AreDimsEqual(output_dim, initial_state.dimension(1), "State dim"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, initial_memory.dimension(0),
- "Memory batch"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(output_dim, initial_memory.dimension(1),
- "Memory dim"));
- OP_REQUIRES_OK(
- ctx, AreDimsEqual(output_dim, w_m_m.dimension(0), "Weight dim 0"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(4, w_m_m.dimension(1), "Weight dim 1"));
- OP_REQUIRES_OK(
- ctx, AreDimsEqual(output_dim, w_m_m.dimension(2), "Weight dim 2"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, gate_raw_act.dimension(0),
- "Gate raw activation batch"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(seq_len, gate_raw_act.dimension(1),
- "Gate raw activation len"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(4, gate_raw_act.dimension(2),
- "Gate raw activation num"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(output_dim, gate_raw_act.dimension(3),
- "Gate raw activation dim"));
- OP_REQUIRES_OK(
- ctx, AreDimsEqual(batch_size, memory.dimension(0), "Memory batch"));
- OP_REQUIRES_OK(ctx,
- AreDimsEqual(seq_len, memory.dimension(1), "Memory len"));
- OP_REQUIRES_OK(ctx,
- AreDimsEqual(output_dim, memory.dimension(2), "Memory dim"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, act_grad.dimension(0),
- "Activation gradient batch"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(seq_len, act_grad.dimension(1),
- "Activation gradient len"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(output_dim, act_grad.dimension(2),
- "Activation gradient dim"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, gate_raw_act_grad.dimension(0),
- "Activation gradient batch"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(seq_len, gate_raw_act_grad.dimension(1),
- "Activation gradient len"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(4, gate_raw_act_grad.dimension(2),
- "Activation gradient num"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(output_dim, gate_raw_act_grad.dimension(3),
- "Activation gradient dim"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, memory_grad.dimension(0),
- "Memory gradient batch"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(seq_len, memory_grad.dimension(1),
- "Memory gradient len"));
- OP_REQUIRES_OK(ctx, AreDimsEqual(output_dim, memory_grad.dimension(2),
- "Memory gradient dim"));
- // Outputs.
- std::vector<Tensor*> collections(4, nullptr);
- OP_REQUIRES_OK(ctx,
- ctx->allocate_output(0, {batch_size, seq_len, 4, output_dim},
- &collections[0]));
- auto input_grad = collections[0]->tensor<float, 4>();
- input_grad.setZero();
- OP_REQUIRES_OK(ctx, ctx->allocate_output(1, {batch_size, output_dim},
- &collections[1]));
- auto init_state_grad = collections[1]->tensor<float, 2>();
- init_state_grad.setZero();
- OP_REQUIRES_OK(ctx, ctx->allocate_output(2, {batch_size, output_dim},
- &collections[2]));
- auto init_memory_grad = collections[2]->tensor<float, 2>();
- init_memory_grad.setZero();
- OP_REQUIRES_OK(ctx, ctx->allocate_output(3, {output_dim, 4, output_dim},
- &collections[3]));
- auto w_m_m_grad = collections[3]->tensor<float, 3>();
- w_m_m_grad.setZero();
- // Const and scratch tensors.
- Tensor ones_tensor;
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {batch_size, output_dim},
- &ones_tensor));
- auto ones = ones_tensor.tensor<float, 2>();
- ones.setConstant(1.0);
- Tensor scratch_tensor;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_temp(DT_FLOAT, {batch_size, 4, output_dim},
- &scratch_tensor));
- auto scratch = scratch_tensor.tensor<float, 3>();
- scratch.setZero();
- Tensor tmp1_tensor;
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {batch_size, output_dim},
- &tmp1_tensor));
- auto tmp1 = tmp1_tensor.tensor<float, 2>();
- tmp1.setZero();
- Tensor tmp2_tensor;
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {batch_size, output_dim},
- &tmp2_tensor));
- auto tmp2 = tmp2_tensor.tensor<float, 2>();
- tmp2.setZero();
- // Uses the most efficient order for the contraction depending on the batch
- // size.
- // Shuffles the dimensions of the weight tensor to be efficient when used in
- // the left-hand side. Allocates memory for the shuffled tensor for
- // efficiency.
- Tensor w_m_m_s_tensor;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_temp(DT_FLOAT, {4, output_dim, output_dim},
- &w_m_m_s_tensor));
- auto w_m_m_s = w_m_m_s_tensor.tensor<float, 3>();
- if (batch_size == 1) {
- // Allocates memory only it is used.
- w_m_m_s = w_m_m.shuffle(array<int, 3>{1, 2, 0});
- }
- // Dimensions for the contraction with the weight tensor.
- const array<IndexPair, 1> m_m_dim =
- batch_size == 1 ? array<IndexPair, 1>{IndexPair(1, 0)}
- : array<IndexPair, 1>{IndexPair(1, 1)};
- // Dimensions for the contraction of the batch dimensions.
- const array<IndexPair, 1> b_b_dim = {IndexPair(0, 0)};
- for (int i = seq_len - 1; i >= 0; --i) {
- if (i == seq_len - 1) {
- init_state_grad = act_grad.chip(i, 1);
- } else {
- w_m_m_grad +=
- act.chip(i, 1)
- .contract(scratch.reshape(
- array<DenseIndex, 2>{batch_size, 4 * output_dim}),
- b_b_dim)
- .reshape(array<DenseIndex, 3>{output_dim, 4, output_dim});
- if (batch_size == 1) {
- init_state_grad.device(ctx->eigen_cpu_device()) =
- scratch.chip(0, 1).contract(w_m_m_s.chip(0, 0), m_m_dim) +
- scratch.chip(1, 1).contract(w_m_m_s.chip(1, 0), m_m_dim) +
- scratch.chip(2, 1).contract(w_m_m_s.chip(2, 0), m_m_dim) +
- scratch.chip(3, 1).contract(w_m_m_s.chip(3, 0), m_m_dim);
- } else {
- init_state_grad.device(ctx->eigen_cpu_device()) =
- (w_m_m.chip(0, 1).contract(scratch.chip(0, 1), m_m_dim) +
- w_m_m.chip(1, 1).contract(scratch.chip(1, 1), m_m_dim) +
- w_m_m.chip(2, 1).contract(scratch.chip(2, 1), m_m_dim) +
- w_m_m.chip(3, 1).contract(scratch.chip(3, 1), m_m_dim))
- .shuffle(array<int, 2>{1, 0});
- }
- init_state_grad += act_grad.chip(i, 1);
- }
- auto gate_raw_act_t = gate_raw_act.chip(i, 1);
- auto gate_raw_act_grad_t = gate_raw_act_grad.chip(i, 1);
- // Output gate.
- tmp1 = memory.chip(i, 1);
- tmp1 = tmp1.tanh(); // y_t
- tmp2 = gate_raw_act_t.chip(3, 1).sigmoid(); // o_t
- scratch.chip(3, 1) = init_state_grad * tmp1 * tmp2 * (ones - tmp2) +
- gate_raw_act_grad_t.chip(3, 1);
- init_memory_grad += init_state_grad * tmp2 * (ones - tmp1.square()) +
- memory_grad.chip(i, 1);
- // Input gate.
- tmp1 = gate_raw_act_t.chip(0, 1).sigmoid(); // i_t
- tmp2 = gate_raw_act_t.chip(1, 1);
- tmp2 = tmp2.tanh(); // j_t
- scratch.chip(0, 1) = init_memory_grad * tmp2 * tmp1 * (ones - tmp1) +
- gate_raw_act_grad_t.chip(0, 1);
- // Input.
- scratch.chip(1, 1) = init_memory_grad * tmp1 * (ones - tmp2.square()) +
- gate_raw_act_grad_t.chip(1, 1);
- // Forget gate.
- tmp1 = gate_raw_act_t.chip(2, 1).sigmoid(); // f_t
- if (i == 0) {
- scratch.chip(2, 1) =
- init_memory_grad * initial_memory * tmp1 * (ones - tmp1) +
- gate_raw_act_grad_t.chip(2, 1);
- } else {
- scratch.chip(2, 1) =
- init_memory_grad * memory.chip(i - 1, 1) * tmp1 * (ones - tmp1) +
- gate_raw_act_grad_t.chip(2, 1);
- }
- // Memory.
- init_memory_grad *= tmp1;
- input_grad.chip(i, 1) = scratch;
- }
- w_m_m_grad += initial_state
- .contract(scratch.reshape(array<DenseIndex, 2>{
- batch_size, 4 * output_dim}),
- b_b_dim)
- .reshape(array<DenseIndex, 3>{output_dim, 4, output_dim});
- if (batch_size == 1) {
- init_state_grad.device(ctx->eigen_cpu_device()) =
- (scratch.chip(0, 1).contract(w_m_m_s.chip(0, 0), m_m_dim) +
- scratch.chip(1, 1).contract(w_m_m_s.chip(1, 0), m_m_dim) +
- scratch.chip(2, 1).contract(w_m_m_s.chip(2, 0), m_m_dim) +
- scratch.chip(3, 1).contract(w_m_m_s.chip(3, 0), m_m_dim));
- } else {
- init_state_grad.device(ctx->eigen_cpu_device()) =
- (w_m_m.chip(0, 1).contract(scratch.chip(0, 1), m_m_dim) +
- w_m_m.chip(1, 1).contract(scratch.chip(1, 1), m_m_dim) +
- w_m_m.chip(2, 1).contract(scratch.chip(2, 1), m_m_dim) +
- w_m_m.chip(3, 1).contract(scratch.chip(3, 1), m_m_dim))
- .shuffle(array<int, 2>{1, 0});
- }
- }
- };
- REGISTER_KERNEL_BUILDER(Name("VariableLSTMGrad").Device(DEVICE_CPU),
- VariableLSTMGradOp);
- REGISTER_OP("VariableLSTMGrad")
- .Input("initial_state: float32")
- .Input("initial_memory: float32")
- .Input("w_m_m: float32")
- .Input("activation: float32")
- .Input("gate_raw_act: float32")
- .Input("memory: float32")
- .Input("act_grad: float32")
- .Input("gate_raw_act_grad: float32")
- .Input("memory_grad: float32")
- .Output("input_grad: float32")
- .Output("initial_state_grad: float32")
- .Output("initial_memory_grad: float32")
- .Output("w_m_m_grad: float32")
- .Doc(R"doc(
- Computes the gradient for VariableLSTM.
- This is to be used conjunction with VariableLSTM. It ignores the clipping used
- in the forward pass.
- initial_state: 2-D with shape `[batch_size, num_nodes]`
- initial_memory: 2-D with shape `[batch_size, num_nodes]`
- w_m_m: 3-D with shape `[num_nodes, 4, num_nodes]`
- activation: 3-D with shape `[batch_size, seq_len, num_nodes]`
- gate_raw_act: 3-D with shape `[batch_size, seq_len, 4, num_nodes]`
- memory: 3-D with shape `[batch_size, seq_len, num_nodes]`
- act_grad: 3-D with shape `[batch_size, seq_len, num_nodes]`
- gate_raw_act_grad: 3-D with shape `[batch_size, seq_len, 4, num_nodes]`
- memory_grad: 3-D with shape `[batch_size, seq_len, num_nodes]`
- input_grad: 3-D with shape `[batch_size, seq_len, num_nodes]`
- initial_state_grad: 2-D with shape `[batch_size, num_nodes]`
- initial_memory_grad: 2-D with shape `[batch_size, num_nodes]`
- w_m_m_grad: 3-D with shape `[num_nodes, 4, num_nodes]`
- )doc");
- } // namespace tensorflow
|