123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- # Copyright 2016 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """RNN model with embeddings"""
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import tensorflow as tf
- class NamignizerModel(object):
- """The Namignizer model ~ strongly based on PTB"""
- def __init__(self, is_training, config):
- self.batch_size = batch_size = config.batch_size
- self.num_steps = num_steps = config.num_steps
- size = config.hidden_size
- # will always be 27
- vocab_size = config.vocab_size
- # placeholders for inputs
- self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps])
- self._targets = tf.placeholder(tf.int32, [batch_size, num_steps])
- # weights for the loss function
- self._weights = tf.placeholder(tf.float32, [batch_size * num_steps])
- # lstm for our RNN cell (GRU supported too)
- lstm_cells = []
- for layer in range(config.num_layers):
- lstm_cell = tf.contrib.rnn.BasicLSTMCell(size, forget_bias=0.0)
- if is_training and config.keep_prob < 1:
- lstm_cell = tf.contrib.rnn.DropoutWrapper(
- lstm_cell, output_keep_prob=config.keep_prob)
- lstm_cells.append(lstm_cell)
- cell = tf.contrib.rnn.MultiRNNCell(lstm_cells)
- self._initial_state = cell.zero_state(batch_size, tf.float32)
- with tf.device("/cpu:0"):
- embedding = tf.get_variable("embedding", [vocab_size, size])
- inputs = tf.nn.embedding_lookup(embedding, self._input_data)
- if is_training and config.keep_prob < 1:
- inputs = tf.nn.dropout(inputs, config.keep_prob)
- outputs = []
- state = self._initial_state
- with tf.variable_scope("RNN"):
- for time_step in range(num_steps):
- if time_step > 0:
- tf.get_variable_scope().reuse_variables()
- (cell_output, state) = cell(inputs[:, time_step, :], state)
- outputs.append(cell_output)
- output = tf.reshape(tf.concat(axis=1, values=outputs), [-1, size])
- softmax_w = tf.get_variable("softmax_w", [size, vocab_size])
- softmax_b = tf.get_variable("softmax_b", [vocab_size])
- logits = tf.matmul(output, softmax_w) + softmax_b
- loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
- [logits],
- [tf.reshape(self._targets, [-1])],
- [self._weights])
- self._loss = loss
- self._cost = cost = tf.reduce_sum(loss) / batch_size
- self._final_state = state
- # probabilities of each letter
- self._activations = tf.nn.softmax(logits)
- # ability to save the model
- self.saver = tf.train.Saver(tf.global_variables())
- if not is_training:
- return
- self._lr = tf.Variable(0.0, trainable=False)
- tvars = tf.trainable_variables()
- grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
- config.max_grad_norm)
- optimizer = tf.train.GradientDescentOptimizer(self.lr)
- self._train_op = optimizer.apply_gradients(zip(grads, tvars))
- def assign_lr(self, session, lr_value):
- session.run(tf.assign(self.lr, lr_value))
- @property
- def input_data(self):
- return self._input_data
- @property
- def targets(self):
- return self._targets
- @property
- def activations(self):
- return self._activations
- @property
- def weights(self):
- return self._weights
- @property
- def initial_state(self):
- return self._initial_state
- @property
- def cost(self):
- return self._cost
- @property
- def loss(self):
- return self._loss
- @property
- def final_state(self):
- return self._final_state
- @property
- def lr(self):
- return self._lr
- @property
- def train_op(self):
- return self._train_op
|