model.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """RNN model with embeddings"""
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import tensorflow as tf
  19. class NamignizerModel(object):
  20. """The Namignizer model ~ strongly based on PTB"""
  21. def __init__(self, is_training, config):
  22. self.batch_size = batch_size = config.batch_size
  23. self.num_steps = num_steps = config.num_steps
  24. size = config.hidden_size
  25. # will always be 27
  26. vocab_size = config.vocab_size
  27. # placeholders for inputs
  28. self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps])
  29. self._targets = tf.placeholder(tf.int32, [batch_size, num_steps])
  30. # weights for the loss function
  31. self._weights = tf.placeholder(tf.float32, [batch_size * num_steps])
  32. # lstm for our RNN cell (GRU supported too)
  33. lstm_cells = []
  34. for layer in range(config.num_layers):
  35. lstm_cell = tf.contrib.rnn.BasicLSTMCell(size, forget_bias=0.0)
  36. if is_training and config.keep_prob < 1:
  37. lstm_cell = tf.contrib.rnn.DropoutWrapper(
  38. lstm_cell, output_keep_prob=config.keep_prob)
  39. lstm_cells.append(lstm_cell)
  40. cell = tf.contrib.rnn.MultiRNNCell(lstm_cells)
  41. self._initial_state = cell.zero_state(batch_size, tf.float32)
  42. with tf.device("/cpu:0"):
  43. embedding = tf.get_variable("embedding", [vocab_size, size])
  44. inputs = tf.nn.embedding_lookup(embedding, self._input_data)
  45. if is_training and config.keep_prob < 1:
  46. inputs = tf.nn.dropout(inputs, config.keep_prob)
  47. outputs = []
  48. state = self._initial_state
  49. with tf.variable_scope("RNN"):
  50. for time_step in range(num_steps):
  51. if time_step > 0:
  52. tf.get_variable_scope().reuse_variables()
  53. (cell_output, state) = cell(inputs[:, time_step, :], state)
  54. outputs.append(cell_output)
  55. output = tf.reshape(tf.concat(axis=1, values=outputs), [-1, size])
  56. softmax_w = tf.get_variable("softmax_w", [size, vocab_size])
  57. softmax_b = tf.get_variable("softmax_b", [vocab_size])
  58. logits = tf.matmul(output, softmax_w) + softmax_b
  59. loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
  60. [logits],
  61. [tf.reshape(self._targets, [-1])],
  62. [self._weights])
  63. self._loss = loss
  64. self._cost = cost = tf.reduce_sum(loss) / batch_size
  65. self._final_state = state
  66. # probabilities of each letter
  67. self._activations = tf.nn.softmax(logits)
  68. # ability to save the model
  69. self.saver = tf.train.Saver(tf.global_variables())
  70. if not is_training:
  71. return
  72. self._lr = tf.Variable(0.0, trainable=False)
  73. tvars = tf.trainable_variables()
  74. grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
  75. config.max_grad_norm)
  76. optimizer = tf.train.GradientDescentOptimizer(self.lr)
  77. self._train_op = optimizer.apply_gradients(zip(grads, tvars))
  78. def assign_lr(self, session, lr_value):
  79. session.run(tf.assign(self.lr, lr_value))
  80. @property
  81. def input_data(self):
  82. return self._input_data
  83. @property
  84. def targets(self):
  85. return self._targets
  86. @property
  87. def activations(self):
  88. return self._activations
  89. @property
  90. def weights(self):
  91. return self._weights
  92. @property
  93. def initial_state(self):
  94. return self._initial_state
  95. @property
  96. def cost(self):
  97. return self._cost
  98. @property
  99. def loss(self):
  100. return self._loss
  101. @property
  102. def final_state(self):
  103. return self._final_state
  104. @property
  105. def lr(self):
  106. return self._lr
  107. @property
  108. def train_op(self):
  109. return self._train_op