Pārlūkot izejas kodu

Make ptb_word_lm compatible with the latest TensorFlow source while maintaining backwards compatibility with TF 1.0

Neal Wu 8 gadi atpakaļ
vecāks
revīzija
167b6c6922
1 mainītis faili ar 11 papildinājumiem un 2 dzēšanām
  1. 11 2
      tutorials/rnn/ptb/ptb_word_lm.py

+ 11 - 2
tutorials/rnn/ptb/ptb_word_lm.py

@@ -109,8 +109,17 @@ class PTBModel(object):
     # initialized to 1 but the hyperparameters of the model would need to be
     # different than reported in the paper.
     def lstm_cell():
-      return tf.contrib.rnn.BasicLSTMCell(
-          size, forget_bias=0.0, state_is_tuple=True)
+      # With the latest TensorFlow source code (as of Mar 27, 2017),
+      # the BasicLSTMCell will need a reuse parameter which is unfortunately not
+      # defined in TensorFlow 1.0. To maintain backwards compatibility, we add a
+      # try-except here:
+      try:
+        return tf.contrib.rnn.BasicLSTMCell(
+            size, forget_bias=0.0, state_is_tuple=True,
+            reuse=tf.get_variable_scope().reuse)
+      except TypeError:
+        return tf.contrib.rnn.BasicLSTMCell(
+            size, forget_bias=0.0, state_is_tuple=True)
     attn_cell = lstm_cell
     if is_training and config.keep_prob < 1:
       def attn_cell():