|
@@ -56,6 +56,7 @@ from __future__ import absolute_import
|
|
|
from __future__ import division
|
|
|
from __future__ import print_function
|
|
|
|
|
|
+import inspect
|
|
|
import time
|
|
|
|
|
|
import numpy as np
|
|
@@ -111,13 +112,14 @@ class PTBModel(object):
|
|
|
def lstm_cell():
|
|
|
# With the latest TensorFlow source code (as of Mar 27, 2017),
|
|
|
# the BasicLSTMCell will need a reuse parameter which is unfortunately not
|
|
|
- # defined in TensorFlow 1.0. To maintain backwards compatibility, we add a
|
|
|
- # try-except here:
|
|
|
- try:
|
|
|
+ # defined in TensorFlow 1.0. To maintain backwards compatibility, we add
|
|
|
+ # an argument check here:
|
|
|
+ if 'reuse' in inspect.getargspec(
|
|
|
+ tf.contrib.rnn.BasicLSTMCell.__init__).args:
|
|
|
return tf.contrib.rnn.BasicLSTMCell(
|
|
|
size, forget_bias=0.0, state_is_tuple=True,
|
|
|
reuse=tf.get_variable_scope().reuse)
|
|
|
- except TypeError:
|
|
|
+ else:
|
|
|
return tf.contrib.rnn.BasicLSTMCell(
|
|
|
size, forget_bias=0.0, state_is_tuple=True)
|
|
|
attn_cell = lstm_cell
|