gru_cell.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """GRU cell implementation for the skip-thought vectors model."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. _layer_norm = tf.contrib.layers.layer_norm
  21. class LayerNormGRUCell(tf.contrib.rnn.RNNCell):
  22. """GRU cell with layer normalization.
  23. The layer normalization implementation is based on:
  24. https://arxiv.org/abs/1607.06450.
  25. "Layer Normalization"
  26. Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
  27. """
  28. def __init__(self,
  29. num_units,
  30. w_initializer,
  31. u_initializer,
  32. b_initializer,
  33. activation=tf.nn.tanh):
  34. """Initializes the cell.
  35. Args:
  36. num_units: Number of cell units.
  37. w_initializer: Initializer for the "W" (input) parameter matrices.
  38. u_initializer: Initializer for the "U" (recurrent) parameter matrices.
  39. b_initializer: Initializer for the "b" (bias) parameter vectors.
  40. activation: Cell activation function.
  41. """
  42. self._num_units = num_units
  43. self._w_initializer = w_initializer
  44. self._u_initializer = u_initializer
  45. self._b_initializer = b_initializer
  46. self._activation = activation
  47. @property
  48. def state_size(self):
  49. return self._num_units
  50. @property
  51. def output_size(self):
  52. return self._num_units
  53. def _w_h_initializer(self):
  54. """Returns an initializer for the "W_h" parameter matrix.
  55. See equation (23) in the paper. The "W_h" parameter matrix is the
  56. concatenation of two parameter submatrices. The matrix returned is
  57. [U_z, U_r].
  58. Returns:
  59. A Tensor with shape [num_units, 2 * num_units] as described above.
  60. """
  61. def _initializer(shape, dtype=tf.float32, partition_info=None):
  62. num_units = self._num_units
  63. assert shape == [num_units, 2 * num_units]
  64. u_z = self._u_initializer([num_units, num_units], dtype, partition_info)
  65. u_r = self._u_initializer([num_units, num_units], dtype, partition_info)
  66. return tf.concat([u_z, u_r], 1)
  67. return _initializer
  68. def _w_x_initializer(self, input_dim):
  69. """Returns an initializer for the "W_x" parameter matrix.
  70. See equation (23) in the paper. The "W_x" parameter matrix is the
  71. concatenation of two parameter submatrices. The matrix returned is
  72. [W_z, W_r].
  73. Args:
  74. input_dim: The dimension of the cell inputs.
  75. Returns:
  76. A Tensor with shape [input_dim, 2 * num_units] as described above.
  77. """
  78. def _initializer(shape, dtype=tf.float32, partition_info=None):
  79. num_units = self._num_units
  80. assert shape == [input_dim, 2 * num_units]
  81. w_z = self._w_initializer([input_dim, num_units], dtype, partition_info)
  82. w_r = self._w_initializer([input_dim, num_units], dtype, partition_info)
  83. return tf.concat([w_z, w_r], 1)
  84. return _initializer
  85. def __call__(self, inputs, state, scope=None):
  86. """GRU cell with layer normalization."""
  87. input_dim = inputs.get_shape().as_list()[1]
  88. num_units = self._num_units
  89. with tf.variable_scope(scope or "gru_cell"):
  90. with tf.variable_scope("gates"):
  91. w_h = tf.get_variable(
  92. "w_h", [num_units, 2 * num_units],
  93. initializer=self._w_h_initializer())
  94. w_x = tf.get_variable(
  95. "w_x", [input_dim, 2 * num_units],
  96. initializer=self._w_x_initializer(input_dim))
  97. z_and_r = (_layer_norm(tf.matmul(state, w_h), scope="layer_norm/w_h") +
  98. _layer_norm(tf.matmul(inputs, w_x), scope="layer_norm/w_x"))
  99. z, r = tf.split(tf.sigmoid(z_and_r), 2, 1)
  100. with tf.variable_scope("candidate"):
  101. w = tf.get_variable(
  102. "w", [input_dim, num_units], initializer=self._w_initializer)
  103. u = tf.get_variable(
  104. "u", [num_units, num_units], initializer=self._u_initializer)
  105. h_hat = (r * _layer_norm(tf.matmul(state, u), scope="layer_norm/u") +
  106. _layer_norm(tf.matmul(inputs, w), scope="layer_norm/w"))
  107. new_h = (1 - z) * state + z * self._activation(h_hat)
  108. return new_h, new_h