浏览代码

Force new instance creation in MultiRNNCell (See also CL 145094809)

Neal Wu 8 年之前
父节点
当前提交
520b557e09
共有 2 个文件被更改,包括 14 次插入9 次删除
  1. 8 5
      tutorials/rnn/ptb/ptb_word_lm.py
  2. 6 4
      tutorials/rnn/translate/seq2seq_model.py

+ 8 - 5
tutorials/rnn/ptb/ptb_word_lm.py

@@ -108,13 +108,16 @@ class PTBModel(object):
     # Slightly better results can be obtained with forget gate biases
     # initialized to 1 but the hyperparameters of the model would need to be
     # different than reported in the paper.
-    lstm_cell = tf.contrib.rnn.BasicLSTMCell(
-        size, forget_bias=0.0, state_is_tuple=True)
+    def lstm_cell():
+      return tf.contrib.rnn.BasicLSTMCell(
+          size, forget_bias=0.0, state_is_tuple=True)
+    attn_cell = lstm_cell
     if is_training and config.keep_prob < 1:
-      lstm_cell = tf.contrib.rnn.DropoutWrapper(
-          lstm_cell, output_keep_prob=config.keep_prob)
+      def attn_cell():
+        return tf.contrib.rnn.DropoutWrapper(
+            lstm_cell(), output_keep_prob=config.keep_prob)
     cell = tf.contrib.rnn.MultiRNNCell(
-        [lstm_cell] * config.num_layers, state_is_tuple=True)
+        [attn_cell() for _ in range(config.num_layers)], state_is_tuple=True)
 
     self._initial_state = cell.zero_state(batch_size, data_type())
 

+ 6 - 4
tutorials/rnn/translate/seq2seq_model.py

@@ -114,12 +114,14 @@ class Seq2SeqModel(object):
       softmax_loss_function = sampled_loss
 
     # Create the internal multi-layer cell for our RNN.
-    single_cell = tf.nn.rnn_cell.GRUCell(size)
+    def single_cell():
+      return tf.nn.rnn_cell.GRUCell(size)
     if use_lstm:
-      single_cell = tf.nn.rnn_cell.BasicLSTMCell(size)
-    cell = single_cell
+      def single_cell():
+        return tf.nn.rnn_cell.BasicLSTMCell(size)
+    cell = single_cell()
     if num_layers > 1:
-      cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers)
+      cell = tf.nn.rnn_cell.MultiRNNCell([single_cell() for _ in range(num_layers)])
 
     # The seq2seq function: we use embedding for the input and attention.
     def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):