seq2seq_lib.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """seq2seq library codes copied from elsewhere for customization."""
  16. import tensorflow as tf
  17. # Adapted to support sampled_softmax loss function, which accepts activations
  18. # instead of logits.
  19. def sequence_loss_by_example(inputs, targets, weights, loss_function,
  20. average_across_timesteps=True, name=None):
  21. """Sampled softmax loss for a sequence of inputs (per example).
  22. Args:
  23. inputs: List of 2D Tensors of shape [batch_size x hid_dim].
  24. targets: List of 1D batch-sized int32 Tensors of the same length as logits.
  25. weights: List of 1D batch-sized float-Tensors of the same length as logits.
  26. loss_function: Sampled softmax function (inputs, labels) -> loss
  27. average_across_timesteps: If set, divide the returned cost by the total
  28. label weight.
  29. name: Optional name for this operation, default: 'sequence_loss_by_example'.
  30. Returns:
  31. 1D batch-sized float Tensor: The log-perplexity for each sequence.
  32. Raises:
  33. ValueError: If len(inputs) is different from len(targets) or len(weights).
  34. """
  35. if len(targets) != len(inputs) or len(weights) != len(inputs):
  36. raise ValueError('Lengths of logits, weights, and targets must be the same '
  37. '%d, %d, %d.' % (len(inputs), len(weights), len(targets)))
  38. with tf.name_scope(values=inputs + targets + weights, name=name,
  39. default_name='sequence_loss_by_example'):
  40. log_perp_list = []
  41. for inp, target, weight in zip(inputs, targets, weights):
  42. crossent = loss_function(inp, target)
  43. log_perp_list.append(crossent * weight)
  44. log_perps = tf.add_n(log_perp_list)
  45. if average_across_timesteps:
  46. total_size = tf.add_n(weights)
  47. total_size += 1e-12 # Just to avoid division by 0 for all-0 weights.
  48. log_perps /= total_size
  49. return log_perps
  50. def sampled_sequence_loss(inputs, targets, weights, loss_function,
  51. average_across_timesteps=True,
  52. average_across_batch=True, name=None):
  53. """Weighted cross-entropy loss for a sequence of logits, batch-collapsed.
  54. Args:
  55. inputs: List of 2D Tensors of shape [batch_size x hid_dim].
  56. targets: List of 1D batch-sized int32 Tensors of the same length as inputs.
  57. weights: List of 1D batch-sized float-Tensors of the same length as inputs.
  58. loss_function: Sampled softmax function (inputs, labels) -> loss
  59. average_across_timesteps: If set, divide the returned cost by the total
  60. label weight.
  61. average_across_batch: If set, divide the returned cost by the batch size.
  62. name: Optional name for this operation, defaults to 'sequence_loss'.
  63. Returns:
  64. A scalar float Tensor: The average log-perplexity per symbol (weighted).
  65. Raises:
  66. ValueError: If len(inputs) is different from len(targets) or len(weights).
  67. """
  68. with tf.name_scope(values=inputs + targets + weights, name=name,
  69. default_name='sampled_sequence_loss'):
  70. cost = tf.reduce_sum(sequence_loss_by_example(
  71. inputs, targets, weights, loss_function,
  72. average_across_timesteps=average_across_timesteps))
  73. if average_across_batch:
  74. batch_size = tf.shape(targets[0])[0]
  75. return cost / tf.cast(batch_size, tf.float32)
  76. else:
  77. return cost
  78. def linear(args, output_size, bias, bias_start=0.0, scope=None):
  79. """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
  80. Args:
  81. args: a 2D Tensor or a list of 2D, batch x n, Tensors.
  82. output_size: int, second dimension of W[i].
  83. bias: boolean, whether to add a bias term or not.
  84. bias_start: starting value to initialize the bias; 0 by default.
  85. scope: VariableScope for the created subgraph; defaults to "Linear".
  86. Returns:
  87. A 2D Tensor with shape [batch x output_size] equal to
  88. sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
  89. Raises:
  90. ValueError: if some of the arguments has unspecified or wrong shape.
  91. """
  92. if args is None or (isinstance(args, (list, tuple)) and not args):
  93. raise ValueError('`args` must be specified')
  94. if not isinstance(args, (list, tuple)):
  95. args = [args]
  96. # Calculate the total size of arguments on dimension 1.
  97. total_arg_size = 0
  98. shapes = [a.get_shape().as_list() for a in args]
  99. for shape in shapes:
  100. if len(shape) != 2:
  101. raise ValueError('Linear is expecting 2D arguments: %s' % str(shapes))
  102. if not shape[1]:
  103. raise ValueError('Linear expects shape[1] of arguments: %s' % str(shapes))
  104. else:
  105. total_arg_size += shape[1]
  106. # Now the computation.
  107. with tf.variable_scope(scope or 'Linear'):
  108. matrix = tf.get_variable('Matrix', [total_arg_size, output_size])
  109. if len(args) == 1:
  110. res = tf.matmul(args[0], matrix)
  111. else:
  112. res = tf.matmul(tf.concat(axis=1, values=args), matrix)
  113. if not bias:
  114. return res
  115. bias_term = tf.get_variable(
  116. 'Bias', [output_size],
  117. initializer=tf.constant_initializer(bias_start))
  118. return res + bias_term