Explorar el Código

Explicitly set state_is_tuple=False.

Xin Pan hace 9 años
padre
commit
5e875226bc
Se han modificado 1 ficheros con 6 adiciones y 3 borrados
  1. 6 3
      textsum/seq2seq_attention_model.py

+ 6 - 3
textsum/seq2seq_attention_model.py

@@ -160,10 +160,12 @@ class Seq2SeqAttentionModel(object):
             self._next_device()):
           cell_fw = tf.nn.rnn_cell.LSTMCell(
               hps.num_hidden,
-              initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123))
+              initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123),
+              state_is_tuple=False)
           cell_bw = tf.nn.rnn_cell.LSTMCell(
               hps.num_hidden,
-              initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113))
+              initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
+              state_is_tuple=False)
           (emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn(
               cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
               sequence_length=article_lens)
@@ -188,7 +190,8 @@ class Seq2SeqAttentionModel(object):
 
         cell = tf.nn.rnn_cell.LSTMCell(
             hps.num_hidden,
-            initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113))
+            initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
+            state_is_tuple=False)
 
         encoder_outputs = [tf.reshape(x, [hps.batch_size, 1, 2*hps.num_hidden])
                            for x in encoder_outputs]