Explorar o código

Merge pull request #601 from panyx0718/master

Explicitly set state_is_tuple=False.
Xin Pan %!s(int64=9) %!d(string=hai) anos
pai
achega
9b023de8cb
Modificáronse 1 ficheiros con 6 adicións e 3 borrados
  1. 6 3
      textsum/seq2seq_attention_model.py

+ 6 - 3
textsum/seq2seq_attention_model.py

@@ -160,10 +160,12 @@ class Seq2SeqAttentionModel(object):
             self._next_device()):
             self._next_device()):
           cell_fw = tf.nn.rnn_cell.LSTMCell(
           cell_fw = tf.nn.rnn_cell.LSTMCell(
               hps.num_hidden,
               hps.num_hidden,
-              initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123))
+              initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123),
+              state_is_tuple=False)
           cell_bw = tf.nn.rnn_cell.LSTMCell(
           cell_bw = tf.nn.rnn_cell.LSTMCell(
               hps.num_hidden,
               hps.num_hidden,
-              initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113))
+              initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
+              state_is_tuple=False)
           (emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn(
           (emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn(
               cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
               cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
               sequence_length=article_lens)
               sequence_length=article_lens)
@@ -188,7 +190,8 @@ class Seq2SeqAttentionModel(object):
 
 
         cell = tf.nn.rnn_cell.LSTMCell(
         cell = tf.nn.rnn_cell.LSTMCell(
             hps.num_hidden,
             hps.num_hidden,
-            initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113))
+            initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
+            state_is_tuple=False)
 
 
         encoder_outputs = [tf.reshape(x, [hps.batch_size, 1, 2*hps.num_hidden])
         encoder_outputs = [tf.reshape(x, [hps.batch_size, 1, 2*hps.num_hidden])
                            for x in encoder_outputs]
                            for x in encoder_outputs]