rnn_ops.cc 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. // OpKernel of LSTM Neural Networks:
  13. //
  14. // LSTM: VariableLSTMOp (VariableLSTMGradOp)
  15. //
  16. // where (.*) are the ops to compute gradients for the corresponding ops.
  17. #define EIGEN_USE_THREADS
  18. #include <vector>
  19. #ifdef GOOGLE_INCLUDES
  20. #include "third_party/eigen3/Eigen/Core"
  21. #include "third_party/tensorflow/core/framework/op.h"
  22. #include "third_party/tensorflow/core/framework/op_kernel.h"
  23. #include "third_party/tensorflow/core/framework/tensor.h"
  24. #else
  25. #include "Eigen/Core"
  26. #include "tensorflow/core/framework/op.h"
  27. #include "tensorflow/core/framework/op_kernel.h"
  28. #include "tensorflow/core/framework/tensor.h"
  29. #endif // GOOGLE_INCLUDES
  30. namespace tensorflow {
  31. using Eigen::array;
  32. using Eigen::DenseIndex;
  33. using IndexPair = Eigen::IndexPair<int>;
  34. Status AreDimsEqual(int dim1, int dim2, const string& message) {
  35. if (dim1 != dim2) {
  36. return errors::InvalidArgument(message, ": ", dim1, " vs. ", dim2);
  37. }
  38. return Status::OK();
  39. }
  40. // ------------------------------- VariableLSTMOp -----------------------------
  41. // Kernel to compute the forward propagation of a Long Short-Term Memory
  42. // network. See the doc of the op below for more detail.
  43. class VariableLSTMOp : public OpKernel {
  44. public:
  45. explicit VariableLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
  46. OP_REQUIRES_OK(ctx, ctx->GetAttr("clip", &clip_));
  47. OP_REQUIRES(
  48. ctx, clip_ >= 0.0,
  49. errors::InvalidArgument("clip_ needs to be equal or greator than 0"));
  50. }
  51. void Compute(OpKernelContext* ctx) override {
  52. // Inputs.
  53. const auto input = ctx->input(0).tensor<float, 4>();
  54. const auto initial_state = ctx->input(1).tensor<float, 2>();
  55. const auto initial_memory = ctx->input(2).tensor<float, 2>();
  56. const auto w_m_m = ctx->input(3).tensor<float, 3>();
  57. const int batch_size = input.dimension(0);
  58. const int seq_len = input.dimension(1);
  59. const int output_dim = input.dimension(3);
  60. // Sanity checks.
  61. OP_REQUIRES_OK(ctx, AreDimsEqual(4, input.dimension(2), "Input num"));
  62. OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, initial_state.dimension(0),
  63. "State batch"));
  64. OP_REQUIRES_OK(
  65. ctx, AreDimsEqual(output_dim, initial_state.dimension(1), "State dim"));
  66. OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, initial_memory.dimension(0),
  67. "Memory batch"));
  68. OP_REQUIRES_OK(ctx, AreDimsEqual(output_dim, initial_memory.dimension(1),
  69. "Memory dim"));
  70. OP_REQUIRES_OK(
  71. ctx, AreDimsEqual(output_dim, w_m_m.dimension(0), "Weight dim 0"));
  72. OP_REQUIRES_OK(ctx, AreDimsEqual(4, w_m_m.dimension(1), "Weight dim 1"));
  73. OP_REQUIRES_OK(
  74. ctx, AreDimsEqual(output_dim, w_m_m.dimension(2), "Weight dim 2"));
  75. // Outputs.
  76. Tensor* act_tensor = nullptr;
  77. OP_REQUIRES_OK(ctx, ctx->allocate_output(
  78. 0, {batch_size, seq_len, output_dim}, &act_tensor));
  79. auto act = act_tensor->tensor<float, 3>();
  80. act.setZero();
  81. Tensor* gate_raw_act_tensor = nullptr;
  82. OP_REQUIRES_OK(ctx,
  83. ctx->allocate_output(1, {batch_size, seq_len, 4, output_dim},
  84. &gate_raw_act_tensor));
  85. auto gate_raw_act = gate_raw_act_tensor->tensor<float, 4>();
  86. gate_raw_act.setZero();
  87. Tensor* memory_tensor = nullptr;
  88. OP_REQUIRES_OK(ctx,
  89. ctx->allocate_output(2, {batch_size, seq_len, output_dim},
  90. &memory_tensor));
  91. auto memory = memory_tensor->tensor<float, 3>();
  92. memory.setZero();
  93. // Const and scratch tensors.
  94. Tensor ones_tensor;
  95. OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {batch_size, output_dim},
  96. &ones_tensor));
  97. auto ones = ones_tensor.tensor<float, 2>();
  98. ones.setConstant(1.0);
  99. Tensor state_tensor;
  100. OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {batch_size, output_dim},
  101. &state_tensor));
  102. auto state = state_tensor.tensor<float, 2>();
  103. state = initial_state;
  104. Tensor scratch_tensor;
  105. OP_REQUIRES_OK(ctx,
  106. ctx->allocate_temp(DT_FLOAT, {batch_size, 4, output_dim},
  107. &scratch_tensor));
  108. auto scratch = scratch_tensor.tensor<float, 3>();
  109. scratch.setZero();
  110. // Uses the most efficient order for the contraction depending on the batch
  111. // size.
  112. // This is the code shared by both cases. It is discouraged to use the
  113. // implicit capture with lambda functions, but it should be clear that what
  114. // is done here.
  115. auto Forward = [&](int i) {
  116. // Each pre-activation value is stored in the following order (See the
  117. // comment of the op for the meaning):
  118. //
  119. // i: 0
  120. // j: 1
  121. // f: 2
  122. // o: 3
  123. // Adds one to the pre-activation values of the forget gate. This is a
  124. // heuristic to make the training easier.
  125. scratch.chip(2, 1) += ones;
  126. gate_raw_act.chip(i, 1) = scratch;
  127. // c_t = f_t * c_{t-1} + i_t * j_t
  128. if (i == 0) {
  129. state = initial_memory * scratch.chip(2, 1).sigmoid();
  130. } else {
  131. state = memory.chip(i - 1, 1) * scratch.chip(2, 1).sigmoid();
  132. }
  133. state += scratch.chip(0, 1).sigmoid() * scratch.chip(1, 1).tanh();
  134. if (clip_ > 0.0) {
  135. // Clips the values if required.
  136. state = state.cwiseMax(-clip_).cwiseMin(clip_);
  137. }
  138. memory.chip(i, 1) = state;
  139. // h_t = o_t * tanh(c_t)
  140. state = scratch.chip(3, 1).sigmoid() * state.tanh();
  141. act.chip(i, 1) = state;
  142. };
  143. if (batch_size == 1) {
  144. // Reshapes the weight tensor to pretend as if it is a matrix
  145. // multiplication which is more efficient.
  146. auto w_m_m_r =
  147. w_m_m.reshape(array<DenseIndex, 2>{output_dim, 4 * output_dim});
  148. // Dimensions for the contraction.
  149. const array<IndexPair, 1> m_m_dim = {IndexPair(1, 0)};
  150. for (int i = 0; i < seq_len; ++i) {
  151. // Computes the pre-activation value of the input and each gate.
  152. scratch = input.chip(i, 1) +
  153. state.contract(w_m_m_r, m_m_dim)
  154. .reshape(array<DenseIndex, 3>{batch_size, 4, output_dim});
  155. Forward(i);
  156. }
  157. } else {
  158. // Shuffles the dimensions of the weight tensor to be efficient when used
  159. // in the left-hand side. Allocates memory for the shuffled tensor for
  160. // efficiency.
  161. Tensor w_m_m_s_tensor;
  162. OP_REQUIRES_OK(ctx,
  163. ctx->allocate_temp(DT_FLOAT, {output_dim * 4, output_dim},
  164. &w_m_m_s_tensor));
  165. auto w_m_m_s = w_m_m_s_tensor.tensor<float, 2>();
  166. w_m_m_s = w_m_m.shuffle(array<int, 3>{2, 1, 0})
  167. .reshape(array<DenseIndex, 2>{output_dim * 4, output_dim});
  168. // Dimensions for the contraction.
  169. const array<IndexPair, 1> m_m_dim = {IndexPair(1, 1)};
  170. for (int i = 0; i < seq_len; ++i) {
  171. // Computes the pre-activation value of the input and each gate.
  172. scratch = input.chip(i, 1) +
  173. w_m_m_s.contract(state, m_m_dim)
  174. .reshape(array<DenseIndex, 3>{output_dim, 4, batch_size})
  175. .shuffle(array<int, 3>{2, 1, 0});
  176. Forward(i);
  177. }
  178. }
  179. }
  180. private:
  181. // Threshold to clip the values of memory cells.
  182. float clip_ = 0;
  183. };
  184. REGISTER_KERNEL_BUILDER(Name("VariableLSTM").Device(DEVICE_CPU),
  185. VariableLSTMOp);
  186. REGISTER_OP("VariableLSTM")
  187. .Attr("clip: float = 0.0")
  188. .Input("input: float32")
  189. .Input("initial_state: float32")
  190. .Input("initial_memory: float32")
  191. .Input("w_m_m: float32")
  192. .Output("activation: float32")
  193. .Output("gate_raw_act: float32")
  194. .Output("memory: float32")
  195. .Doc(R"doc(
  196. Computes the forward propagation of a Long Short-Term Memory Network.
  197. It computes the following equation recursively for `0<t<=T`:
  198. i_t = sigmoid(a_{i,t})
  199. j_t = tanh(a_{j,t})
  200. f_t = sigmoid(a_{f,t} + 1.0)
  201. o_t = sigmoid(a_{o,t})
  202. c_t = f_t * c_{t-1} + i_t * j_t
  203. c'_t = min(max(c_t, -clip), clip) if clip > 0 else c_t
  204. h_t = o_t * tanh(c'_t)
  205. where
  206. a_{l,t} = w_{l,m,m} * h_{t-1} + x'_{l,t}
  207. where
  208. x'_{l,t} = w_{l,m,i} * x_{t}.
  209. `input` corresponds to the concatenation of `X'_i`, `X'_j`, `X'_f`, and `X'_o`
  210. where `X'_l = (x'_{l,1}, x'_{l,2}, ..., x'_{l,T})`, `initial_state` corresponds
  211. to `h_{0}`, `initial_memory` corresponds to `c_{0}` and `weight` corresponds to
  212. `w_{l,m,m}`. `X'_l` (the transformed input) is computed outside of the op in
  213. advance, so w_{l,m,i} is not passed in to the op.
  214. `activation` corresponds to `H = (h_1, h_2, ..., h_T)`, `gate_raw_activation`
  215. corresponds to the concatanation of `A_i`, `A_j`, `A_f` and `A_o`, and `memory`
  216. corresponds `C = (c_0, c_1, ..., c_T)`.
  217. All entries in the batch are propagated to the end, and are assumed to be the
  218. same length.
  219. input: 4-D with shape `[batch_size, seq_len, 4, num_nodes]`
  220. initial_state: 2-D with shape `[batch_size, num_nodes]`
  221. initial_memory: 2-D with shape `[batch_size, num_nodes]`
  222. w_m_m: 3-D with shape `[num_nodes, 4, num_nodes]`
  223. activation: 3-D with shape `[batch_size, seq_len, num_nodes]`
  224. gate_raw_act: 3-D with shape `[batch_size, seq_len, 4, num_nodes]`
  225. memory: 3-D with shape `[batch_size, seq_len, num_nodes]`
  226. )doc");
  227. // ----------------------------- VariableLSTMGradOp ----------------------------
  228. // Kernel to compute the gradient of VariableLSTMOp.
  229. class VariableLSTMGradOp : public OpKernel {
  230. public:
  231. explicit VariableLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
  232. void Compute(OpKernelContext* ctx) override {
  233. // Inputs.
  234. const auto initial_state = ctx->input(0).tensor<float, 2>();
  235. const auto initial_memory = ctx->input(1).tensor<float, 2>();
  236. const auto w_m_m = ctx->input(2).tensor<float, 3>();
  237. const auto act = ctx->input(3).tensor<float, 3>();
  238. const auto gate_raw_act = ctx->input(4).tensor<float, 4>();
  239. const auto memory = ctx->input(5).tensor<float, 3>();
  240. const auto act_grad = ctx->input(6).tensor<float, 3>();
  241. const auto gate_raw_act_grad = ctx->input(7).tensor<float, 4>();
  242. const auto memory_grad = ctx->input(8).tensor<float, 3>();
  243. const int batch_size = act.dimension(0);
  244. const int seq_len = act.dimension(1);
  245. const int output_dim = act.dimension(2);
  246. // Sanity checks.
  247. OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, initial_state.dimension(0),
  248. "State batch"));
  249. OP_REQUIRES_OK(
  250. ctx, AreDimsEqual(output_dim, initial_state.dimension(1), "State dim"));
  251. OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, initial_memory.dimension(0),
  252. "Memory batch"));
  253. OP_REQUIRES_OK(ctx, AreDimsEqual(output_dim, initial_memory.dimension(1),
  254. "Memory dim"));
  255. OP_REQUIRES_OK(
  256. ctx, AreDimsEqual(output_dim, w_m_m.dimension(0), "Weight dim 0"));
  257. OP_REQUIRES_OK(ctx, AreDimsEqual(4, w_m_m.dimension(1), "Weight dim 1"));
  258. OP_REQUIRES_OK(
  259. ctx, AreDimsEqual(output_dim, w_m_m.dimension(2), "Weight dim 2"));
  260. OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, gate_raw_act.dimension(0),
  261. "Gate raw activation batch"));
  262. OP_REQUIRES_OK(ctx, AreDimsEqual(seq_len, gate_raw_act.dimension(1),
  263. "Gate raw activation len"));
  264. OP_REQUIRES_OK(ctx, AreDimsEqual(4, gate_raw_act.dimension(2),
  265. "Gate raw activation num"));
  266. OP_REQUIRES_OK(ctx, AreDimsEqual(output_dim, gate_raw_act.dimension(3),
  267. "Gate raw activation dim"));
  268. OP_REQUIRES_OK(
  269. ctx, AreDimsEqual(batch_size, memory.dimension(0), "Memory batch"));
  270. OP_REQUIRES_OK(ctx,
  271. AreDimsEqual(seq_len, memory.dimension(1), "Memory len"));
  272. OP_REQUIRES_OK(ctx,
  273. AreDimsEqual(output_dim, memory.dimension(2), "Memory dim"));
  274. OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, act_grad.dimension(0),
  275. "Activation gradient batch"));
  276. OP_REQUIRES_OK(ctx, AreDimsEqual(seq_len, act_grad.dimension(1),
  277. "Activation gradient len"));
  278. OP_REQUIRES_OK(ctx, AreDimsEqual(output_dim, act_grad.dimension(2),
  279. "Activation gradient dim"));
  280. OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, gate_raw_act_grad.dimension(0),
  281. "Activation gradient batch"));
  282. OP_REQUIRES_OK(ctx, AreDimsEqual(seq_len, gate_raw_act_grad.dimension(1),
  283. "Activation gradient len"));
  284. OP_REQUIRES_OK(ctx, AreDimsEqual(4, gate_raw_act_grad.dimension(2),
  285. "Activation gradient num"));
  286. OP_REQUIRES_OK(ctx, AreDimsEqual(output_dim, gate_raw_act_grad.dimension(3),
  287. "Activation gradient dim"));
  288. OP_REQUIRES_OK(ctx, AreDimsEqual(batch_size, memory_grad.dimension(0),
  289. "Memory gradient batch"));
  290. OP_REQUIRES_OK(ctx, AreDimsEqual(seq_len, memory_grad.dimension(1),
  291. "Memory gradient len"));
  292. OP_REQUIRES_OK(ctx, AreDimsEqual(output_dim, memory_grad.dimension(2),
  293. "Memory gradient dim"));
  294. // Outputs.
  295. std::vector<Tensor*> collections(4, nullptr);
  296. OP_REQUIRES_OK(ctx,
  297. ctx->allocate_output(0, {batch_size, seq_len, 4, output_dim},
  298. &collections[0]));
  299. auto input_grad = collections[0]->tensor<float, 4>();
  300. input_grad.setZero();
  301. OP_REQUIRES_OK(ctx, ctx->allocate_output(1, {batch_size, output_dim},
  302. &collections[1]));
  303. auto init_state_grad = collections[1]->tensor<float, 2>();
  304. init_state_grad.setZero();
  305. OP_REQUIRES_OK(ctx, ctx->allocate_output(2, {batch_size, output_dim},
  306. &collections[2]));
  307. auto init_memory_grad = collections[2]->tensor<float, 2>();
  308. init_memory_grad.setZero();
  309. OP_REQUIRES_OK(ctx, ctx->allocate_output(3, {output_dim, 4, output_dim},
  310. &collections[3]));
  311. auto w_m_m_grad = collections[3]->tensor<float, 3>();
  312. w_m_m_grad.setZero();
  313. // Const and scratch tensors.
  314. Tensor ones_tensor;
  315. OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {batch_size, output_dim},
  316. &ones_tensor));
  317. auto ones = ones_tensor.tensor<float, 2>();
  318. ones.setConstant(1.0);
  319. Tensor scratch_tensor;
  320. OP_REQUIRES_OK(ctx,
  321. ctx->allocate_temp(DT_FLOAT, {batch_size, 4, output_dim},
  322. &scratch_tensor));
  323. auto scratch = scratch_tensor.tensor<float, 3>();
  324. scratch.setZero();
  325. Tensor tmp1_tensor;
  326. OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {batch_size, output_dim},
  327. &tmp1_tensor));
  328. auto tmp1 = tmp1_tensor.tensor<float, 2>();
  329. tmp1.setZero();
  330. Tensor tmp2_tensor;
  331. OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {batch_size, output_dim},
  332. &tmp2_tensor));
  333. auto tmp2 = tmp2_tensor.tensor<float, 2>();
  334. tmp2.setZero();
  335. // Uses the most efficient order for the contraction depending on the batch
  336. // size.
  337. // Shuffles the dimensions of the weight tensor to be efficient when used in
  338. // the left-hand side. Allocates memory for the shuffled tensor for
  339. // efficiency.
  340. Tensor w_m_m_s_tensor;
  341. OP_REQUIRES_OK(ctx,
  342. ctx->allocate_temp(DT_FLOAT, {4, output_dim, output_dim},
  343. &w_m_m_s_tensor));
  344. auto w_m_m_s = w_m_m_s_tensor.tensor<float, 3>();
  345. if (batch_size == 1) {
  346. // Allocates memory only it is used.
  347. w_m_m_s = w_m_m.shuffle(array<int, 3>{1, 2, 0});
  348. }
  349. // Dimensions for the contraction with the weight tensor.
  350. const array<IndexPair, 1> m_m_dim =
  351. batch_size == 1 ? array<IndexPair, 1>{IndexPair(1, 0)}
  352. : array<IndexPair, 1>{IndexPair(1, 1)};
  353. // Dimensions for the contraction of the batch dimensions.
  354. const array<IndexPair, 1> b_b_dim = {IndexPair(0, 0)};
  355. for (int i = seq_len - 1; i >= 0; --i) {
  356. if (i == seq_len - 1) {
  357. init_state_grad = act_grad.chip(i, 1);
  358. } else {
  359. w_m_m_grad +=
  360. act.chip(i, 1)
  361. .contract(scratch.reshape(
  362. array<DenseIndex, 2>{batch_size, 4 * output_dim}),
  363. b_b_dim)
  364. .reshape(array<DenseIndex, 3>{output_dim, 4, output_dim});
  365. if (batch_size == 1) {
  366. init_state_grad.device(ctx->eigen_cpu_device()) =
  367. scratch.chip(0, 1).contract(w_m_m_s.chip(0, 0), m_m_dim) +
  368. scratch.chip(1, 1).contract(w_m_m_s.chip(1, 0), m_m_dim) +
  369. scratch.chip(2, 1).contract(w_m_m_s.chip(2, 0), m_m_dim) +
  370. scratch.chip(3, 1).contract(w_m_m_s.chip(3, 0), m_m_dim);
  371. } else {
  372. init_state_grad.device(ctx->eigen_cpu_device()) =
  373. (w_m_m.chip(0, 1).contract(scratch.chip(0, 1), m_m_dim) +
  374. w_m_m.chip(1, 1).contract(scratch.chip(1, 1), m_m_dim) +
  375. w_m_m.chip(2, 1).contract(scratch.chip(2, 1), m_m_dim) +
  376. w_m_m.chip(3, 1).contract(scratch.chip(3, 1), m_m_dim))
  377. .shuffle(array<int, 2>{1, 0});
  378. }
  379. init_state_grad += act_grad.chip(i, 1);
  380. }
  381. auto gate_raw_act_t = gate_raw_act.chip(i, 1);
  382. auto gate_raw_act_grad_t = gate_raw_act_grad.chip(i, 1);
  383. // Output gate.
  384. tmp1 = memory.chip(i, 1);
  385. tmp1 = tmp1.tanh(); // y_t
  386. tmp2 = gate_raw_act_t.chip(3, 1).sigmoid(); // o_t
  387. scratch.chip(3, 1) = init_state_grad * tmp1 * tmp2 * (ones - tmp2) +
  388. gate_raw_act_grad_t.chip(3, 1);
  389. init_memory_grad += init_state_grad * tmp2 * (ones - tmp1.square()) +
  390. memory_grad.chip(i, 1);
  391. // Input gate.
  392. tmp1 = gate_raw_act_t.chip(0, 1).sigmoid(); // i_t
  393. tmp2 = gate_raw_act_t.chip(1, 1);
  394. tmp2 = tmp2.tanh(); // j_t
  395. scratch.chip(0, 1) = init_memory_grad * tmp2 * tmp1 * (ones - tmp1) +
  396. gate_raw_act_grad_t.chip(0, 1);
  397. // Input.
  398. scratch.chip(1, 1) = init_memory_grad * tmp1 * (ones - tmp2.square()) +
  399. gate_raw_act_grad_t.chip(1, 1);
  400. // Forget gate.
  401. tmp1 = gate_raw_act_t.chip(2, 1).sigmoid(); // f_t
  402. if (i == 0) {
  403. scratch.chip(2, 1) =
  404. init_memory_grad * initial_memory * tmp1 * (ones - tmp1) +
  405. gate_raw_act_grad_t.chip(2, 1);
  406. } else {
  407. scratch.chip(2, 1) =
  408. init_memory_grad * memory.chip(i - 1, 1) * tmp1 * (ones - tmp1) +
  409. gate_raw_act_grad_t.chip(2, 1);
  410. }
  411. // Memory.
  412. init_memory_grad *= tmp1;
  413. input_grad.chip(i, 1) = scratch;
  414. }
  415. w_m_m_grad += initial_state
  416. .contract(scratch.reshape(array<DenseIndex, 2>{
  417. batch_size, 4 * output_dim}),
  418. b_b_dim)
  419. .reshape(array<DenseIndex, 3>{output_dim, 4, output_dim});
  420. if (batch_size == 1) {
  421. init_state_grad.device(ctx->eigen_cpu_device()) =
  422. (scratch.chip(0, 1).contract(w_m_m_s.chip(0, 0), m_m_dim) +
  423. scratch.chip(1, 1).contract(w_m_m_s.chip(1, 0), m_m_dim) +
  424. scratch.chip(2, 1).contract(w_m_m_s.chip(2, 0), m_m_dim) +
  425. scratch.chip(3, 1).contract(w_m_m_s.chip(3, 0), m_m_dim));
  426. } else {
  427. init_state_grad.device(ctx->eigen_cpu_device()) =
  428. (w_m_m.chip(0, 1).contract(scratch.chip(0, 1), m_m_dim) +
  429. w_m_m.chip(1, 1).contract(scratch.chip(1, 1), m_m_dim) +
  430. w_m_m.chip(2, 1).contract(scratch.chip(2, 1), m_m_dim) +
  431. w_m_m.chip(3, 1).contract(scratch.chip(3, 1), m_m_dim))
  432. .shuffle(array<int, 2>{1, 0});
  433. }
  434. }
  435. };
  436. REGISTER_KERNEL_BUILDER(Name("VariableLSTMGrad").Device(DEVICE_CPU),
  437. VariableLSTMGradOp);
  438. REGISTER_OP("VariableLSTMGrad")
  439. .Input("initial_state: float32")
  440. .Input("initial_memory: float32")
  441. .Input("w_m_m: float32")
  442. .Input("activation: float32")
  443. .Input("gate_raw_act: float32")
  444. .Input("memory: float32")
  445. .Input("act_grad: float32")
  446. .Input("gate_raw_act_grad: float32")
  447. .Input("memory_grad: float32")
  448. .Output("input_grad: float32")
  449. .Output("initial_state_grad: float32")
  450. .Output("initial_memory_grad: float32")
  451. .Output("w_m_m_grad: float32")
  452. .Doc(R"doc(
  453. Computes the gradient for VariableLSTM.
  454. This is to be used conjunction with VariableLSTM. It ignores the clipping used
  455. in the forward pass.
  456. initial_state: 2-D with shape `[batch_size, num_nodes]`
  457. initial_memory: 2-D with shape `[batch_size, num_nodes]`
  458. w_m_m: 3-D with shape `[num_nodes, 4, num_nodes]`
  459. activation: 3-D with shape `[batch_size, seq_len, num_nodes]`
  460. gate_raw_act: 3-D with shape `[batch_size, seq_len, 4, num_nodes]`
  461. memory: 3-D with shape `[batch_size, seq_len, num_nodes]`
  462. act_grad: 3-D with shape `[batch_size, seq_len, num_nodes]`
  463. gate_raw_act_grad: 3-D with shape `[batch_size, seq_len, 4, num_nodes]`
  464. memory_grad: 3-D with shape `[batch_size, seq_len, num_nodes]`
  465. input_grad: 3-D with shape `[batch_size, seq_len, num_nodes]`
  466. initial_state_grad: 2-D with shape `[batch_size, num_nodes]`
  467. initial_memory_grad: 2-D with shape `[batch_size, num_nodes]`
  468. w_m_m_grad: 3-D with shape `[num_nodes, 4, num_nodes]`
  469. )doc");
  470. } // namespace tensorflow