Ver código fonte

Rename sampled_loss argument inputs to logits in preparation for named arguments requirement

Neal Wu 8 anos atrás
pai
commit
f7cea8d01b
1 arquivos alterados com 2 adições e 2 exclusões
  1. 2 2
      tutorials/rnn/translate/seq2seq_model.py

+ 2 - 2
tutorials/rnn/translate/seq2seq_model.py

@@ -100,13 +100,13 @@ class Seq2SeqModel(object):
       b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype)
       output_projection = (w, b)
 
-      def sampled_loss(labels, inputs):
+      def sampled_loss(labels, logits):
         labels = tf.reshape(labels, [-1, 1])
         # We need to compute the sampled_softmax_loss using 32bit floats to
         # avoid numerical instabilities.
         local_w_t = tf.cast(w_t, tf.float32)
         local_b = tf.cast(b, tf.float32)
-        local_inputs = tf.cast(inputs, tf.float32)
+        local_inputs = tf.cast(logits, tf.float32)
         return tf.cast(
             tf.nn.sampled_softmax_loss(
                 weights=local_w_t,