Browse Source

Explicitly set state_is_tuple=False.

Xin Pan 9 năm trước cách đây
mục cha
commit
5e875226bc
1 tập tin đã thay đổi với 6 bổ sung3 xóa
  1. 6 3
      textsum/seq2seq_attention_model.py

+ 6 - 3
textsum/seq2seq_attention_model.py

@@ -160,10 +160,12 @@ class Seq2SeqAttentionModel(object):
             self._next_device()):
           cell_fw = tf.nn.rnn_cell.LSTMCell(
               hps.num_hidden,
-              initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123))
+              initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123),
+              state_is_tuple=False)
           cell_bw = tf.nn.rnn_cell.LSTMCell(
               hps.num_hidden,
-              initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113))
+              initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
+              state_is_tuple=False)
           (emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn(
               cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
               sequence_length=article_lens)
@@ -188,7 +190,8 @@ class Seq2SeqAttentionModel(object):
 
         cell = tf.nn.rnn_cell.LSTMCell(
             hps.num_hidden,
-            initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113))
+            initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
+            state_is_tuple=False)
 
         encoder_outputs = [tf.reshape(x, [hps.batch_size, 1, 2*hps.num_hidden])
                            for x in encoder_outputs]