|
@@ -100,13 +100,13 @@ class Seq2SeqModel(object):
|
|
|
b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype)
|
|
|
output_projection = (w, b)
|
|
|
|
|
|
- def sampled_loss(labels, inputs):
|
|
|
+ def sampled_loss(labels, logits):
|
|
|
labels = tf.reshape(labels, [-1, 1])
|
|
|
# We need to compute the sampled_softmax_loss using 32bit floats to
|
|
|
# avoid numerical instabilities.
|
|
|
local_w_t = tf.cast(w_t, tf.float32)
|
|
|
local_b = tf.cast(b, tf.float32)
|
|
|
- local_inputs = tf.cast(inputs, tf.float32)
|
|
|
+ local_inputs = tf.cast(logits, tf.float32)
|
|
|
return tf.cast(
|
|
|
tf.nn.sampled_softmax_loss(
|
|
|
weights=local_w_t,
|