Prechádzať zdrojové kódy

Add pre-trained lm_1b model.

Xin Pan 9 rokov pred
rodič
commit
2a5a559689
4 zmenil súbory, kde vykonal 804 pridanie a 0 odobranie
  1. 27 0
      lm_1b/BUILD
  2. 191 0
      lm_1b/README.md
  3. 279 0
      lm_1b/data_utils.py
  4. 307 0
      lm_1b/lm_1b_eval.py

+ 27 - 0
lm_1b/BUILD

@@ -0,0 +1,27 @@
+package(default_visibility = [":internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package_group(
+    name = "internal",
+    packages = [
+        "//lm_1b/...",
+    ],
+)
+
+py_library(
+    name = "data_utils",
+    srcs = ["data_utils.py"],
+)
+
+py_binary(
+    name = "lm_1b_eval",
+    srcs = [
+        "lm_1b_eval.py",
+    ],
+    deps = [
+        ":data_utils",
+    ],
+)

+ 191 - 0
lm_1b/README.md

@@ -0,0 +1,191 @@
+<font size=4><b>Language Model on One Billion Word Benchmark</b></font>
+
+<b>Authors:</b>
+
+Oriol Vinyals (vinyals@google.com, github: OriolVinyals),
+Xin Pan (xpan@google.com, github: panyx0718)
+
+<b>Paper Authors:</b>
+
+Rafal Jozefowicz, Oriol Vinyals, Mike Schuster, Noam Shazeer, Yonghui Wu
+
+<b>TL;DR</b>
+
+This is a pretrained model on One Billion Word Benchmark.
+If you use this model in your publication, please cite the original paper:
+
+@article{jozefowicz2016exploring,
+  title={Exploring the Limits of Language Modeling},
+  author={Jozefowicz, Rafal and Vinyals, Oriol and Schuster, Mike
+          and Shazeer, Noam and Wu, Yonghui},
+  journal={arXiv preprint arXiv:1602.02410},
+  year={2016}
+}
+
+<b>Introduction</b>
+
+In this release, we open source a model trained on the One Billion Word
+Benchmark (http://arxiv.org/abs/1312.3005), a large language corpus in English
+which was released in 2013. This dataset contains about one billion words, and
+has a vocabulary size of about 800K words. It contains mostly news data. Since
+sentences in the training set are shuffled, models can ignore the context and
+focus on sentence level language modeling.
+
+In the original release and subsequent work, people have used the same test set
+to train models on this dataset as a standard benchmark for language modeling.
+Recently, we wrote an article (http://arxiv.org/abs/1602.02410) describing a
+model hybrid between character CNN, a large and deep LSTM, and a specific
+Softmax architecture which allowed us to train the best model on this dataset
+thus far, almost halving the best perplexity previously obtained by others.
+
+<b>Code Release</b>
+
+The open-sourced components include:
+
+* TensorFlow GraphDef proto buffer text file.
+* TensorFlow pre-trained checkpoint shards.
+* Code used to evaluate the pre-trained model.
+* Vocabulary file.
+* Test set from LM-1B evaluation.
+
+The code supports 4 evaluation modes:
+
+* Given provided dataset, calculate the model's perplexity.
+* Given a prefix sentence, predict the next words.
+* Dump the softmax embedding, character-level CNN word embeddings.
+* Give a sentence, dump the embedding from the LSTM state.
+
+<b>Results</b>
+
+Model | Test Perplexity | Number of Params [billions]
+------|-----------------|----------------------------
+Sigmoid-RNN-2048 [Blackout] | 68.3 | 4.1
+Interpolated KN 5-gram, 1.1B n-grams [chelba2013one] | 67.6 | 1.76
+Sparse Non-Negative Matrix LM [shazeer2015sparse] | 52.9 | 33
+RNN-1024 + MaxEnt 9-gram features [chelba2013one] | 51.3 | 20
+LSTM-512-512 | 54.1 | 0.82
+LSTM-1024-512 | 48.2 | 0.82
+LSTM-2048-512 | 43.7 | 0.83
+LSTM-8192-2048 (No Dropout) | 37.9 | 3.3
+LSTM-8192-2048 (50\% Dropout) | 32.2 | 3.3
+2-Layer LSTM-8192-1024 (BIG LSTM) | 30.6 | 1.8
+(THIS RELEASE) BIG LSTM+CNN Inputs | <b>30.0</b> | <b>1.04</b>
+
+<b>How To Run</b>
+
+Pre-requesite:
+
+* Install TensorFlow.
+* Install Bazel.
+* Download the data files:
+  * Model GraphDef file:
+  [link](download.tensorflow.org/models/LM_LSTM_CNN/graph-2016-09-10.pbtxt)
+  * Model Checkpoint sharded file:
+  [1](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-base)
+  [2](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-char-embedding)
+  [3](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-lstm)
+  [4](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax0)
+  [5](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax1)
+  [6](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax2)
+  [7](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax3)
+  [8](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax4)
+  [9](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax5)
+  [10](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax6)
+  [11](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax7)
+  [12](download.tensorflow.org/models/LM_LSTM_CNN/all_shards-2016-09-10/ckpt-softmax8)
+  * Vocabulary file:
+  [link](download.tensorflow.org/models/LM_LSTM_CNN/vocab-2016-09-10.txt)
+  * test dataset: link
+  [link](download.tensorflow.org/models/LM_LSTM_CNN/test/news.en.heldout-00000-of-00050)
+* It is recommended to run on modern desktop PC instead of laptop.
+
+```shell
+# 1. Clone the code to your workspace.
+# 2. Download the data to your workspace.
+# 3. Create an empty WORKSPACE file in your workspace.
+# 4. Create an empty output directory in your workspace.
+# Example directory structure below:
+ls -R
+.:
+data  lm_1b  output  WORKSPACE
+
+./data:
+ckpt  eval_2_8k_1k_1_1_char.pbtxt  news.en.heldout-00000-of-00050  vocab.txt
+
+./lm_1b:
+BUILD  data_utils.py  data_utils.pyc  lm_1b_eval.py  README.md
+
+./output:
+
+# Build the codes.
+bazel build -c opt lm_1b/...
+# Run sample mode:
+bazel-bin/lm_1b/lm_1b_eval --mode sample \
+                           --prefix "I love that I" \
+                           --pbtxt data/eval_2_8k_1k_1_1_char.pbtxt \
+                           --vocab_file data/vocab.txt  \
+                           --ckpt data/ckpt
+...(omitted some TensorFlow output)
+I love
+I love that
+I love that I
+I love that I find
+I love that I find that
+I love that I find that amazing
+...(omitted)
+
+# Run eval mode:
+bazel-bin/lm_1b/lm_1b_eval --mode eval \
+                           --pbtxt data/eval_2_8k_1k_1_1_char.pbtxt \
+                           --vocab_file data/vocab.txt  \
+                           --input_data data/news.en.heldout-00000-of-00050 \
+                           --ckpt data/ckpt
+...(omitted some TensorFlow output)
+Loaded step 14108582.
+# perplexity is high initially because words without context are harder to
+# predict.
+Eval Step: 0, Average Perplexity: 2045.512297.
+Eval Step: 1, Average Perplexity: 229.478699.
+Eval Step: 2, Average Perplexity: 208.116787.
+Eval Step: 3, Average Perplexity: 338.870601.
+Eval Step: 4, Average Perplexity: 228.950107.
+Eval Step: 5, Average Perplexity: 197.685857.
+Eval Step: 6, Average Perplexity: 156.287063.
+Eval Step: 7, Average Perplexity: 124.866189.
+Eval Step: 8, Average Perplexity: 147.204975.
+Eval Step: 9, Average Perplexity: 90.124864.
+Eval Step: 10, Average Perplexity: 59.897914.
+Eval Step: 11, Average Perplexity: 42.591137.
+...(omitted)
+Eval Step: 4529, Average Perplexity: 29.243668.
+Eval Step: 4530, Average Perplexity: 29.302362.
+Eval Step: 4531, Average Perplexity: 29.285674.
+...(omitted. At convergence, it should be around 30.)
+
+# Run dump_emb mode:
+bazel-bin/lm_1b/lm_1b_eval --mode dump_emb \
+                           --pbtxt data/eval_2_8k_1k_1_1_char.pbtxt \
+                           --vocab_file data/vocab.txt  \
+                           --ckpt data/ckpt \
+                           --save_dir output
+...(omitted some TensorFlow output)
+Finished softmax weights
+Finished word embedding 0/793471
+Finished word embedding 1/793471
+Finished word embedding 2/793471
+...(omitted)
+ls output/
+embeddings_softmax.npy ...
+
+# Run dump_lstm_emb mode:
+bazel-bin/lm_1b/lm_1b_eval --mode dump_lstm_emb \
+                           --pbtxt data/eval_2_8k_1k_1_1_char.pbtxt \
+                           --vocab_file data/vocab.txt \
+                           --ckpt data/ckpt \
+                           --sentence "I love who I am ." \
+                           --save_dir output
+ls output/
+lstm_emb_step_0.npy  lstm_emb_step_2.npy  lstm_emb_step_4.npy
+lstm_emb_step_6.npy  lstm_emb_step_1.npy  lstm_emb_step_3.npy
+lstm_emb_step_5.npy
+```

+ 279 - 0
lm_1b/data_utils.py

@@ -0,0 +1,279 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A library for loading 1B word benchmark dataset."""
+
+import random
+
+import numpy as np
+import tensorflow as tf
+
+
+class Vocabulary(object):
+  """Class that holds a vocabulary for the dataset."""
+
+  def __init__(self, filename):
+    """Initialize vocabulary.
+
+    Args:
+      filename: Vocabulary file name.
+    """
+
+    self._id_to_word = []
+    self._word_to_id = {}
+    self._unk = -1
+    self._bos = -1
+    self._eos = -1
+
+    with tf.gfile.Open(filename) as f:
+      idx = 0
+      for line in f:
+        word_name = line.strip()
+        if word_name == '<S>':
+          self._bos = idx
+        elif word_name == '</S>':
+          self._eos = idx
+        elif word_name == '<UNK>':
+          self._unk = idx
+        if word_name == '!!!MAXTERMID':
+          continue
+
+        self._id_to_word.append(word_name)
+        self._word_to_id[word_name] = idx
+        idx += 1
+
+  @property
+  def bos(self):
+    return self._bos
+
+  @property
+  def eos(self):
+    return self._eos
+
+  @property
+  def unk(self):
+    return self._unk
+
+  @property
+  def size(self):
+    return len(self._id_to_word)
+
+  def word_to_id(self, word):
+    if word in self._word_to_id:
+      return self._word_to_id[word]
+    return self.unk
+
+  def id_to_word(self, cur_id):
+    if cur_id < self.size:
+      return self._id_to_word[cur_id]
+    return 'ERROR'
+
+  def decode(self, cur_ids):
+    """Convert a list of ids to a sentence, with space inserted."""
+    return ' '.join([self.id_to_word(cur_id) for cur_id in cur_ids])
+
+  def encode(self, sentence):
+    """Convert a sentence to a list of ids, with special tokens added."""
+    word_ids = [self.word_to_id(cur_word) for cur_word in sentence.split()]
+    return np.array([self.bos] + word_ids + [self.eos], dtype=np.int32)
+
+
+class CharsVocabulary(Vocabulary):
+  """Vocabulary containing character-level information."""
+
+  def __init__(self, filename, max_word_length):
+    super(CharsVocabulary, self).__init__(filename)
+    self._max_word_length = max_word_length
+    chars_set = set()
+
+    for word in self._id_to_word:
+      chars_set |= set(word)
+
+    free_ids = []
+    for i in range(256):
+      if chr(i) in chars_set:
+        continue
+      free_ids.append(chr(i))
+
+    if len(free_ids) < 5:
+      raise ValueError('Not enough free char ids: %d' % len(free_ids))
+
+    self.bos_char = free_ids[0]  # <begin sentence>
+    self.eos_char = free_ids[1]  # <end sentence>
+    self.bow_char = free_ids[2]  # <begin word>
+    self.eow_char = free_ids[3]  # <end word>
+    self.pad_char = free_ids[4]  # <padding>
+
+    chars_set |= {self.bos_char, self.eos_char, self.bow_char, self.eow_char,
+                  self.pad_char}
+
+    self._char_set = chars_set
+    num_words = len(self._id_to_word)
+
+    self._word_char_ids = np.zeros([num_words, max_word_length], dtype=np.int32)
+
+    self.bos_chars = self._convert_word_to_char_ids(self.bos_char)
+    self.eos_chars = self._convert_word_to_char_ids(self.eos_char)
+
+    for i, word in enumerate(self._id_to_word):
+      self._word_char_ids[i] = self._convert_word_to_char_ids(word)
+
+  @property
+  def word_char_ids(self):
+    return self._word_char_ids
+
+  @property
+  def max_word_length(self):
+    return self._max_word_length
+
+  def _convert_word_to_char_ids(self, word):
+    code = np.zeros([self.max_word_length], dtype=np.int32)
+    code[:] = ord(self.pad_char)
+
+    if len(word) > self.max_word_length - 2:
+      word = word[:self.max_word_length-2]
+    cur_word = self.bow_char + word + self.eow_char
+    for j in range(len(cur_word)):
+      code[j] = ord(cur_word[j])
+    return code
+
+  def word_to_char_ids(self, word):
+    if word in self._word_to_id:
+      return self._word_char_ids[self._word_to_id[word]]
+    else:
+      return self._convert_word_to_char_ids(word)
+
+  def encode_chars(self, sentence):
+    chars_ids = [self.word_to_char_ids(cur_word)
+                 for cur_word in sentence.split()]
+    return np.vstack([self.bos_chars] + chars_ids + [self.eos_chars])
+
+
+def get_batch(generator, batch_size, num_steps, max_word_length, pad=False):
+  """Read batches of input."""
+  cur_stream = [None] * batch_size
+
+  inputs = np.zeros([batch_size, num_steps], np.int32)
+  char_inputs = np.zeros([batch_size, num_steps, max_word_length], np.int32)
+  global_word_ids = np.zeros([batch_size, num_steps], np.int32)
+  targets = np.zeros([batch_size, num_steps], np.int32)
+  weights = np.ones([batch_size, num_steps], np.float32)
+
+  no_more_data = False
+  while True:
+    inputs[:] = 0
+    char_inputs[:] = 0
+    global_word_ids[:] = 0
+    targets[:] = 0
+    weights[:] = 0.0
+
+    for i in range(batch_size):
+      cur_pos = 0
+
+      while cur_pos < num_steps:
+        if cur_stream[i] is None or len(cur_stream[i][0]) <= 1:
+          try:
+            cur_stream[i] = list(generator.next())
+          except StopIteration:
+            # No more data, exhaust current streams and quit
+            no_more_data = True
+            break
+
+        how_many = min(len(cur_stream[i][0]) - 1, num_steps - cur_pos)
+        next_pos = cur_pos + how_many
+
+        inputs[i, cur_pos:next_pos] = cur_stream[i][0][:how_many]
+        char_inputs[i, cur_pos:next_pos] = cur_stream[i][1][:how_many]
+        global_word_ids[i, cur_pos:next_pos] = cur_stream[i][2][:how_many]
+        targets[i, cur_pos:next_pos] = cur_stream[i][0][1:how_many+1]
+        weights[i, cur_pos:next_pos] = 1.0
+
+        cur_pos = next_pos
+        cur_stream[i][0] = cur_stream[i][0][how_many:]
+        cur_stream[i][1] = cur_stream[i][1][how_many:]
+        cur_stream[i][2] = cur_stream[i][2][how_many:]
+
+        if pad:
+          break
+
+    if no_more_data and np.sum(weights) == 0:
+      # There is no more data and this is an empty batch. Done!
+      break
+    yield inputs, char_inputs, global_word_ids, targets, weights
+
+
+class LM1BDataset(object):
+  """Utility class for 1B word benchmark dataset.
+
+  The current implementation reads the data from the tokenized text files.
+  """
+
+  def __init__(self, filepattern, vocab):
+    """Initialize LM1BDataset reader.
+
+    Args:
+      filepattern: Dataset file pattern.
+      vocab: Vocabulary.
+    """
+    self._vocab = vocab
+    self._all_shards = tf.gfile.Glob(filepattern)
+    tf.logging.info('Found %d shards at %s', len(self._all_shards), filepattern)
+
+  def _load_random_shard(self):
+    """Randomly select a file and read it."""
+    return self._load_shard(random.choice(self._all_shards))
+
+  def _load_shard(self, shard_name):
+    """Read one file and convert to ids.
+
+    Args:
+      shard_name: file path.
+
+    Returns:
+      list of (id, char_id, global_word_id) tuples.
+    """
+    tf.logging.info('Loading data from: %s', shard_name)
+    with tf.gfile.Open(shard_name) as f:
+      sentences = f.readlines()
+    chars_ids = [self.vocab.encode_chars(sentence) for sentence in sentences]
+    ids = [self.vocab.encode(sentence) for sentence in sentences]
+
+    global_word_ids = []
+    current_idx = 0
+    for word_ids in ids:
+      current_size = len(word_ids) - 1  # without <BOS> symbol
+      cur_ids = np.arange(current_idx, current_idx + current_size)
+      global_word_ids.append(cur_ids)
+      current_idx += current_size
+
+    tf.logging.info('Loaded %d words.', current_idx)
+    tf.logging.info('Finished loading')
+    return zip(ids, chars_ids, global_word_ids)
+
+  def _get_sentence(self, forever=True):
+    while True:
+      ids = self._load_random_shard()
+      for current_ids in ids:
+        yield current_ids
+      if not forever:
+        break
+
+  def get_batch(self, batch_size, num_steps, pad=False, forever=True):
+    return get_batch(self._get_sentence(forever), batch_size, num_steps,
+                     self.vocab.max_word_length, pad=pad)
+
+  @property
+  def vocab(self):
+    return self._vocab

+ 307 - 0
lm_1b/lm_1b_eval.py

@@ -0,0 +1,307 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Eval pre-trained 1 billion word language model.
+"""
+import os
+import sys
+
+import numpy as np
+import tensorflow as tf
+
+from google.protobuf import text_format
+import data_utils
+
+FLAGS = tf.flags.FLAGS
+# General flags.
+tf.flags.DEFINE_string('mode', 'eval',
+                       'One of [sample, eval, dump_emb, dump_lstm_emb]. '
+                       '"sample" mode samples future word predictions, using '
+                       'FLAGS.prefix as prefix (prefix could be left empty). '
+                       '"eval" mode calculates perplexity of the '
+                       'FLAGS.input_data. '
+                       '"dump_emb" mode dumps word and softmax embeddings to '
+                       'FLAGS.save_dir. embeddings are dumped in the same '
+                       'order as words in vocabulary. All words in vocabulary '
+                       'are dumped.'
+                       'dump_lstm_emb dumps lstm embeddings of FLAGS.sentence '
+                       'to FLAGS.save_dir.')
+tf.flags.DEFINE_string('pbtxt', '',
+                       'GraphDef proto text file used to construct model '
+                       'structure.')
+tf.flags.DEFINE_string('ckpt', '',
+                       'Checkpoint directory used to fill model values.')
+tf.flags.DEFINE_string('vocab_file', '', 'Vocabulary file.')
+tf.flags.DEFINE_string('save_dir', '',
+                       'Used for "dump_emb" mode to save word embeddings.')
+# sample mode flags.
+tf.flags.DEFINE_string('prefix', '',
+                       'Used for "sample" mode to predict next words.')
+tf.flags.DEFINE_integer('max_sample_words', 100,
+                        'Sampling stops either when </S> is met or this number '
+                        'of steps has passed.')
+tf.flags.DEFINE_integer('num_samples', 3,
+                        'Number of samples to generate for the prefix.')
+# dump_lstm_emb mode flags.
+tf.flags.DEFINE_string('sentence', '',
+                       'Used as input for "dump_lstm_emb" mode.')
+# eval mode flags.
+tf.flags.DEFINE_string('input_data', '',
+                       'Input data files for eval model.')
+tf.flags.DEFINE_integer('max_eval_steps', 1000000,
+                        'Maximum mumber of steps to run "eval" mode.')
+
+
+# For saving demo resources, use batch size 1 and step 1.
+BATCH_SIZE = 1
+NUM_TIMESTEPS = 1
+MAX_WORD_LEN = 50
+
+
+def _LoadModel(gd_file, ckpt_file):
+  """Load the model from GraphDef and Checkpoint.
+
+  Args:
+    gd_file: GraphDef proto text file.
+    ckpt_file: TensorFlow Checkpoint file.
+
+  Returns:
+    TensorFlow session and tensors dict.
+  """
+  with tf.Graph().as_default():
+    sys.stderr.write('Recovering graph.\n')
+    with tf.gfile.FastGFile(gd_file, 'r') as f:
+      s = f.read()
+      gd = tf.GraphDef()
+      text_format.Merge(s, gd)
+
+    tf.logging.info('Recovering Graph %s', gd_file)
+    t = {}
+    [t['states_init'], t['lstm/lstm_0/control_dependency'],
+     t['lstm/lstm_1/control_dependency'], t['softmax_out'], t['class_ids_out'],
+     t['class_weights_out'], t['log_perplexity_out'], t['inputs_in'],
+     t['targets_in'], t['target_weights_in'], t['char_inputs_in'],
+     t['all_embs'], t['softmax_weights'], t['global_step']
+    ] = tf.import_graph_def(gd, {}, ['states_init',
+                                     'lstm/lstm_0/control_dependency:0',
+                                     'lstm/lstm_1/control_dependency:0',
+                                     'softmax_out:0',
+                                     'class_ids_out:0',
+                                     'class_weights_out:0',
+                                     'log_perplexity_out:0',
+                                     'inputs_in:0',
+                                     'targets_in:0',
+                                     'target_weights_in:0',
+                                     'char_inputs_in:0',
+                                     'all_embs_out:0',
+                                     'Reshape_3:0',
+                                     'global_step:0'], name='')
+
+    sys.stderr.write('Recovering checkpoint %s\n' % ckpt_file)
+    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+    sess.run('save/restore_all', {'save/Const:0': ckpt_file})
+    sess.run(t['states_init'])
+
+  return sess, t
+
+
+def _EvalModel(dataset):
+  """Evaluate model perplexity using provided dataset.
+
+  Args:
+    dataset: LM1BDataset object.
+  """
+  sess, t = _LoadModel(FLAGS.pbtxt, FLAGS.ckpt)
+
+  current_step = t['global_step'].eval(session=sess)
+  sys.stderr.write('Loaded step %d.\n' % current_step)
+
+  data_gen = dataset.get_batch(BATCH_SIZE, NUM_TIMESTEPS, forever=False)
+  sum_num = 0.0
+  sum_den = 0.0
+  perplexity = 0.0
+  for i, (inputs, char_inputs, _, targets, weights) in enumerate(data_gen):
+    input_dict = {t['inputs_in']: inputs,
+                  t['targets_in']: targets,
+                  t['target_weights_in']: weights}
+    if 'char_inputs_in' in t:
+      input_dict[t['char_inputs_in']] = char_inputs
+    log_perp = sess.run(t['log_perplexity_out'], feed_dict=input_dict)
+
+    if np.isnan(log_perp):
+      sys.stderr.error('log_perplexity is Nan.\n')
+    else:
+      sum_num += log_perp * weights.mean()
+      sum_den += weights.mean()
+    if sum_den > 0:
+      perplexity = np.exp(sum_num / sum_den)
+
+    sys.stderr.write('Eval Step: %d, Average Perplexity: %f.\n' %
+                     (i, perplexity))
+
+    if i > FLAGS.max_eval_steps:
+      break
+
+
+def _SampleSoftmax(softmax):
+  return min(np.sum(np.cumsum(softmax) < np.random.rand()), len(softmax) - 1)
+
+
+def _SampleModel(prefix_words, vocab):
+  """Predict next words using the given prefix words.
+
+  Args:
+    prefix_words: Prefix words.
+    vocab: Vocabulary. Contains max word chard id length and converts between
+        words and ids.
+  """
+  targets = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
+  weights = np.ones([BATCH_SIZE, NUM_TIMESTEPS], np.float32)
+
+  sess, t = _LoadModel(FLAGS.pbtxt, FLAGS.ckpt)
+
+  if prefix_words.find('<S>') != 0:
+    prefix_words = '<S> ' + prefix_words
+
+  prefix = [vocab.word_to_id(w) for w in prefix_words.split()]
+  prefix_char_ids = [vocab.word_to_char_ids(w) for w in prefix_words.split()]
+  for _ in xrange(FLAGS.num_samples):
+    inputs = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
+    char_ids_inputs = np.zeros(
+        [BATCH_SIZE, NUM_TIMESTEPS, vocab.max_word_length], np.int32)
+    samples = prefix[:]
+    char_ids_samples = prefix_char_ids[:]
+    sent = ''
+    while True:
+      inputs[0, 0] = samples[0]
+      char_ids_inputs[0, 0, :] = char_ids_samples[0]
+      samples = samples[1:]
+      char_ids_samples = char_ids_samples[1:]
+
+      softmax = sess.run(t['softmax_out'],
+                         feed_dict={t['char_inputs_in']: char_ids_inputs,
+                                    t['inputs_in']: inputs,
+                                    t['targets_in']: targets,
+                                    t['target_weights_in']: weights})
+
+      sample = _SampleSoftmax(softmax[0])
+      sample_char_ids = vocab.word_to_char_ids(vocab.id_to_word(sample))
+
+      if not samples:
+        samples = [sample]
+        char_ids_samples = [sample_char_ids]
+      sent += vocab.id_to_word(samples[0]) + ' '
+      sys.stderr.write('%s\n' % sent)
+
+      if (vocab.id_to_word(samples[0]) == '</S>' or
+          len(sent) > FLAGS.max_sample_words):
+        break
+
+
+def _DumpEmb(vocab):
+  """Dump the softmax weights and word embeddings to files.
+
+  Args:
+    vocab: Vocabulary. Contains vocabulary size and converts word to ids.
+  """
+  assert FLAGS.save_dir, 'Must specify FLAGS.save_dir for dump_emb.'
+  inputs = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
+  targets = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
+  weights = np.ones([BATCH_SIZE, NUM_TIMESTEPS], np.float32)
+
+  sess, t = _LoadModel(FLAGS.pbtxt, FLAGS.ckpt)
+
+  softmax_weights = sess.run(t['softmax_weights'])
+  fname = FLAGS.save_dir + '/embeddings_softmax.npy'
+  with tf.gfile.Open(fname, mode='w') as f:
+    np.save(f, softmax_weights)
+  sys.stderr.write('Finished softmax weights\n')
+
+  all_embs = np.zeros([vocab.size, 1024])
+  for i in range(vocab.size):
+    input_dict = {t['inputs_in']: inputs,
+                  t['targets_in']: targets,
+                  t['target_weights_in']: weights}
+    if 'char_inputs_in' in t:
+      input_dict[t['char_inputs_in']] = (
+          vocab.word_char_ids[i].reshape([-1, 1, MAX_WORD_LEN]))
+    embs = sess.run(t['all_embs'], input_dict)
+    all_embs[i, :] = embs
+    sys.stderr.write('Finished word embedding %d/%d\n' % (i, vocab.size))
+
+  fname = FLAGS.save_dir + '/embeddings_char_cnn.npy'
+  with tf.gfile.Open(fname, mode='w') as f:
+    np.save(f, all_embs)
+  sys.stderr.write('Embedding file saved\n')
+
+
+def _DumpSentenceEmbedding(sentence, vocab):
+  """Predict next words using the given prefix words.
+
+  Args:
+    sentence: Sentence words.
+    vocab: Vocabulary. Contains max word chard id length and converts between
+        words and ids.
+  """
+  targets = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
+  weights = np.ones([BATCH_SIZE, NUM_TIMESTEPS], np.float32)
+
+  sess, t = _LoadModel(FLAGS.pbtxt, FLAGS.ckpt)
+
+  if sentence.find('<S>') != 0:
+    sentence = '<S> ' + sentence
+
+  word_ids = [vocab.word_to_id(w) for w in sentence.split()]
+  char_ids = [vocab.word_to_char_ids(w) for w in sentence.split()]
+
+  inputs = np.zeros([BATCH_SIZE, NUM_TIMESTEPS], np.int32)
+  char_ids_inputs = np.zeros(
+      [BATCH_SIZE, NUM_TIMESTEPS, vocab.max_word_length], np.int32)
+  for i in xrange(len(word_ids)):
+    inputs[0, 0] = word_ids[i]
+    char_ids_inputs[0, 0, :] = char_ids[i]
+
+    # Add 'lstm/lstm_0/control_dependency' if you want to dump previous layer
+    # LSTM.
+    lstm_emb = sess.run(t['lstm/lstm_1/control_dependency'],
+                        feed_dict={t['char_inputs_in']: char_ids_inputs,
+                                   t['inputs_in']: inputs,
+                                   t['targets_in']: targets,
+                                   t['target_weights_in']: weights})
+
+    fname = os.path.join(FLAGS.save_dir, 'lstm_emb_step_%d.npy' % i)
+    with tf.gfile.Open(fname, mode='w') as f:
+      np.save(f, lstm_emb)
+    sys.stderr.write('LSTM embedding step %d file saved\n' % i)
+
+
+def main(unused_argv):
+  vocab = data_utils.CharsVocabulary(FLAGS.vocab_file, MAX_WORD_LEN)
+
+  if FLAGS.mode == 'eval':
+    dataset = data_utils.LM1BDataset(FLAGS.input_data, vocab)
+    _EvalModel(dataset)
+  elif FLAGS.mode == 'sample':
+    _SampleModel(FLAGS.prefix, vocab)
+  elif FLAGS.mode == 'dump_emb':
+    _DumpEmb(vocab)
+  elif FLAGS.mode == 'dump_lstm_emb':
+    _DumpSentenceEmbedding(FLAGS.sentence, vocab)
+  else:
+    raise Exception('Mode not supported.')
+
+
+if __name__ == '__main__':
+  tf.app.run()