浏览代码

Rewrite to use inspect.getargspec

Neal Wu 8 年之前
父节点
当前提交
c15fada281
共有 1 个文件被更改,包括 6 次插入4 次删除
  1. 6 4
      tutorials/rnn/ptb/ptb_word_lm.py

+ 6 - 4
tutorials/rnn/ptb/ptb_word_lm.py

@@ -56,6 +56,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import inspect
 import time
 
 import numpy as np
@@ -111,13 +112,14 @@ class PTBModel(object):
     def lstm_cell():
       # With the latest TensorFlow source code (as of Mar 27, 2017),
       # the BasicLSTMCell will need a reuse parameter which is unfortunately not
-      # defined in TensorFlow 1.0. To maintain backwards compatibility, we add a
-      # try-except here:
-      try:
+      # defined in TensorFlow 1.0. To maintain backwards compatibility, we add
+      # an argument check here:
+      if 'reuse' in inspect.getargspec(
+          tf.contrib.rnn.BasicLSTMCell.__init__).args:
         return tf.contrib.rnn.BasicLSTMCell(
             size, forget_bias=0.0, state_is_tuple=True,
             reuse=tf.get_variable_scope().reuse)
-      except TypeError:
+      else:
         return tf.contrib.rnn.BasicLSTMCell(
             size, forget_bias=0.0, state_is_tuple=True)
     attn_cell = lstm_cell