|
@@ -227,8 +227,9 @@ class Seq2SeqAttentionModel(object):
|
|
|
def sampled_loss_func(inputs, labels):
|
|
|
with tf.device('/cpu:0'): # Try gpu.
|
|
|
labels = tf.reshape(labels, [-1, 1])
|
|
|
- return tf.nn.sampled_softmax_loss(w_t, v, inputs, labels,
|
|
|
- hps.num_softmax_samples, vsize)
|
|
|
+ return tf.nn.sampled_softmax_loss(
|
|
|
+ weights=w_t, biases=v, labels=labels, inputs=inputs,
|
|
|
+ num_sampled=hps.num_softmax_samples, num_classes=vsize)
|
|
|
|
|
|
if hps.num_softmax_samples != 0 and hps.mode == 'train':
|
|
|
self._loss = seq2seq_lib.sampled_sequence_loss(
|