Explorar o código

Merge pull request #982 from arvind385801/hotfix/tranlate_tutorial

fixed a bug in sampled_loss(), made compatible for 0.12.0
Neal Wu %!s(int64=8) %!d(string=hai) anos
pai
achega
0079e98a24
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      tutorials/rnn/translate/seq2seq_model.py

+ 1 - 1
tutorials/rnn/translate/seq2seq_model.py

@@ -100,7 +100,7 @@ class Seq2SeqModel(object):
       b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype)
       output_projection = (w, b)
 
-      def sampled_loss(labels, inputs):
+      def sampled_loss(inputs, labels):
         labels = tf.reshape(labels, [-1, 1])
         # We need to compute the sampled_softmax_loss using 32bit floats to
         # avoid numerical instabilities.