|
|
@@ -37,11 +37,14 @@ class NamignizerModel(object):
|
|
|
self._weights = tf.placeholder(tf.float32, [batch_size * num_steps])
|
|
|
|
|
|
# lstm for our RNN cell (GRU supported too)
|
|
|
- lstm_cell = tf.contrib.rnn.BasicLSTMCell(size, forget_bias=0.0)
|
|
|
- if is_training and config.keep_prob < 1:
|
|
|
- lstm_cell = tf.contrib.rnn.DropoutWrapper(
|
|
|
- lstm_cell, output_keep_prob=config.keep_prob)
|
|
|
- cell = tf.contrib.rnn.MultiRNNCell([lstm_cell] * config.num_layers)
|
|
|
+ lstm_cells = []
|
|
|
+ for layer in range(config.num_layers):
|
|
|
+ lstm_cell = tf.contrib.rnn.BasicLSTMCell(size, forget_bias=0.0)
|
|
|
+ if is_training and config.keep_prob < 1:
|
|
|
+ lstm_cell = tf.contrib.rnn.DropoutWrapper(
|
|
|
+ lstm_cell, output_keep_prob=config.keep_prob)
|
|
|
+ lstm_cells.append(lstm_cell)
|
|
|
+ cell = tf.contrib.rnn.MultiRNNCell(lstm_cells)
|
|
|
|
|
|
self._initial_state = cell.zero_state(batch_size, tf.float32)
|
|
|
|