model.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """RNN model with embeddings"""
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import tensorflow as tf
  19. class NamignizerModel(object):
  20. """The Namignizer model ~ strongly based on PTB"""
  21. def __init__(self, is_training, config):
  22. self.batch_size = batch_size = config.batch_size
  23. self.num_steps = num_steps = config.num_steps
  24. size = config.hidden_size
  25. # will always be 27
  26. vocab_size = config.vocab_size
  27. # placeholders for inputs
  28. self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps])
  29. self._targets = tf.placeholder(tf.int32, [batch_size, num_steps])
  30. # weights for the loss function
  31. self._weights = tf.placeholder(tf.float32, [batch_size * num_steps])
  32. # lstm for our RNN cell (GRU supported too)
  33. lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(size, forget_bias=0.0)
  34. if is_training and config.keep_prob < 1:
  35. lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
  36. lstm_cell, output_keep_prob=config.keep_prob)
  37. cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * config.num_layers)
  38. self._initial_state = cell.zero_state(batch_size, tf.float32)
  39. with tf.device("/cpu:0"):
  40. embedding = tf.get_variable("embedding", [vocab_size, size])
  41. inputs = tf.nn.embedding_lookup(embedding, self._input_data)
  42. if is_training and config.keep_prob < 1:
  43. inputs = tf.nn.dropout(inputs, config.keep_prob)
  44. outputs = []
  45. state = self._initial_state
  46. with tf.variable_scope("RNN"):
  47. for time_step in range(num_steps):
  48. if time_step > 0:
  49. tf.get_variable_scope().reuse_variables()
  50. (cell_output, state) = cell(inputs[:, time_step, :], state)
  51. outputs.append(cell_output)
  52. output = tf.reshape(tf.concat(1, outputs), [-1, size])
  53. softmax_w = tf.get_variable("softmax_w", [size, vocab_size])
  54. softmax_b = tf.get_variable("softmax_b", [vocab_size])
  55. logits = tf.matmul(output, softmax_w) + softmax_b
  56. loss = tf.nn.seq2seq.sequence_loss_by_example(
  57. [logits],
  58. [tf.reshape(self._targets, [-1])],
  59. [self._weights])
  60. self._loss = loss
  61. self._cost = cost = tf.reduce_sum(loss) / batch_size
  62. self._final_state = state
  63. # probabilities of each letter
  64. self._activations = tf.nn.softmax(logits)
  65. # ability to save the model
  66. self.saver = tf.train.Saver(tf.all_variables())
  67. if not is_training:
  68. return
  69. self._lr = tf.Variable(0.0, trainable=False)
  70. tvars = tf.trainable_variables()
  71. grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
  72. config.max_grad_norm)
  73. optimizer = tf.train.GradientDescentOptimizer(self.lr)
  74. self._train_op = optimizer.apply_gradients(zip(grads, tvars))
  75. def assign_lr(self, session, lr_value):
  76. session.run(tf.assign(self.lr, lr_value))
  77. @property
  78. def input_data(self):
  79. return self._input_data
  80. @property
  81. def targets(self):
  82. return self._targets
  83. @property
  84. def activations(self):
  85. return self._activations
  86. @property
  87. def weights(self):
  88. return self._weights
  89. @property
  90. def initial_state(self):
  91. return self._initial_state
  92. @property
  93. def cost(self):
  94. return self._cost
  95. @property
  96. def loss(self):
  97. return self._loss
  98. @property
  99. def final_state(self):
  100. return self._final_state
  101. @property
  102. def lr(self):
  103. return self._lr
  104. @property
  105. def train_op(self):
  106. return self._train_op