|
@@ -158,11 +158,11 @@ class Seq2SeqAttentionModel(object):
|
|
|
for layer_i in xrange(hps.enc_layers):
|
|
|
with tf.variable_scope('encoder%d'%layer_i), tf.device(
|
|
|
self._next_device()):
|
|
|
- cell_fw = tf.nn.rnn_cell.LSTMCell(
|
|
|
+ cell_fw = tf.contrib.rnn.LSTMCell(
|
|
|
hps.num_hidden,
|
|
|
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123),
|
|
|
state_is_tuple=False)
|
|
|
- cell_bw = tf.nn.rnn_cell.LSTMCell(
|
|
|
+ cell_bw = tf.contrib.rnn.LSTMCell(
|
|
|
hps.num_hidden,
|
|
|
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
|
|
|
state_is_tuple=False)
|
|
@@ -188,7 +188,7 @@ class Seq2SeqAttentionModel(object):
|
|
|
loop_function = _extract_argmax_and_embed(
|
|
|
embedding, (w, v), update_embedding=False)
|
|
|
|
|
|
- cell = tf.nn.rnn_cell.LSTMCell(
|
|
|
+ cell = tf.contrib.rnn.LSTMCell(
|
|
|
hps.num_hidden,
|
|
|
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
|
|
|
state_is_tuple=False)
|