Browse Source

TF implementation of Skip Thoughts.

Christopher Shallue 8 years ago
parent
commit
68609ca78a

+ 1 - 0
README.md

@@ -21,6 +21,7 @@ To propose a model for inclusion please submit a pull request.
 - [next_frame_prediction](next_frame_prediction): probabilistic future frame synthesis via cross convolutional networks.
 - [real_nvp](real_nvp): density estimation using real-valued non-volume preserving (real NVP) transformations.
 - [resnet](resnet): deep and wide residual networks.
+- [skip_thoughts](skip_thoughts): recurrent neural network sentence-to-vector encoder.
 - [slim](slim): image classification models in TF-Slim.
 - [street](street): identify the name of a street (in France) from an image using a Deep RNN.
 - [swivel](swivel): the Swivel algorithm for generating word embeddings.

+ 8 - 0
skip_thoughts/.gitignore

@@ -0,0 +1,8 @@
+/bazel-bin
+/bazel-ci_build-cache
+/bazel-genfiles
+/bazel-out
+/bazel-skip_thoughts
+/bazel-testlogs
+/bazel-tf
+*.pyc

+ 471 - 0
skip_thoughts/README.md

@@ -0,0 +1,471 @@
+# Skip-Thought Vectors
+
+This is a TensorFlow implementation of the model described in:
+
+Ryan Kiros, Yukun Zhu, Ruslan Salakhutdinov, Richard S. Zemel,
+Antonio Torralba, Raquel Urtasun, Sanja Fidler.
+[Skip-Thought Vectors](https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf).
+*In NIPS, 2015.*
+
+
+## Contact
+***Code author:*** Chris Shallue
+
+***Pull requests and issues:*** @cshallue
+
+## Contents
+* [Model Overview](#model-overview)
+* [Getting Started](#getting-started)
+    * [Install Required Packages](#install-required-packages)
+    * [Download Pretrained Models (Optional)](#download-pretrained-models-optional)
+* [Training a Model](#training-a-model)
+    * [Prepare the Training Data](#prepare-the-training-data)
+    * [Run the Training Script](#run-the-training-script)
+    * [Track Training Progress](#track-training-progress)
+* [Expanding the Vocabulary](#expanding-the-vocabulary)
+    * [Overview](#overview)
+    * [Preparation](#preparation)
+    * [Run the Vocabulary Expansion Script](#run-the-vocabulary-expansion-script)
+* [Evaluating a Model](#evaluating-a-model)
+    * [Overview](#overview-1)
+    * [Preparation](#preparation-1)
+    * [Run the Evaluation Tasks](#run-the-evaluation-tasks)
+* [Encoding Sentences](#encoding-sentences)
+
+## Model overview
+
+The *Skip-Thoughts* model is a sentence encoder. It learns to encode input
+sentences into a fixed-dimensional vector representation that is useful for many
+tasks, for example to detect paraphrases or to classify whether a product review
+is positive or negative. See the
+[Skip-Thought Vectors](https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf)
+paper for details of the model architecture and more example applications.
+
+A trained *Skip-Thoughts* model will encode similar sentences nearby each other
+in the embedding vector space. The following examples show the nearest neighbor by
+cosine similarity of some sentences from the
+[movie review dataset](https://www.cs.cornell.edu/people/pabo/movie-review-data/).
+
+
+| Input sentence | Nearest Neighbor |
+|----------------|------------------|
+| Simplistic, silly and tedious. | Trite, banal, cliched, mostly inoffensive. |
+| Not so much farcical as sour. | Not only unfunny, but downright repellent. |
+| A sensitive and astute first feature by Anne-Sophie Birot. | Absorbing character study by André Turpin . |
+| An enthralling, entertaining feature. |  A slick, engrossing melodrama. |
+
+## Getting Started
+
+### Install Required Packages
+First ensure that you have installed the following required packages:
+
+* **Bazel** ([instructions](http://bazel.build/docs/install.html))
+* **TensorFlow** ([instructions](https://www.tensorflow.org/install/))
+* **NumPy** ([instructions](http://www.scipy.org/install.html))
+* **scikit-learn** ([instructions](http://scikit-learn.org/stable/install.html))
+* **Natural Language Toolkit (NLTK)**
+    * First install NLTK ([instructions](http://www.nltk.org/install.html))
+    * Then install the NLTK data ([instructions](http://www.nltk.org/data.html))
+* **gensim** ([instructions](https://radimrehurek.com/gensim/install.html))
+    * Only required if you will be expanding your vocabulary with the [word2vec](https://code.google.com/archive/p/word2vec/) model.
+
+
+### Download Pretrained Models (Optional)
+
+You can download model checkpoints pretrained on the
+[BookCorpus](http://yknzhu.wixsite.com/mbweb) dataset in the following
+configurations:
+
+* Unidirectional RNN encoder ("uni-skip" in the paper)
+* Bidirectional RNN encoder ("bi-skip" in the paper)
+
+```shell
+# Directory to download the pretrained models to.
+PRETRAINED_MODELS_DIR="${HOME}/skip_thoughts/pretrained/"
+
+mkdir -p ${PRETRAINED_MODELS_DIR}
+cd ${PRETRAINED_MODELS_DIR}
+
+# Download and extract the unidirectional model.
+wget "http://download.tensorflow.org/models/skip_thoughts_uni_2017_02_02.tar.gz"
+tar -xvf skip_thoughts_uni_2017_02_02.tar
+rm skip_thoughts_uni_2017_02_02.tar
+
+# Download and extract the bidirectional model.
+wget "http://download.tensorflow.org/models/skip_thoughts_bi_2017_02_16.tar.gz"
+tar -xvf skip_thoughts_bi_2017_02_16.tar
+rm skip_thoughts_bi_2017_02_16.tar
+```
+
+You can now skip to the sections [Evaluating a Model](#evaluating-a-model) and
+[Encoding Sentences](#encoding-sentences).
+
+
+## Training a Model
+
+### Prepare the Training Data
+
+To train a model you will need to provide training data in TFRecord format. The
+TFRecord format consists of a set of sharded files containing serialized
+`tf.Example` protocol buffers. Each `tf.Example` proto contains three
+sentences:
+
+  * `encode`: The sentence to encode.
+  * `decode_pre`: The sentence preceding `encode` in the original text.
+  * `decode_post`: The sentence following `encode` in the original text.
+
+Each sentence is a list of words. During preprocessing, a dictionary is created
+that assigns each word in the vocabulary to an integer-valued id. Each sentence
+is encoded as a list of integer word ids in the `tf.Example` protos.
+
+We have provided a script to preprocess any set of text-files into this format.
+You may wish to use the [BookCorpus](http://yknzhu.wixsite.com/mbweb) dataset.
+Note that the preprocessing script may take **12 hours** or more to complete
+on this large dataset.
+
+```shell
+# Comma-separated list of globs matching the input input files. The format of
+# the input files is assumed to be a list of newline-separated sentences, where
+# each sentence is already tokenized.
+INPUT_FILES="${HOME}/skip_thoughts/bookcorpus/*.txt"
+
+# Location to save the preprocessed training and validation data.
+DATA_DIR="${HOME}/skip_thoughts/data"
+
+# Build the preprocessing script.
+bazel build -c opt skip_thoughts/data/preprocess_dataset
+
+# Run the preprocessing script.
+bazel-bin/skip_thoughts/data/preprocess_dataset \
+  --input_files=${INPUT_FILES} \
+  --output_dir=${DATA_DIR}
+```
+
+When the script finishes you will find 100 training files and 1 validation file
+in `DATA_DIR`. The files will match the patterns `train-?????-of-00100` and
+`validation-00000-of-00001` respectively.
+
+The script will also produce a file named `vocab.txt`. The format of this file
+is a list of newline-separated words where the word id is the corresponding 0-
+based line index. Words are sorted by descending order of frequency in the input
+data. Only the top 20,000 words are assigned unique ids; all other words are
+assigned the "unknown id" of 1 in the processed data.
+
+### Run the Training Script
+
+Execute the following commands to start the training script. By default it will
+run for 500k steps (around 9 days on a GeForce GTX 1080 GPU).
+
+```shell
+# Directory containing the preprocessed data.
+DATA_DIR="${HOME}/skip_thoughts/data"
+
+# Directory to save the model.
+MODEL_DIR="${HOME}/skip_thoughts/model"
+
+# Build the model.
+bazel build -c opt skip_thoughts/...
+
+# Run the training script.
+bazel-bin/skip_thoughts/train \
+  --input_file_pattern="${DATA_DIR}/train-?????-of-00100" \
+  --train_dir="${MODEL_DIR}/train"
+```
+
+### Track Training Progress
+
+Optionally, you can run the `track_perplexity` script in a separate process.
+This will log per-word perplexity on the validation set which allows training
+progress to be monitored on
+[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
+
+Note that you may run out of memory if you run the this script on the same GPU
+as the training script. You can set the environment variable
+`CUDA_VISIBLE_DEVICES=""` to force the script to run on CPU. If it runs too
+slowly on CPU, you can decrease the value of `--num_eval_examples`.
+
+```shell
+DATA_DIR="${HOME}/skip_thoughts/data"
+MODEL_DIR="${HOME}/skip_thoughts/model"
+
+# Ignore GPU devices (only necessary if your GPU is currently memory
+# constrained, for example, by running the training script).
+export CUDA_VISIBLE_DEVICES=""
+
+# Run the evaluation script. This will run in a loop, periodically loading the
+# latest model checkpoint file and computing evaluation metrics.
+bazel-bin/skip_thoughts/track_perplexity \
+  --input_file_pattern="${DATA_DIR}/validation-?????-of-00001" \
+  --checkpoint_dir="${MODEL_DIR}/train" \
+  --eval_dir="${MODEL_DIR}/val" \
+  --num_eval_examples=50000
+```
+
+If you started the `track_perplexity` script, run a
+[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard)
+server in a separate process for real-time monitoring of training summaries and
+validation perplexity.
+
+```shell
+MODEL_DIR="${HOME}/skip_thoughts/model"
+
+# Run a TensorBoard server.
+tensorboard --logdir="${MODEL_DIR}"
+```
+
+## Expanding the Vocabulary
+
+### Overview
+
+The vocabulary generated by the preprocessing script contains only 20,000 words
+which is insufficient for many tasks. For example, a sentence from Wikipedia
+might contain nouns that do not appear in this vocabulary.
+
+A solution to this problem described in the
+[Skip-Thought Vectors](https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf)
+paper is to learn a mapping that transfers word representations from one model to
+another. This idea is based on the "Translation Matrix" method from the paper
+[Exploiting Similarities Among Languages for Machine Translation](https://arxiv.org/abs/1309.4168).
+
+
+Specifically, we will load the word embeddings from a trained *Skip-Thoughts*
+model and from a trained [word2vec model](https://arxiv.org/pdf/1301.3781.pdf)
+(which has a much larger vocabulary). We will train a linear regression model
+without regularization to learn a linear mapping from the word2vec embedding
+space to the *Skip-Thoughts* embedding space. We will then apply the linear
+model to all words in the word2vec vocabulary, yielding vectors in the *Skip-
+Thoughts* word embedding space for the union of the two vocabularies.
+
+The linear regression task is to learn a parameter matrix *W* to minimize
+*|| X - Y \* W ||<sup>2</sup>*, where *X* is a matrix of *Skip-Thoughts*
+embeddings of shape `[num_words, dim1]`, *Y* is a matrix of word2vec embeddings
+of shape `[num_words, dim2]`, and *W* is a matrix of shape `[dim2, dim1]`.
+
+### Preparation
+
+First you will need to download and unpack a pretrained
+[word2vec model](https://arxiv.org/pdf/1301.3781.pdf) from
+[this website](https://code.google.com/archive/p/word2vec/)
+([direct download link](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing)).
+This model was trained on the Google News dataset (about 100 billion words).
+
+
+Also ensure that you have already [installed gensim](https://radimrehurek.com/gensim/install.html).
+
+### Run the Vocabulary Expansion Script
+
+```shell
+# Path to checkpoint file or a directory containing checkpoint files (the script
+# will select the most recent).
+CHECKPOINT_PATH="${HOME}/skip_thoughts/model/train"
+
+# Vocabulary file generated by the preprocessing script.
+SKIP_THOUGHTS_VOCAB="${HOME}/skip_thoughts/data/vocab.txt"
+
+# Path to downloaded word2vec model.
+WORD2VEC_MODEL="${HOME}/skip_thoughts/googlenews/GoogleNews-vectors-negative300.bin"
+
+# Output directory.
+EXP_VOCAB_DIR="${HOME}/skip_thoughts/exp_vocab"
+
+# Build the vocabulary expansion script.
+bazel build -c opt skip_thoughts/vocabulary_expansion
+
+# Run the vocabulary expansion script.
+bazel-bin/skip_thoughts/vocabulary_expansion \
+  --skip_thoughts_model=${CHECKPOINT_PATH} \
+  --skip_thoughts_vocab=${SKIP_THOUGHTS_VOCAB} \
+  --word2vec_model=${WORD2VEC_MODEL} \
+  --output_dir=${EXP_VOCAB_DIR}
+```
+
+## Evaluating a Model
+
+### Overview
+
+The model can be evaluated using the benchmark tasks described in the
+[Skip-Thought Vectors](https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf)
+paper. The following tasks are suported (refer to the paper for full details):
+
+ * **SICK** semantic relatedness task.
+ * **MSRP** (Microsoft Research Paraphrase Corpus) paraphrase detection task.
+ * Binary classification tasks:
+   * **MR** movie review sentiment task.
+   * **CR** customer product review task.
+   * **SUBJ** subjectivity/objectivity task.
+   * **MPQA** opinion polarity task.
+   * **TREC** question-type classification task.
+
+### Preparation
+
+You will need to clone or download the
+[skip-thoughts GitHub repository](https://github.com/ryankiros/skip-thoughts) by
+[ryankiros](https://github.com/ryankiros) (the first author of the Skip-Thoughts
+paper):
+
+```shell
+# Folder to clone the repository to.
+ST_KIROS_DIR="${HOME}/skip_thoughts/skipthoughts_kiros"
+
+# Clone the repository.
+git clone git@github.com:ryankiros/skip-thoughts.git "${ST_KIROS_DIR}/skipthoughts"
+
+# Make the package importable.
+export PYTHONPATH="${ST_KIROS_DIR}/:${PYTHONPATH}"
+```
+
+You will also need to download the data needed for each evaluation task. See the
+instructions [here](https://github.com/ryankiros/skip-thoughts).
+
+For example, the CR (customer review) dataset is found [here](http://nlp.stanford.edu/~sidaw/home/projects:nbsvm). For this task we want the
+files `custrev.pos` and `custrev.neg`.
+
+### Run the Evaluation Tasks
+
+In the following example we will evaluate a unidirectional model ("uni-skip" in
+the paper) on the CR task. To use a bidirectional model ("bi-skip" in the
+paper),  simply pass the flags `--bi_vocab_file`, `--bi_embeddings_file` and
+`--bi_checkpoint_path` instead. To use the "combine-skip" model described in the
+paper you will need to pass both the unidirectional and bidirectional flags.
+
+```shell
+# Path to checkpoint file or a directory containing checkpoint files (the script
+# will select the most recent).
+CHECKPOINT_PATH="${HOME}/skip_thoughts/model/train"
+
+# Vocabulary file generated by the vocabulary expansion script.
+VOCAB_FILE="${HOME}/skip_thoughts/exp_vocab/vocab.txt"
+
+# Embeddings file generated by the vocabulary expansion script.
+EMBEDDINGS_FILE="${HOME}/skip_thoughts/exp_vocab/embeddings.npy"
+
+# Directory containing files custrev.pos and custrev.neg.
+EVAL_DATA_DIR="${HOME}/skip_thoughts/eval_data"
+
+# Build the evaluation script.
+bazel build -c opt skip_thoughts/evaluate
+
+# Run the evaluation script.
+bazel-bin/skip_thoughts/evaluate \
+  --eval_task=CR \
+  --data_dir=${EVAL_DATA_DIR} \
+  --uni_vocab_file=${VOCAB_FILE} \
+  --uni_embeddings_file=${EMBEDDINGS_FILE} \
+  --uni_checkpoint_path=${CHECKPOINT_PATH}
+```
+
+Output:
+
+```python
+[0.82539682539682535, 0.84084880636604775, 0.83023872679045096,
+ 0.86206896551724133, 0.83554376657824936, 0.85676392572944293,
+ 0.84084880636604775, 0.83023872679045096, 0.85145888594164454,
+ 0.82758620689655171]
+```
+
+The output is a list of accuracies of 10 cross-validation classification models.
+To get a single number, simply take the average:
+
+```python
+ipython  # Launch iPython.
+
+In [0]:
+import numpy as np
+np.mean([0.82539682539682535, 0.84084880636604775, 0.83023872679045096,
+         0.86206896551724133, 0.83554376657824936, 0.85676392572944293,
+         0.84084880636604775, 0.83023872679045096, 0.85145888594164454,
+         0.82758620689655171])
+
+Out [0]: 0.84009936423729525
+```
+
+## Encoding Sentences
+
+In this example we will encode data from the
+[movie review dataset](https://www.cs.cornell.edu/people/pabo/movie-review-data/)
+(specifically the [sentence polarity dataset v1.0](https://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz)).
+
+```python
+ipython  # Launch iPython.
+
+In [0]:
+
+# Imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+import os.path
+import scipy.spatial.distance as sd
+from skip_thoughts import configuration
+from skip_thoughts import combined_encoder
+
+In [1]:
+# Set paths to the model.
+VOCAB_FILE = "/path/to/vocab.txt"
+EMBEDDING_MATRIX_FILE = "/path/to/embeddings.npy"
+CHECKPOINT_PATH = "/path/to/model.ckpt-9999"
+# The following directory should contain files rt-polarity.neg and
+# rt-polarity.pos.
+MR_DATA_DIR = "/dir/containing/mr/data"
+
+In [2]:
+# Set up the encoder. Here we are using a single unidirectional model.
+# To use a bidirectional model as well, call load_encoder() again with
+# configuration.ModelConfig(bidirectional_encoder=True) and paths to the
+# bidirectional model's files. The encoder will use the concatenation of
+# all loaded models.
+encoder = combined_encoder.CombinedEncoder()
+encoder.load_encoder(configuration.ModelConfig(),
+                     vocabulary_file=VOCAB_FILE,
+                     embedding_matrix_file=EMBEDDING_MATRIX_FILE,
+                     checkpoint_path=CHECKPOINT_PATH)
+
+In [3]:
+# Load the movie review dataset.
+data = []
+with open(os.path.join(MR_DATA_DIR, 'rt-polarity.neg'), 'rb') as f:
+  data.extend([line.decode('latin-1').strip() for line in f])
+with open(os.path.join(MR_DATA_DIR, 'rt-polarity.pos'), 'rb') as f:
+  data.extend([line.decode('latin-1').strip() for line in f])
+
+In [4]:
+# Generate Skip-Thought Vectors for each sentence in the dataset.
+encodings = encoder.encode(data)
+
+In [5]:
+# Define a helper function to generate nearest neighbors.
+def get_nn(ind, num=10):
+  encoding = encodings[ind]
+  scores = sd.cdist([encoding], encodings, "cosine")[0]
+  sorted_ids = np.argsort(scores)
+  print("Sentence:")
+  print("", data[ind])
+  print("\nNearest neighbors:")
+  for i in range(1, num + 1):
+    print(" %d. %s (%.3f)" %
+          (i, data[sorted_ids[i]], scores[sorted_ids[i]]))
+
+In [6]:
+# Compute nearest neighbors of the first sentence in the dataset.
+get_nn(0)
+```
+
+Output:
+
+```
+Sentence:
+ simplistic , silly and tedious .
+
+Nearest neighbors:
+ 1. trite , banal , cliched , mostly inoffensive . (0.247)
+ 2. banal and predictable . (0.253)
+ 3. witless , pointless , tasteless and idiotic . (0.272)
+ 4. loud , silly , stupid and pointless . (0.295)
+ 5. grating and tedious . (0.299)
+ 6. idiotic and ugly . (0.330)
+ 7. black-and-white and unrealistic . (0.335)
+ 8. hopelessly inane , humorless and under-inspired . (0.335)
+ 9. shallow , noisy and pretentious . (0.340)
+ 10. . . . unlikable , uninteresting , unfunny , and completely , utterly inept . (0.346)
+```

+ 0 - 0
skip_thoughts/WORKSPACE


+ 94 - 0
skip_thoughts/skip_thoughts/BUILD

@@ -0,0 +1,94 @@
+package(default_visibility = [":internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package_group(
+    name = "internal",
+    packages = [
+        "//skip_thoughts/...",
+    ],
+)
+
+py_library(
+    name = "configuration",
+    srcs = ["configuration.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "skip_thoughts_model",
+    srcs = ["skip_thoughts_model.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//skip_thoughts/ops:gru_cell",
+        "//skip_thoughts/ops:input_ops",
+    ],
+)
+
+py_test(
+    name = "skip_thoughts_model_test",
+    size = "large",
+    srcs = ["skip_thoughts_model_test.py"],
+    deps = [
+        ":configuration",
+        ":skip_thoughts_model",
+    ],
+)
+
+py_binary(
+    name = "train",
+    srcs = ["train.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":configuration",
+        ":skip_thoughts_model",
+    ],
+)
+
+py_binary(
+    name = "track_perplexity",
+    srcs = ["track_perplexity.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":configuration",
+        ":skip_thoughts_model",
+    ],
+)
+
+py_binary(
+    name = "vocabulary_expansion",
+    srcs = ["vocabulary_expansion.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "skip_thoughts_encoder",
+    srcs = ["skip_thoughts_encoder.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":skip_thoughts_model",
+        "//skip_thoughts/data:special_words",
+    ],
+)
+
+py_library(
+    name = "encoder_manager",
+    srcs = ["encoder_manager.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":skip_thoughts_encoder",
+    ],
+)
+
+py_binary(
+    name = "evaluate",
+    srcs = ["evaluate.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":encoder_manager",
+        "//skip_thoughts:configuration",
+    ],
+)
+

+ 0 - 0
skip_thoughts/skip_thoughts/__init__.py


+ 110 - 0
skip_thoughts/skip_thoughts/configuration.py

@@ -0,0 +1,110 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Default configuration for model architecture and training."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+class _HParams(object):
+  """Wrapper for configuration parameters."""
+  pass
+
+
+def model_config(input_file_pattern=None,
+                 input_queue_capacity=640000,
+                 num_input_reader_threads=1,
+                 shuffle_input_data=True,
+                 uniform_init_scale=0.1,
+                 vocab_size=20000,
+                 batch_size=128,
+                 word_embedding_dim=620,
+                 bidirectional_encoder=False,
+                 encoder_dim=2400):
+  """Creates a model configuration object.
+
+  Args:
+    input_file_pattern: File pattern of sharded TFRecord files containing
+      tf.Example protobufs.
+    input_queue_capacity: Number of examples to keep in the input queue.
+    num_input_reader_threads: Number of threads for prefetching input
+      tf.Examples.
+    shuffle_input_data: Whether to shuffle the input data.
+    uniform_init_scale: Scale of random uniform initializer.
+    vocab_size: Number of unique words in the vocab.
+    batch_size: Batch size (training and evaluation only).
+    word_embedding_dim: Word embedding dimension.
+    bidirectional_encoder: Whether to use a bidirectional or unidirectional
+      encoder RNN.
+    encoder_dim: Number of output dimensions of the sentence encoder.
+
+  Returns:
+    An object containing model configuration parameters.
+  """
+  config = _HParams()
+  config.input_file_pattern = input_file_pattern
+  config.input_queue_capacity = input_queue_capacity
+  config.num_input_reader_threads = num_input_reader_threads
+  config.shuffle_input_data = shuffle_input_data
+  config.uniform_init_scale = uniform_init_scale
+  config.vocab_size = vocab_size
+  config.batch_size = batch_size
+  config.word_embedding_dim = word_embedding_dim
+  config.bidirectional_encoder = bidirectional_encoder
+  config.encoder_dim = encoder_dim
+  return config
+
+
+def training_config(learning_rate=0.0008,
+                    learning_rate_decay_factor=0.5,
+                    learning_rate_decay_steps=400000,
+                    number_of_steps=500000,
+                    clip_gradient_norm=5.0,
+                    save_model_secs=600,
+                    save_summaries_secs=600):
+  """Creates a training configuration object.
+
+  Args:
+    learning_rate: Initial learning rate.
+    learning_rate_decay_factor: If > 0, the learning rate decay factor.
+    learning_rate_decay_steps: The number of steps before the learning rate
+      decays by learning_rate_decay_factor.
+    number_of_steps: The total number of training steps to run. Passing None
+      will cause the training script to run indefinitely.
+    clip_gradient_norm: If not None, then clip gradients to this value.
+    save_model_secs: How often (in seconds) to save model checkpoints.
+    save_summaries_secs: How often (in seconds) to save model summaries.
+
+  Returns:
+    An object containing training configuration parameters.
+
+  Raises:
+    ValueError: If learning_rate_decay_factor is set and
+      learning_rate_decay_steps is unset.
+  """
+  if learning_rate_decay_factor and not learning_rate_decay_steps:
+    raise ValueError(
+        "learning_rate_decay_factor requires learning_rate_decay_steps.")
+
+  config = _HParams()
+  config.learning_rate = learning_rate
+  config.learning_rate_decay_factor = learning_rate_decay_factor
+  config.learning_rate_decay_steps = learning_rate_decay_steps
+  config.number_of_steps = number_of_steps
+  config.clip_gradient_norm = clip_gradient_norm
+  config.save_model_secs = save_model_secs
+  config.save_summaries_secs = save_summaries_secs
+  return config

+ 23 - 0
skip_thoughts/skip_thoughts/data/BUILD

@@ -0,0 +1,23 @@
+package(default_visibility = ["//skip_thoughts:internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+    name = "special_words",
+    srcs = ["special_words.py"],
+    srcs_version = "PY2AND3",
+    deps = [],
+)
+
+py_binary(
+    name = "preprocess_dataset",
+    srcs = [
+        "preprocess_dataset.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":special_words",
+    ],
+)

+ 0 - 0
skip_thoughts/skip_thoughts/data/__init__.py


+ 301 - 0
skip_thoughts/skip_thoughts/data/preprocess_dataset.py

@@ -0,0 +1,301 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converts a set of text files to TFRecord format with Example protos.
+
+Each Example proto in the output contains the following fields:
+
+  decode_pre: list of int64 ids corresponding to the "previous" sentence.
+  encode: list of int64 ids corresponding to the "current" sentence.
+  decode_post: list of int64 ids corresponding to the "post" sentence.
+
+In addition, the following files are generated:
+
+  vocab.txt: List of "<word> <id>" pairs, where <id> is the integer
+             encoding of <word> in the Example protos.
+  word_counts.txt: List of "<word> <count>" pairs, where <count> is the number
+                   of occurrences of <word> in the input files.
+
+The vocabulary of word ids is constructed from the top --num_words by word
+count. All other words get the <unk> word id.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+
+
+import numpy as np
+import tensorflow as tf
+
+from skip_thoughts.data import special_words
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_string("input_files", None,
+                       "Comma-separated list of globs matching the input "
+                       "files. The format of the input files is assumed to be "
+                       "a list of newline-separated sentences, where each "
+                       "sentence is already tokenized.")
+
+tf.flags.DEFINE_string("vocab_file", "",
+                       "(Optional) existing vocab file. Otherwise, a new vocab "
+                       "file is created and written to the output directory. "
+                       "The file format is a list of newline-separated words, "
+                       "where the word id is the corresponding 0-based index "
+                       "in the file.")
+
+tf.flags.DEFINE_string("output_dir", None, "Output directory.")
+
+tf.flags.DEFINE_integer("train_output_shards", 100,
+                        "Number of output shards for the training set.")
+
+tf.flags.DEFINE_integer("validation_output_shards", 1,
+                        "Number of output shards for the validation set.")
+
+tf.flags.DEFINE_integer("num_validation_sentences", 50000,
+                        "Number of output shards for the validation set.")
+
+tf.flags.DEFINE_integer("num_words", 20000,
+                        "Number of words to include in the output.")
+
+tf.flags.DEFINE_integer("max_sentences", 0,
+                        "If > 0, the maximum number of sentences to output.")
+
+tf.flags.DEFINE_integer("max_sentence_length", 30,
+                        "If > 0, exclude sentences whose encode, decode_pre OR"
+                        "decode_post sentence exceeds this length.")
+
+tf.flags.DEFINE_boolean("add_eos", True,
+                        "Whether to add end-of-sentence ids to the output.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def _build_vocabulary(input_files):
+  """Loads or builds the model vocabulary.
+
+  Args:
+    input_files: List of pre-tokenized input .txt files.
+
+  Returns:
+    vocab: A dictionary of word to id.
+  """
+  if FLAGS.vocab_file:
+    tf.logging.info("Loading existing vocab file.")
+    vocab = collections.OrderedDict()
+    with tf.gfile.GFile(FLAGS.vocab_file, mode="r") as f:
+      for i, line in enumerate(f):
+        word = line.decode("utf-8").strip()
+        assert word not in vocab, "Attempting to add word twice: %s" % word
+        vocab[word] = i
+    tf.logging.info("Read vocab of size %d from %s",
+                    len(vocab), FLAGS.vocab_file)
+    return vocab
+
+  tf.logging.info("Creating vocabulary.")
+  num = 0
+  wordcount = collections.Counter()
+  for input_file in input_files:
+    tf.logging.info("Processing file: %s", input_file)
+    for sentence in tf.gfile.FastGFile(input_file):
+      wordcount.update(sentence.split())
+
+      num += 1
+      if num % 1000000 == 0:
+        tf.logging.info("Processed %d sentences", num)
+
+  tf.logging.info("Processed %d sentences total", num)
+
+  words = wordcount.keys()
+  freqs = wordcount.values()
+  sorted_indices = np.argsort(freqs)[::-1]
+
+  vocab = collections.OrderedDict()
+  vocab[special_words.EOS] = special_words.EOS_ID
+  vocab[special_words.UNK] = special_words.UNK_ID
+  for w_id, w_index in enumerate(sorted_indices[0:FLAGS.num_words - 2]):
+    vocab[words[w_index]] = w_id + 2  # 0: EOS, 1: UNK.
+
+  tf.logging.info("Created vocab with %d words", len(vocab))
+
+  vocab_file = os.path.join(FLAGS.output_dir, "vocab.txt")
+  with tf.gfile.FastGFile(vocab_file, "w") as f:
+    f.write("\n".join(vocab.keys()))
+  tf.logging.info("Wrote vocab file to %s", vocab_file)
+
+  word_counts_file = os.path.join(FLAGS.output_dir, "word_counts.txt")
+  with tf.gfile.FastGFile(word_counts_file, "w") as f:
+    for i in sorted_indices:
+      f.write("%s %d\n" % (words[i], freqs[i]))
+  tf.logging.info("Wrote word counts file to %s", word_counts_file)
+
+  return vocab
+
+
+def _int64_feature(value):
+  """Helper for creating an Int64 Feature."""
+  return tf.train.Feature(int64_list=tf.train.Int64List(
+      value=[int(v) for v in value]))
+
+
+def _sentence_to_ids(sentence, vocab):
+  """Helper for converting a sentence (list of words) to a list of ids."""
+  ids = [vocab.get(w, special_words.UNK_ID) for w in sentence]
+  if FLAGS.add_eos:
+    ids.append(special_words.EOS_ID)
+  return ids
+
+
+def _create_serialized_example(predecessor, current, successor, vocab):
+  """Helper for creating a serialized Example proto."""
+  example = tf.train.Example(features=tf.train.Features(feature={
+      "decode_pre": _int64_feature(_sentence_to_ids(predecessor, vocab)),
+      "encode": _int64_feature(_sentence_to_ids(current, vocab)),
+      "decode_post": _int64_feature(_sentence_to_ids(successor, vocab)),
+  }))
+
+  return example.SerializeToString()
+
+
+def _process_input_file(filename, vocab, stats):
+  """Processes the sentences in an input file.
+
+  Args:
+    filename: Path to a pre-tokenized input .txt file.
+    vocab: A dictionary of word to id.
+    stats: A Counter object for statistics.
+
+  Returns:
+    processed: A list of serialized Example protos
+  """
+  tf.logging.info("Processing input file: %s", filename)
+  processed = []
+
+  predecessor = None  # Predecessor sentence (list of words).
+  current = None  # Current sentence (list of words).
+  successor = None  # Successor sentence (list of words).
+
+  for successor_str in tf.gfile.FastGFile(filename):
+    stats.update(["sentences_seen"])
+    successor = successor_str.split()
+
+    # The first 2 sentences per file will be skipped.
+    if predecessor and current and successor:
+      stats.update(["sentences_considered"])
+
+      # Note that we are going to insert <EOS> later, so we only allow
+      # sentences with strictly less than max_sentence_length to pass.
+      if FLAGS.max_sentence_length and (
+          len(predecessor) >= FLAGS.max_sentence_length or len(current) >=
+          FLAGS.max_sentence_length or len(successor) >=
+          FLAGS.max_sentence_length):
+        stats.update(["sentences_too_long"])
+      else:
+        serialized = _create_serialized_example(predecessor, current, successor,
+                                                vocab)
+        processed.append(serialized)
+        stats.update(["sentences_output"])
+
+    predecessor = current
+    current = successor
+
+    sentences_seen = stats["sentences_seen"]
+    sentences_output = stats["sentences_output"]
+    if sentences_seen and sentences_seen % 100000 == 0:
+      tf.logging.info("Processed %d sentences (%d output)", sentences_seen,
+                      sentences_output)
+    if FLAGS.max_sentences and sentences_output >= FLAGS.max_sentences:
+      break
+
+  tf.logging.info("Completed processing file %s", filename)
+  return processed
+
+
+def _write_shard(filename, dataset, indices):
+  """Writes a TFRecord shard."""
+  with tf.python_io.TFRecordWriter(filename) as writer:
+    for j in indices:
+      writer.write(dataset[j])
+
+
+def _write_dataset(name, dataset, indices, num_shards):
+  """Writes a sharded TFRecord dataset.
+
+  Args:
+    name: Name of the dataset (e.g. "train").
+    dataset: List of serialized Example protos.
+    indices: List of indices of 'dataset' to be written.
+    num_shards: The number of output shards.
+  """
+  tf.logging.info("Writing dataset %s", name)
+  borders = np.int32(np.linspace(0, len(indices), num_shards + 1))
+  for i in range(num_shards):
+    filename = os.path.join(FLAGS.output_dir, "%s-%.5d-of-%.5d" % (name, i,
+                                                                   num_shards))
+    shard_indices = indices[borders[i]:borders[i + 1]]
+    _write_shard(filename, dataset, shard_indices)
+    tf.logging.info("Wrote dataset indices [%d, %d) to output shard %s",
+                    borders[i], borders[i + 1], filename)
+  tf.logging.info("Finished writing %d sentences in dataset %s.",
+                  len(indices), name)
+
+
+def main(unused_argv):
+  if not FLAGS.input_files:
+    raise ValueError("--input_files is required.")
+  if not FLAGS.output_dir:
+    raise ValueError("--output_dir is required.")
+
+  if not tf.gfile.IsDirectory(FLAGS.output_dir):
+    tf.gfile.MakeDirs(FLAGS.output_dir)
+
+  input_files = []
+  for pattern in FLAGS.input_files.split(","):
+    match = tf.gfile.Glob(FLAGS.input_files)
+    if not match:
+      raise ValueError("Found no files matching %s" % pattern)
+    input_files.extend(match)
+  tf.logging.info("Found %d input files.", len(input_files))
+
+  vocab = _build_vocabulary(input_files)
+
+  tf.logging.info("Generating dataset.")
+  stats = collections.Counter()
+  dataset = []
+  for filename in input_files:
+    dataset.extend(_process_input_file(filename, vocab, stats))
+    if FLAGS.max_sentences and stats["sentences_output"] >= FLAGS.max_sentences:
+      break
+
+  tf.logging.info("Generated dataset with %d sentences.", len(dataset))
+  for k, v in stats.items():
+    tf.logging.info("%s: %d", k, v)
+
+  tf.logging.info("Shuffling dataset.")
+  np.random.seed(123)
+  shuffled_indices = np.random.permutation(len(dataset))
+  val_indices = shuffled_indices[:FLAGS.num_validation_sentences]
+  train_indices = shuffled_indices[FLAGS.num_validation_sentences:]
+
+  _write_dataset("train", dataset, train_indices, FLAGS.train_output_shards)
+  _write_dataset("validation", dataset, val_indices,
+                 FLAGS.validation_output_shards)
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 27 - 0
skip_thoughts/skip_thoughts/data/special_words.py

@@ -0,0 +1,27 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Special word constants.
+
+NOTE: The ids of the EOS and UNK constants should not be modified. It is assumed
+that these always occupy the first two ids.
+"""
+
+# End of sentence.
+EOS = "<eos>"
+EOS_ID = 0
+
+# Unknown.
+UNK = "<unk>"
+UNK_ID = 1

+ 134 - 0
skip_thoughts/skip_thoughts/encoder_manager.py

@@ -0,0 +1,134 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Manager class for loading and encoding with multiple skip-thoughts models.
+
+If multiple models are loaded at once then the encode() function returns the
+concatenation of the outputs of each model.
+
+Example usage:
+  manager = EncoderManager()
+  manager.load_model(model_config_1, vocabulary_file_1, embedding_matrix_file_1,
+                     checkpoint_path_1)
+  manager.load_model(model_config_2, vocabulary_file_2, embedding_matrix_file_2,
+                     checkpoint_path_2)
+  encodings = manager.encode(data)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+
+import numpy as np
+import tensorflow as tf
+
+from skip_thoughts import skip_thoughts_encoder
+
+
+class EncoderManager(object):
+  """Manager class for loading and encoding with skip-thoughts models."""
+
+  def __init__(self):
+    self.encoders = []
+    self.sessions = []
+
+  def load_model(self, model_config, vocabulary_file, embedding_matrix_file,
+                 checkpoint_path):
+    """Loads a skip-thoughts model.
+
+    Args:
+      model_config: Object containing parameters for building the model.
+      vocabulary_file: Path to vocabulary file containing a list of newline-
+        separated words where the word id is the corresponding 0-based index in
+        the file.
+      embedding_matrix_file: Path to a serialized numpy array of shape
+        [vocab_size, embedding_dim].
+      checkpoint_path: SkipThoughtsModel checkpoint file or a directory
+        containing a checkpoint file.
+    """
+    tf.logging.info("Reading vocabulary from %s", vocabulary_file)
+    with tf.gfile.GFile(vocabulary_file, mode="r") as f:
+      lines = list(f.readlines())
+    reverse_vocab = [line.decode("utf-8").strip() for line in lines]
+    tf.logging.info("Loaded vocabulary with %d words.", len(reverse_vocab))
+
+    tf.logging.info("Loading embedding matrix from %s", embedding_matrix_file)
+    # Note: tf.gfile.GFile doesn't work here because np.load() calls f.seek()
+    # with 3 arguments.
+    with open(embedding_matrix_file, "r") as f:
+      embedding_matrix = np.load(f)
+    tf.logging.info("Loaded embedding matrix with shape %s",
+                    embedding_matrix.shape)
+
+    word_embeddings = collections.OrderedDict(
+        zip(reverse_vocab, embedding_matrix))
+
+    g = tf.Graph()
+    with g.as_default():
+      encoder = skip_thoughts_encoder.SkipThoughtsEncoder(word_embeddings)
+      restore_model = encoder.build_graph_from_config(model_config,
+                                                      checkpoint_path)
+
+    sess = tf.Session(graph=g)
+    restore_model(sess)
+
+    self.encoders.append(encoder)
+    self.sessions.append(sess)
+
+  def encode(self,
+             data,
+             use_norm=True,
+             verbose=False,
+             batch_size=128,
+             use_eos=False):
+    """Encodes a sequence of sentences as skip-thought vectors.
+
+    Args:
+      data: A list of input strings.
+      use_norm: If True, normalize output skip-thought vectors to unit L2 norm.
+      verbose: Whether to log every batch.
+      batch_size: Batch size for the RNN encoders.
+      use_eos: If True, append the end-of-sentence word to each input sentence.
+
+    Returns:
+      thought_vectors: A list of numpy arrays corresponding to 'data'.
+
+    Raises:
+      ValueError: If called before calling load_encoder.
+    """
+    if not self.encoders:
+      raise ValueError(
+          "Must call load_model at least once before calling encode.")
+
+    encoded = []
+    for encoder, sess in zip(self.encoders, self.sessions):
+      encoded.append(
+          np.array(
+              encoder.encode(
+                  sess,
+                  data,
+                  use_norm=use_norm,
+                  verbose=verbose,
+                  batch_size=batch_size,
+                  use_eos=use_eos)))
+
+    return np.concatenate(encoded, axis=1)
+
+  def close(self):
+    """Closes the active TensorFlow Sessions."""
+    for sess in self.sessions:
+      sess.close()

+ 117 - 0
skip_thoughts/skip_thoughts/evaluate.py

@@ -0,0 +1,117 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to evaluate a skip-thoughts model.
+
+This script can evaluate a model with a unidirectional encoder ("uni-skip" in
+the paper); or a model with a bidirectional encoder ("bi-skip"); or the
+combination of a model with a unidirectional encoder and a model with a
+bidirectional encoder ("combine-skip").
+
+The uni-skip model (if it exists) is specified by the flags
+--uni_vocab_file, --uni_embeddings_file, --uni_checkpoint_path.
+
+The bi-skip model (if it exists) is specified by the flags
+--bi_vocab_file, --bi_embeddings_path, --bi_checkpoint_path.
+
+The evaluation tasks have different running times. SICK may take 5-10 minutes.
+MSRP, TREC and CR may take 20-60 minutes. SUBJ, MPQA and MR may take 2+ hours.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from skipthoughts import eval_classification
+from skipthoughts import eval_msrp
+from skipthoughts import eval_sick
+from skipthoughts import eval_trec
+import tensorflow as tf
+
+from skip_thoughts import configuration
+from skip_thoughts import encoder_manager
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_string("eval_task", "CR",
+                       "Name of the evaluation task to run. Available tasks: "
+                       "MR, CR, SUBJ, MPQA, SICK, MSRP, TREC.")
+
+tf.flags.DEFINE_string("data_dir", None, "Directory containing training data.")
+
+tf.flags.DEFINE_string("uni_vocab_file", None,
+                       "Path to vocabulary file containing a list of newline-"
+                       "separated words where the word id is the "
+                       "corresponding 0-based index in the file.")
+tf.flags.DEFINE_string("bi_vocab_file", None,
+                       "Path to vocabulary file containing a list of newline-"
+                       "separated words where the word id is the "
+                       "corresponding 0-based index in the file.")
+
+tf.flags.DEFINE_string("uni_embeddings_file", None,
+                       "Path to serialized numpy array of shape "
+                       "[vocab_size, embedding_dim].")
+tf.flags.DEFINE_string("bi_embeddings_file", None,
+                       "Path to serialized numpy array of shape "
+                       "[vocab_size, embedding_dim].")
+
+tf.flags.DEFINE_string("uni_checkpoint_path", None,
+                       "Checkpoint file or directory containing a checkpoint "
+                       "file.")
+tf.flags.DEFINE_string("bi_checkpoint_path", None,
+                       "Checkpoint file or directory containing a checkpoint "
+                       "file.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def main(unused_argv):
+  if not FLAGS.data_dir:
+    raise ValueError("--data_dir is required.")
+
+  encoder = encoder_manager.EncoderManager()
+
+  # Maybe load unidirectional encoder.
+  if FLAGS.uni_checkpoint_path:
+    print("Loading unidirectional model...")
+    uni_config = configuration.model_config()
+    encoder.load_model(uni_config, FLAGS.uni_vocab_file,
+                       FLAGS.uni_embeddings_file, FLAGS.uni_checkpoint_path)
+
+  # Maybe load bidirectional encoder.
+  if FLAGS.bi_checkpoint_path:
+    print("Loading bidirectional model...")
+    bi_config = configuration.model_config(bidirectional_encoder=True)
+    encoder.load_model(bi_config, FLAGS.bi_vocab_file, FLAGS.bi_embeddings_file,
+                       FLAGS.bi_checkpoint_path)
+
+  if FLAGS.eval_task in ["MR", "CR", "SUBJ", "MPQA"]:
+    eval_classification.eval_nested_kfold(
+        encoder, FLAGS.eval_task, FLAGS.data_dir, use_nb=False)
+  elif FLAGS.eval_task == "SICK":
+    eval_sick.evaluate(encoder, evaltest=True, loc=FLAGS.data_dir)
+  elif FLAGS.eval_task == "MSRP":
+    eval_msrp.evaluate(
+        encoder, evalcv=True, evaltest=True, use_feats=True, loc=FLAGS.data_dir)
+  elif FLAGS.eval_task == "TREC":
+    eval_trec.evaluate(encoder, evalcv=True, evaltest=True, loc=FLAGS.data_dir)
+  else:
+    raise ValueError("Unrecognized eval_task: %s" % FLAGS.eval_task)
+
+  encoder.close()
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 17 - 0
skip_thoughts/skip_thoughts/ops/BUILD

@@ -0,0 +1,17 @@
+package(default_visibility = ["//skip_thoughts:internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+    name = "input_ops",
+    srcs = ["input_ops.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "gru_cell",
+    srcs = ["gru_cell.py"],
+    srcs_version = "PY2AND3",
+)

+ 0 - 0
skip_thoughts/skip_thoughts/ops/__init__.py


+ 134 - 0
skip_thoughts/skip_thoughts/ops/gru_cell.py

@@ -0,0 +1,134 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""GRU cell implementation for the skip-thought vectors model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+_layer_norm = tf.contrib.layers.layer_norm
+
+
+class LayerNormGRUCell(tf.contrib.rnn.RNNCell):
+  """GRU cell with layer normalization.
+
+  The layer normalization implementation is based on:
+
+    https://arxiv.org/abs/1607.06450.
+
+  "Layer Normalization"
+  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
+  """
+
+  def __init__(self,
+               num_units,
+               w_initializer,
+               u_initializer,
+               b_initializer,
+               activation=tf.nn.tanh):
+    """Initializes the cell.
+
+    Args:
+      num_units: Number of cell units.
+      w_initializer: Initializer for the "W" (input) parameter matrices.
+      u_initializer: Initializer for the "U" (recurrent) parameter matrices.
+      b_initializer: Initializer for the "b" (bias) parameter vectors.
+      activation: Cell activation function.
+    """
+    self._num_units = num_units
+    self._w_initializer = w_initializer
+    self._u_initializer = u_initializer
+    self._b_initializer = b_initializer
+    self._activation = activation
+
+  @property
+  def state_size(self):
+    return self._num_units
+
+  @property
+  def output_size(self):
+    return self._num_units
+
+  def _w_h_initializer(self):
+    """Returns an initializer for the "W_h" parameter matrix.
+
+    See equation (23) in the paper. The "W_h" parameter matrix is the
+    concatenation of two parameter submatrices. The matrix returned is
+    [U_z, U_r].
+
+    Returns:
+      A Tensor with shape [num_units, 2 * num_units] as described above.
+    """
+
+    def _initializer(shape, dtype=tf.float32, partition_info=None):
+      num_units = self._num_units
+      assert shape == [num_units, 2 * num_units]
+      u_z = self._u_initializer([num_units, num_units], dtype, partition_info)
+      u_r = self._u_initializer([num_units, num_units], dtype, partition_info)
+      return tf.concat([u_z, u_r], 1)
+
+    return _initializer
+
+  def _w_x_initializer(self, input_dim):
+    """Returns an initializer for the "W_x" parameter matrix.
+
+    See equation (23) in the paper. The "W_x" parameter matrix is the
+    concatenation of two parameter submatrices. The matrix returned is
+    [W_z, W_r].
+
+    Args:
+      input_dim: The dimension of the cell inputs.
+
+    Returns:
+      A Tensor with shape [input_dim, 2 * num_units] as described above.
+    """
+
+    def _initializer(shape, dtype=tf.float32, partition_info=None):
+      num_units = self._num_units
+      assert shape == [input_dim, 2 * num_units]
+      w_z = self._w_initializer([input_dim, num_units], dtype, partition_info)
+      w_r = self._w_initializer([input_dim, num_units], dtype, partition_info)
+      return tf.concat([w_z, w_r], 1)
+
+    return _initializer
+
+  def __call__(self, inputs, state, scope=None):
+    """GRU cell with layer normalization."""
+    input_dim = inputs.get_shape().as_list()[1]
+    num_units = self._num_units
+
+    with tf.variable_scope(scope or "gru_cell"):
+      with tf.variable_scope("gates"):
+        w_h = tf.get_variable(
+            "w_h", [num_units, 2 * num_units],
+            initializer=self._w_h_initializer())
+        w_x = tf.get_variable(
+            "w_x", [input_dim, 2 * num_units],
+            initializer=self._w_x_initializer(input_dim))
+        z_and_r = (_layer_norm(tf.matmul(state, w_h), scope="layer_norm/w_h") +
+                   _layer_norm(tf.matmul(inputs, w_x), scope="layer_norm/w_x"))
+        z, r = tf.split(tf.sigmoid(z_and_r), 2, 1)
+      with tf.variable_scope("candidate"):
+        w = tf.get_variable(
+            "w", [input_dim, num_units], initializer=self._w_initializer)
+        u = tf.get_variable(
+            "u", [num_units, num_units], initializer=self._u_initializer)
+        h_hat = (r * _layer_norm(tf.matmul(state, u), scope="layer_norm/u") +
+                 _layer_norm(tf.matmul(inputs, w), scope="layer_norm/w"))
+      new_h = (1 - z) * state + z * self._activation(h_hat)
+    return new_h, new_h

+ 118 - 0
skip_thoughts/skip_thoughts/ops/input_ops.py

@@ -0,0 +1,118 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Input ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+
+import tensorflow as tf
+
+# A SentenceBatch is a pair of Tensors:
+#  ids: Batch of input sentences represented as sequences of word ids: an int64
+#    Tensor with shape [batch_size, padded_length].
+#  mask: Boolean mask distinguishing real words (1) from padded words (0): an
+#    int32 Tensor with shape [batch_size, padded_length].
+SentenceBatch = collections.namedtuple("SentenceBatch", ("ids", "mask"))
+
+
+def parse_example_batch(serialized):
+  """Parses a batch of tf.Example protos.
+
+  Args:
+    serialized: A 1-D string Tensor; a batch of serialized tf.Example protos.
+  Returns:
+    encode: A SentenceBatch of encode sentences.
+    decode_pre: A SentenceBatch of "previous" sentences to decode.
+    decode_post: A SentenceBatch of "post" sentences to decode.
+  """
+  features = tf.parse_example(
+      serialized,
+      features={
+          "encode": tf.VarLenFeature(dtype=tf.int64),
+          "decode_pre": tf.VarLenFeature(dtype=tf.int64),
+          "decode_post": tf.VarLenFeature(dtype=tf.int64),
+      })
+
+  def _sparse_to_batch(sparse):
+    ids = tf.sparse_tensor_to_dense(sparse)  # Padding with zeroes.
+    mask = tf.sparse_to_dense(sparse.indices, sparse.dense_shape,
+                              tf.ones_like(sparse.values, dtype=tf.int32))
+    return SentenceBatch(ids=ids, mask=mask)
+
+  output_names = ("encode", "decode_pre", "decode_post")
+  return tuple(_sparse_to_batch(features[x]) for x in output_names)
+
+
+def prefetch_input_data(reader,
+                        file_pattern,
+                        shuffle,
+                        capacity,
+                        num_reader_threads=1):
+  """Prefetches string values from disk into an input queue.
+
+  Args:
+    reader: Instance of tf.ReaderBase.
+    file_pattern: Comma-separated list of file patterns (e.g.
+        "/tmp/train_data-?????-of-00100", where '?' acts as a wildcard that
+        matches any character).
+    shuffle: Boolean; whether to randomly shuffle the input data.
+    capacity: Queue capacity (number of records).
+    num_reader_threads: Number of reader threads feeding into the queue.
+
+  Returns:
+    A Queue containing prefetched string values.
+  """
+  data_files = []
+  for pattern in file_pattern.split(","):
+    data_files.extend(tf.gfile.Glob(pattern))
+  if not data_files:
+    tf.logging.fatal("Found no input files matching %s", file_pattern)
+  else:
+    tf.logging.info("Prefetching values from %d files matching %s",
+                    len(data_files), file_pattern)
+
+  filename_queue = tf.train.string_input_producer(
+      data_files, shuffle=shuffle, capacity=16, name="filename_queue")
+
+  if shuffle:
+    min_after_dequeue = int(0.6 * capacity)
+    values_queue = tf.RandomShuffleQueue(
+        capacity=capacity,
+        min_after_dequeue=min_after_dequeue,
+        dtypes=[tf.string],
+        shapes=[[]],
+        name="random_input_queue")
+  else:
+    values_queue = tf.FIFOQueue(
+        capacity=capacity,
+        dtypes=[tf.string],
+        shapes=[[]],
+        name="fifo_input_queue")
+
+  enqueue_ops = []
+  for _ in range(num_reader_threads):
+    _, value = reader.read(filename_queue)
+    enqueue_ops.append(values_queue.enqueue([value]))
+  tf.train.queue_runner.add_queue_runner(
+      tf.train.queue_runner.QueueRunner(values_queue, enqueue_ops))
+  tf.summary.scalar("queue/%s/fraction_of_%d_full" % (values_queue.name,
+                                                      capacity),
+                    tf.cast(values_queue.size(), tf.float32) * (1.0 / capacity))
+
+  return values_queue

+ 258 - 0
skip_thoughts/skip_thoughts/skip_thoughts_encoder.py

@@ -0,0 +1,258 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Class for encoding text using a trained SkipThoughtsModel.
+
+Example usage:
+  g = tf.Graph()
+  with g.as_default():
+    encoder = SkipThoughtsEncoder(embeddings)
+    restore_fn = encoder.build_graph_from_config(model_config, checkpoint_path)
+
+  with tf.Session(graph=g) as sess:
+    restore_fn(sess)
+    skip_thought_vectors = encoder.encode(sess, data)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+
+import nltk
+import nltk.tokenize
+import numpy as np
+import tensorflow as tf
+
+from skip_thoughts import skip_thoughts_model
+from skip_thoughts.data import special_words
+
+
+def _pad(seq, target_len):
+  """Pads a sequence of word embeddings up to the target length.
+
+  Args:
+    seq: Sequence of word embeddings.
+    target_len: Desired padded sequence length.
+
+  Returns:
+    embeddings: Input sequence padded with zero embeddings up to the target
+      length.
+    mask: A 0/1 vector with zeros corresponding to padded embeddings.
+
+  Raises:
+    ValueError: If len(seq) is not in the interval (0, target_len].
+  """
+  seq_len = len(seq)
+  if seq_len <= 0 or seq_len > target_len:
+    raise ValueError("Expected 0 < len(seq) <= %d, got %d" % (target_len,
+                                                              seq_len))
+
+  emb_dim = seq[0].shape[0]
+  padded_seq = np.zeros(shape=(target_len, emb_dim), dtype=seq[0].dtype)
+  mask = np.zeros(shape=(target_len,), dtype=np.int8)
+  for i in range(seq_len):
+    padded_seq[i] = seq[i]
+    mask[i] = 1
+  return padded_seq, mask
+
+
+def _batch_and_pad(sequences):
+  """Batches and pads sequences of word embeddings into a 2D array.
+
+  Args:
+    sequences: A list of batch_size sequences of word embeddings.
+
+  Returns:
+    embeddings: A numpy array with shape [batch_size, padded_length, emb_dim].
+    mask: A numpy 0/1 array with shape [batch_size, padded_length] with zeros
+      corresponding to padded elements.
+  """
+  batch_embeddings = []
+  batch_mask = []
+  batch_len = max([len(seq) for seq in sequences])
+  for seq in sequences:
+    embeddings, mask = _pad(seq, batch_len)
+    batch_embeddings.append(embeddings)
+    batch_mask.append(mask)
+  return np.array(batch_embeddings), np.array(batch_mask)
+
+
+class SkipThoughtsEncoder(object):
+  """Skip-thoughts sentence encoder."""
+
+  def __init__(self, embeddings):
+    """Initializes the encoder.
+
+    Args:
+      embeddings: Dictionary of word to embedding vector (1D numpy array).
+    """
+    self._sentence_detector = nltk.data.load("tokenizers/punkt/english.pickle")
+    self._embeddings = embeddings
+
+  def _create_restore_fn(self, checkpoint_path, saver):
+    """Creates a function that restores a model from checkpoint.
+
+    Args:
+      checkpoint_path: Checkpoint file or a directory containing a checkpoint
+        file.
+      saver: Saver for restoring variables from the checkpoint file.
+
+    Returns:
+      restore_fn: A function such that restore_fn(sess) loads model variables
+        from the checkpoint file.
+
+    Raises:
+      ValueError: If checkpoint_path does not refer to a checkpoint file or a
+        directory containing a checkpoint file.
+    """
+    if tf.gfile.IsDirectory(checkpoint_path):
+      latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
+      if not latest_checkpoint:
+        raise ValueError("No checkpoint file found in: %s" % checkpoint_path)
+      checkpoint_path = latest_checkpoint
+
+    def _restore_fn(sess):
+      tf.logging.info("Loading model from checkpoint: %s", checkpoint_path)
+      saver.restore(sess, checkpoint_path)
+      tf.logging.info("Successfully loaded checkpoint: %s",
+                      os.path.basename(checkpoint_path))
+
+    return _restore_fn
+
+  def build_graph_from_config(self, model_config, checkpoint_path):
+    """Builds the inference graph from a configuration object.
+
+    Args:
+      model_config: Object containing configuration for building the model.
+      checkpoint_path: Checkpoint file or a directory containing a checkpoint
+        file.
+
+    Returns:
+      restore_fn: A function such that restore_fn(sess) loads model variables
+        from the checkpoint file.
+    """
+    tf.logging.info("Building model.")
+    model = skip_thoughts_model.SkipThoughtsModel(model_config, mode="encode")
+    model.build()
+    saver = tf.train.Saver()
+
+    return self._create_restore_fn(checkpoint_path, saver)
+
+  def build_graph_from_proto(self, graph_def_file, saver_def_file,
+                             checkpoint_path):
+    """Builds the inference graph from serialized GraphDef and SaverDef protos.
+
+    Args:
+      graph_def_file: File containing a serialized GraphDef proto.
+      saver_def_file: File containing a serialized SaverDef proto.
+      checkpoint_path: Checkpoint file or a directory containing a checkpoint
+        file.
+
+    Returns:
+      restore_fn: A function such that restore_fn(sess) loads model variables
+        from the checkpoint file.
+    """
+    # Load the Graph.
+    tf.logging.info("Loading GraphDef from file: %s", graph_def_file)
+    graph_def = tf.GraphDef()
+    with tf.gfile.FastGFile(graph_def_file, "rb") as f:
+      graph_def.ParseFromString(f.read())
+    tf.import_graph_def(graph_def, name="")
+
+    # Load the Saver.
+    tf.logging.info("Loading SaverDef from file: %s", saver_def_file)
+    saver_def = tf.train.SaverDef()
+    with tf.gfile.FastGFile(saver_def_file, "rb") as f:
+      saver_def.ParseFromString(f.read())
+    saver = tf.train.Saver(saver_def=saver_def)
+
+    return self._create_restore_fn(checkpoint_path, saver)
+
+  def _tokenize(self, item):
+    """Tokenizes an input string into a list of words."""
+    tokenized = []
+    for s in self._sentence_detector.tokenize(item):
+      tokenized.extend(nltk.tokenize.word_tokenize(s))
+
+    return tokenized
+
+  def _word_to_embedding(self, w):
+    """Returns the embedding of a word."""
+    return self._embeddings.get(w, self._embeddings[special_words.UNK])
+
+  def _preprocess(self, data, use_eos):
+    """Preprocesses text for the encoder.
+
+    Args:
+      data: A list of input strings.
+      use_eos: Whether to append the end-of-sentence word to each sentence.
+
+    Returns:
+      embeddings: A list of word embedding sequences corresponding to the input
+        strings.
+    """
+    preprocessed_data = []
+    for item in data:
+      tokenized = self._tokenize(item)
+      if use_eos:
+        tokenized.append(special_words.EOS)
+      preprocessed_data.append([self._word_to_embedding(w) for w in tokenized])
+    return preprocessed_data
+
+  def encode(self,
+             sess,
+             data,
+             use_norm=True,
+             verbose=True,
+             batch_size=128,
+             use_eos=False):
+    """Encodes a sequence of sentences as skip-thought vectors.
+
+    Args:
+      sess: TensorFlow Session.
+      data: A list of input strings.
+      use_norm: Whether to normalize skip-thought vectors to unit L2 norm.
+      verbose: Whether to log every batch.
+      batch_size: Batch size for the encoder.
+      use_eos: Whether to append the end-of-sentence word to each input
+        sentence.
+
+    Returns:
+      thought_vectors: A list of numpy arrays corresponding to the skip-thought
+        encodings of sentences in 'data'.
+    """
+    data = self._preprocess(data, use_eos)
+    thought_vectors = []
+
+    batch_indices = np.arange(0, len(data), batch_size)
+    for batch, start_index in enumerate(batch_indices):
+      if verbose:
+        tf.logging.info("Batch %d / %d.", batch, len(batch_indices))
+
+      embeddings, mask = _batch_and_pad(
+          data[start_index:start_index + batch_size])
+      feed_dict = {
+          "encode_emb:0": embeddings,
+          "encode_mask:0": mask,
+      }
+      thought_vectors.extend(
+          sess.run("encoder/thought_vectors:0", feed_dict=feed_dict))
+
+    if use_norm:
+      thought_vectors = [v / np.linalg.norm(v) for v in thought_vectors]
+
+    return thought_vectors

+ 369 - 0
skip_thoughts/skip_thoughts/skip_thoughts_model.py

@@ -0,0 +1,369 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Skip-Thoughts model for learning sentence vectors.
+
+The model is based on the paper:
+
+  "Skip-Thought Vectors"
+  Ryan Kiros, Yukun Zhu, Ruslan Salakhutdinov, Richard S. Zemel,
+  Antonio Torralba, Raquel Urtasun, Sanja Fidler.
+  https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf
+
+Layer normalization is applied based on the paper:
+
+  "Layer Normalization"
+  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
+  https://arxiv.org/abs/1607.06450
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from skip_thoughts.ops import gru_cell
+from skip_thoughts.ops import input_ops
+
+
+def random_orthonormal_initializer(shape, dtype=tf.float32,
+                                   partition_info=None):  # pylint: disable=unused-argument
+  """Variable initializer that produces a random orthonormal matrix."""
+  if len(shape) != 2 or shape[0] != shape[1]:
+    raise ValueError("Expecting square shape, got %s" % shape)
+  _, u, _ = tf.svd(tf.random_normal(shape, dtype=dtype), full_matrices=True)
+  return u
+
+
+class SkipThoughtsModel(object):
+  """Skip-thoughts model."""
+
+  def __init__(self, config, mode="train", input_reader=None):
+    """Basic setup. The actual TensorFlow graph is constructed in build().
+
+    Args:
+      config: Object containing configuration parameters.
+      mode: "train", "eval" or "encode".
+      input_reader: Subclass of tf.ReaderBase for reading the input serialized
+        tf.Example protocol buffers. Defaults to TFRecordReader.
+
+    Raises:
+      ValueError: If mode is invalid.
+    """
+    if mode not in ["train", "eval", "encode"]:
+      raise ValueError("Unrecognized mode: %s" % mode)
+
+    self.config = config
+    self.mode = mode
+    self.reader = input_reader if input_reader else tf.TFRecordReader()
+
+    # Initializer used for non-recurrent weights.
+    self.uniform_initializer = tf.random_uniform_initializer(
+        minval=-self.config.uniform_init_scale,
+        maxval=self.config.uniform_init_scale)
+
+    # Input sentences represented as sequences of word ids. "encode" is the
+    # source sentence, "decode_pre" is the previous sentence and "decode_post"
+    # is the next sentence.
+    # Each is an int64 Tensor with  shape [batch_size, padded_length].
+    self.encode_ids = None
+    self.decode_pre_ids = None
+    self.decode_post_ids = None
+
+    # Boolean masks distinguishing real words (1) from padded words (0).
+    # Each is an int32 Tensor with shape [batch_size, padded_length].
+    self.encode_mask = None
+    self.decode_pre_mask = None
+    self.decode_post_mask = None
+
+    # Input sentences represented as sequences of word embeddings.
+    # Each is a float32 Tensor with shape [batch_size, padded_length, emb_dim].
+    self.encode_emb = None
+    self.decode_pre_emb = None
+    self.decode_post_emb = None
+
+    # The output from the sentence encoder.
+    # A float32 Tensor with shape [batch_size, num_gru_units].
+    self.thought_vectors = None
+
+    # The cross entropy losses and corresponding weights of the decoders. Used
+    # for evaluation.
+    self.target_cross_entropy_losses = []
+    self.target_cross_entropy_loss_weights = []
+
+    # The total loss to optimize.
+    self.total_loss = None
+
+  def build_inputs(self):
+    """Builds the ops for reading input data.
+
+    Outputs:
+      self.encode_ids
+      self.decode_pre_ids
+      self.decode_post_ids
+      self.encode_mask
+      self.decode_pre_mask
+      self.decode_post_mask
+    """
+    if self.mode == "encode":
+      # Word embeddings are fed from an external vocabulary which has possibly
+      # been expanded (see vocabulary_expansion.py).
+      encode_ids = None
+      decode_pre_ids = None
+      decode_post_ids = None
+      encode_mask = tf.placeholder(tf.int8, (None, None), name="encode_mask")
+      decode_pre_mask = None
+      decode_post_mask = None
+    else:
+      # Prefetch serialized tf.Example protos.
+      input_queue = input_ops.prefetch_input_data(
+          self.reader,
+          self.config.input_file_pattern,
+          shuffle=self.config.shuffle_input_data,
+          capacity=self.config.input_queue_capacity,
+          num_reader_threads=self.config.num_input_reader_threads)
+
+      # Deserialize a batch.
+      serialized = input_queue.dequeue_many(self.config.batch_size)
+      encode, decode_pre, decode_post = input_ops.parse_example_batch(
+          serialized)
+
+      encode_ids = encode.ids
+      decode_pre_ids = decode_pre.ids
+      decode_post_ids = decode_post.ids
+
+      encode_mask = encode.mask
+      decode_pre_mask = decode_pre.mask
+      decode_post_mask = decode_post.mask
+
+    self.encode_ids = encode_ids
+    self.decode_pre_ids = decode_pre_ids
+    self.decode_post_ids = decode_post_ids
+
+    self.encode_mask = encode_mask
+    self.decode_pre_mask = decode_pre_mask
+    self.decode_post_mask = decode_post_mask
+
+  def build_word_embeddings(self):
+    """Builds the word embeddings.
+
+    Inputs:
+      self.encode_ids
+      self.decode_pre_ids
+      self.decode_post_ids
+
+    Outputs:
+      self.encode_emb
+      self.decode_pre_emb
+      self.decode_post_emb
+    """
+    if self.mode == "encode":
+      # Word embeddings are fed from an external vocabulary which has possibly
+      # been expanded (see vocabulary_expansion.py).
+      encode_emb = tf.placeholder(tf.float32, (
+          None, None, self.config.word_embedding_dim), "encode_emb")
+      # No sequences to decode.
+      decode_pre_emb = None
+      decode_post_emb = None
+    else:
+      word_emb = tf.get_variable(
+          name="word_embedding",
+          shape=[self.config.vocab_size, self.config.word_embedding_dim],
+          initializer=self.uniform_initializer)
+
+      encode_emb = tf.nn.embedding_lookup(word_emb, self.encode_ids)
+      decode_pre_emb = tf.nn.embedding_lookup(word_emb, self.decode_pre_ids)
+      decode_post_emb = tf.nn.embedding_lookup(word_emb, self.decode_post_ids)
+
+    self.encode_emb = encode_emb
+    self.decode_pre_emb = decode_pre_emb
+    self.decode_post_emb = decode_post_emb
+
+  def _initialize_gru_cell(self, num_units):
+    """Initializes a GRU cell.
+
+    The Variables of the GRU cell are initialized in a way that exactly matches
+    the skip-thoughts paper: recurrent weights are initialized from random
+    orthonormal matrices and non-recurrent weights are initialized from random
+    uniform matrices.
+
+    Args:
+      num_units: Number of output units.
+
+    Returns:
+      cell: An instance of RNNCell with variable initializers that match the
+        skip-thoughts paper.
+    """
+    return gru_cell.LayerNormGRUCell(
+        num_units,
+        w_initializer=self.uniform_initializer,
+        u_initializer=random_orthonormal_initializer,
+        b_initializer=tf.constant_initializer(0.0))
+
+  def build_encoder(self):
+    """Builds the sentence encoder.
+
+    Inputs:
+      self.encode_emb
+      self.encode_mask
+
+    Outputs:
+      self.thought_vectors
+
+    Raises:
+      ValueError: if config.bidirectional_encoder is True and config.encoder_dim
+        is odd.
+    """
+    with tf.variable_scope("encoder") as scope:
+      length = tf.to_int32(tf.reduce_sum(self.encode_mask, 1), name="length")
+
+      if self.config.bidirectional_encoder:
+        if self.config.encoder_dim % 2:
+          raise ValueError(
+              "encoder_dim must be even when using a bidirectional encoder.")
+        num_units = self.config.encoder_dim // 2
+        cell_fw = self._initialize_gru_cell(num_units)  # Forward encoder
+        cell_bw = self._initialize_gru_cell(num_units)  # Backward encoder
+        _, states = tf.nn.bidirectional_dynamic_rnn(
+            cell_fw=cell_fw,
+            cell_bw=cell_bw,
+            inputs=self.encode_emb,
+            sequence_length=length,
+            dtype=tf.float32,
+            scope=scope)
+        thought_vectors = tf.concat(states, 1, name="thought_vectors")
+      else:
+        cell = self._initialize_gru_cell(self.config.encoder_dim)
+        _, state = tf.nn.dynamic_rnn(
+            cell=cell,
+            inputs=self.encode_emb,
+            sequence_length=length,
+            dtype=tf.float32,
+            scope=scope)
+        # Use an identity operation to name the Tensor in the Graph.
+        thought_vectors = tf.identity(state, name="thought_vectors")
+
+    self.thought_vectors = thought_vectors
+
+  def _build_decoder(self, name, embeddings, targets, mask, initial_state,
+                     reuse_logits):
+    """Builds a sentence decoder.
+
+    Args:
+      name: Decoder name.
+      embeddings: Batch of sentences to decode; a float32 Tensor with shape
+        [batch_size, padded_length, emb_dim].
+      targets: Batch of target word ids; an int64 Tensor with shape
+        [batch_size, padded_length].
+      mask: A 0/1 Tensor with shape [batch_size, padded_length].
+      initial_state: Initial state of the GRU. A float32 Tensor with shape
+        [batch_size, num_gru_cells].
+      reuse_logits: Whether to reuse the logits weights.
+    """
+    # Decoder RNN.
+    cell = self._initialize_gru_cell(self.config.encoder_dim)
+    with tf.variable_scope(name) as scope:
+      # Add a padding word at the start of each sentence (to correspond to the
+      # prediction of the first word) and remove the last word.
+      decoder_input = tf.pad(
+          embeddings[:, :-1, :], [[0, 0], [1, 0], [0, 0]], name="input")
+      length = tf.reduce_sum(mask, 1, name="length")
+      decoder_output, _ = tf.nn.dynamic_rnn(
+          cell=cell,
+          inputs=decoder_input,
+          sequence_length=length,
+          initial_state=initial_state,
+          scope=scope)
+
+    # Stack batch vertically.
+    decoder_output = tf.reshape(decoder_output, [-1, self.config.encoder_dim])
+    targets = tf.reshape(targets, [-1])
+    weights = tf.to_float(tf.reshape(mask, [-1]))
+
+    # Logits.
+    with tf.variable_scope("logits", reuse=reuse_logits) as scope:
+      logits = tf.contrib.layers.fully_connected(
+          inputs=decoder_output,
+          num_outputs=self.config.vocab_size,
+          activation_fn=None,
+          weights_initializer=self.uniform_initializer,
+          scope=scope)
+
+    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
+        labels=targets, logits=logits)
+    batch_loss = tf.reduce_sum(losses * weights)
+    tf.losses.add_loss(batch_loss)
+
+    tf.summary.scalar("losses/" + name, batch_loss)
+
+    self.target_cross_entropy_losses.append(losses)
+    self.target_cross_entropy_loss_weights.append(weights)
+
+  def build_decoders(self):
+    """Builds the sentence decoders.
+
+    Inputs:
+      self.decode_pre_emb
+      self.decode_post_emb
+      self.decode_pre_ids
+      self.decode_post_ids
+      self.decode_pre_mask
+      self.decode_post_mask
+      self.thought_vectors
+
+    Outputs:
+      self.target_cross_entropy_losses
+      self.target_cross_entropy_loss_weights
+    """
+    if self.mode != "encode":
+      # Pre-sentence decoder.
+      self._build_decoder("decoder_pre", self.decode_pre_emb,
+                          self.decode_pre_ids, self.decode_pre_mask,
+                          self.thought_vectors, False)
+
+      # Post-sentence decoder. Logits weights are reused.
+      self._build_decoder("decoder_post", self.decode_post_emb,
+                          self.decode_post_ids, self.decode_post_mask,
+                          self.thought_vectors, True)
+
+  def build_loss(self):
+    """Builds the loss Tensor.
+
+    Outputs:
+      self.total_loss
+    """
+    if self.mode != "encode":
+      total_loss = tf.losses.get_total_loss()
+      tf.summary.scalar("losses/total", total_loss)
+
+      self.total_loss = total_loss
+
+  def build_global_step(self):
+    """Builds the global step Tensor.
+
+    Outputs:
+      self.global_step
+    """
+    self.global_step = tf.contrib.framework.create_global_step()
+
+  def build(self):
+    """Creates all ops for training, evaluation or encoding."""
+    self.build_inputs()
+    self.build_word_embeddings()
+    self.build_encoder()
+    self.build_decoders()
+    self.build_loss()
+    self.build_global_step()

+ 191 - 0
skip_thoughts/skip_thoughts/skip_thoughts_model_test.py

@@ -0,0 +1,191 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow_models.skip_thoughts.skip_thoughts_model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import numpy as np
+import tensorflow as tf
+
+from skip_thoughts import configuration
+from skip_thoughts import skip_thoughts_model
+
+
+class SkipThoughtsModel(skip_thoughts_model.SkipThoughtsModel):
+  """Subclass of SkipThoughtsModel without the disk I/O."""
+
+  def build_inputs(self):
+    if self.mode == "encode":
+      # Encode mode doesn't read from disk, so defer to parent.
+      return super(SkipThoughtsModel, self).build_inputs()
+    else:
+      # Replace disk I/O with random Tensors.
+      self.encode_ids = tf.random_uniform(
+          [self.config.batch_size, 15],
+          minval=0,
+          maxval=self.config.vocab_size,
+          dtype=tf.int64)
+      self.decode_pre_ids = tf.random_uniform(
+          [self.config.batch_size, 15],
+          minval=0,
+          maxval=self.config.vocab_size,
+          dtype=tf.int64)
+      self.decode_post_ids = tf.random_uniform(
+          [self.config.batch_size, 15],
+          minval=0,
+          maxval=self.config.vocab_size,
+          dtype=tf.int64)
+      self.encode_mask = tf.ones_like(self.encode_ids)
+      self.decode_pre_mask = tf.ones_like(self.decode_pre_ids)
+      self.decode_post_mask = tf.ones_like(self.decode_post_ids)
+
+
+class SkipThoughtsModelTest(tf.test.TestCase):
+
+  def setUp(self):
+    super(SkipThoughtsModelTest, self).setUp()
+    self._model_config = configuration.model_config()
+
+  def _countModelParameters(self):
+    """Counts the number of parameters in the model at top level scope."""
+    counter = {}
+    for v in tf.global_variables():
+      name = v.op.name.split("/")[0]
+      num_params = v.get_shape().num_elements()
+      if not num_params:
+        self.fail("Could not infer num_elements from Variable %s" % v.op.name)
+      counter[name] = counter.get(name, 0) + num_params
+    return counter
+
+  def _checkModelParameters(self):
+    """Verifies the number of parameters in the model."""
+    param_counts = self._countModelParameters()
+    expected_param_counts = {
+        # vocab_size * embedding_size
+        "word_embedding": 12400000,
+        # GRU Cells
+        "encoder": 21772800,
+        "decoder_pre": 21772800,
+        "decoder_post": 21772800,
+        # (encoder_dim + 1) * vocab_size
+        "logits": 48020000,
+        "global_step": 1,
+    }
+    self.assertDictEqual(expected_param_counts, param_counts)
+
+  def _checkOutputs(self, expected_shapes, feed_dict=None):
+    """Verifies that the model produces expected outputs.
+
+    Args:
+      expected_shapes: A dict mapping Tensor or Tensor name to expected output
+        shape.
+      feed_dict: Values of Tensors to feed into Session.run().
+    """
+    fetches = expected_shapes.keys()
+
+    with self.test_session() as sess:
+      sess.run(tf.global_variables_initializer())
+      outputs = sess.run(fetches, feed_dict)
+
+    for index, output in enumerate(outputs):
+      tensor = fetches[index]
+      expected = expected_shapes[tensor]
+      actual = output.shape
+      if expected != actual:
+        self.fail("Tensor %s has shape %s (expected %s)." % (tensor, actual,
+                                                             expected))
+
+  def testBuildForTraining(self):
+    model = SkipThoughtsModel(self._model_config, mode="train")
+    model.build()
+
+    self._checkModelParameters()
+
+    expected_shapes = {
+        # [batch_size, length]
+        model.encode_ids: (128, 15),
+        model.decode_pre_ids: (128, 15),
+        model.decode_post_ids: (128, 15),
+        model.encode_mask: (128, 15),
+        model.decode_pre_mask: (128, 15),
+        model.decode_post_mask: (128, 15),
+        # [batch_size, length, word_embedding_dim]
+        model.encode_emb: (128, 15, 620),
+        model.decode_pre_emb: (128, 15, 620),
+        model.decode_post_emb: (128, 15, 620),
+        # [batch_size, encoder_dim]
+        model.thought_vectors: (128, 2400),
+        # [batch_size * length]
+        model.target_cross_entropy_losses[0]: (1920,),
+        model.target_cross_entropy_losses[1]: (1920,),
+        # [batch_size * length]
+        model.target_cross_entropy_loss_weights[0]: (1920,),
+        model.target_cross_entropy_loss_weights[1]: (1920,),
+        # Scalar
+        model.total_loss: (),
+    }
+    self._checkOutputs(expected_shapes)
+
+  def testBuildForEval(self):
+    model = SkipThoughtsModel(self._model_config, mode="eval")
+    model.build()
+
+    self._checkModelParameters()
+
+    expected_shapes = {
+        # [batch_size, length]
+        model.encode_ids: (128, 15),
+        model.decode_pre_ids: (128, 15),
+        model.decode_post_ids: (128, 15),
+        model.encode_mask: (128, 15),
+        model.decode_pre_mask: (128, 15),
+        model.decode_post_mask: (128, 15),
+        # [batch_size, length, word_embedding_dim]
+        model.encode_emb: (128, 15, 620),
+        model.decode_pre_emb: (128, 15, 620),
+        model.decode_post_emb: (128, 15, 620),
+        # [batch_size, encoder_dim]
+        model.thought_vectors: (128, 2400),
+        # [batch_size * length]
+        model.target_cross_entropy_losses[0]: (1920,),
+        model.target_cross_entropy_losses[1]: (1920,),
+        # [batch_size * length]
+        model.target_cross_entropy_loss_weights[0]: (1920,),
+        model.target_cross_entropy_loss_weights[1]: (1920,),
+        # Scalar
+        model.total_loss: (),
+    }
+    self._checkOutputs(expected_shapes)
+
+  def testBuildForEncode(self):
+    model = SkipThoughtsModel(self._model_config, mode="encode")
+    model.build()
+
+    # Test feeding a batch of word embeddings to get skip thought vectors.
+    encode_emb = np.random.rand(64, 15, 620)
+    encode_mask = np.ones((64, 15), dtype=np.int64)
+    feed_dict = {model.encode_emb: encode_emb, model.encode_mask: encode_mask}
+    expected_shapes = {
+        # [batch_size, encoder_dim]
+        model.thought_vectors: (64, 2400),
+    }
+    self._checkOutputs(expected_shapes, feed_dict)
+
+
+if __name__ == "__main__":
+  tf.test.main()

+ 199 - 0
skip_thoughts/skip_thoughts/track_perplexity.py

@@ -0,0 +1,199 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tracks training progress via per-word perplexity.
+
+This script should be run concurrently with training so that summaries show up
+in TensorBoard.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import os.path
+import time
+
+
+import numpy as np
+import tensorflow as tf
+
+from skip_thoughts import configuration
+from skip_thoughts import skip_thoughts_model
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_string("input_file_pattern", None,
+                       "File pattern of sharded TFRecord input files.")
+tf.flags.DEFINE_string("checkpoint_dir", None,
+                       "Directory containing model checkpoints.")
+tf.flags.DEFINE_string("eval_dir", None, "Directory to write event logs to.")
+
+tf.flags.DEFINE_integer("eval_interval_secs", 600,
+                        "Interval between evaluation runs.")
+tf.flags.DEFINE_integer("num_eval_examples", 50000,
+                        "Number of examples for evaluation.")
+
+tf.flags.DEFINE_integer("min_global_step", 100,
+                        "Minimum global step to run evaluation.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def evaluate_model(sess, losses, weights, num_batches, global_step,
+                   summary_writer, summary_op):
+  """Computes perplexity-per-word over the evaluation dataset.
+
+  Summaries and perplexity-per-word are written out to the eval directory.
+
+  Args:
+    sess: Session object.
+    losses: A Tensor of any shape; the target cross entropy losses for the
+      current batch.
+    weights: A Tensor of weights corresponding to losses.
+    num_batches: Integer; the number of evaluation batches.
+    global_step: Integer; global step of the model checkpoint.
+    summary_writer: Instance of SummaryWriter.
+    summary_op: Op for generating model summaries.
+  """
+  # Log model summaries on a single batch.
+  summary_str = sess.run(summary_op)
+  summary_writer.add_summary(summary_str, global_step)
+
+  start_time = time.time()
+  sum_losses = 0.0
+  sum_weights = 0.0
+  for i in xrange(num_batches):
+    batch_losses, batch_weights = sess.run([losses, weights])
+    sum_losses += np.sum(batch_losses * batch_weights)
+    sum_weights += np.sum(batch_weights)
+    if not i % 100:
+      tf.logging.info("Computed losses for %d of %d batches.", i + 1,
+                      num_batches)
+  eval_time = time.time() - start_time
+
+  perplexity = math.exp(sum_losses / sum_weights)
+  tf.logging.info("Perplexity = %f (%.2f sec)", perplexity, eval_time)
+
+  # Log perplexity to the SummaryWriter.
+  summary = tf.Summary()
+  value = summary.value.add()
+  value.simple_value = perplexity
+  value.tag = "perplexity"
+  summary_writer.add_summary(summary, global_step)
+
+  # Write the Events file to the eval directory.
+  summary_writer.flush()
+  tf.logging.info("Finished processing evaluation at global step %d.",
+                  global_step)
+
+
+def run_once(model, losses, weights, saver, summary_writer, summary_op):
+  """Evaluates the latest model checkpoint.
+
+  Args:
+    model: Instance of SkipThoughtsModel; the model to evaluate.
+    losses: Tensor; the target cross entropy losses for the current batch.
+    weights: A Tensor of weights corresponding to losses.
+    saver: Instance of tf.train.Saver for restoring model Variables.
+    summary_writer: Instance of FileWriter.
+    summary_op: Op for generating model summaries.
+  """
+  model_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
+  if not model_path:
+    tf.logging.info("Skipping evaluation. No checkpoint found in: %s",
+                    FLAGS.checkpoint_dir)
+    return
+
+  with tf.Session() as sess:
+    # Load model from checkpoint.
+    tf.logging.info("Loading model from checkpoint: %s", model_path)
+    saver.restore(sess, model_path)
+    global_step = tf.train.global_step(sess, model.global_step.name)
+    tf.logging.info("Successfully loaded %s at global step = %d.",
+                    os.path.basename(model_path), global_step)
+    if global_step < FLAGS.min_global_step:
+      tf.logging.info("Skipping evaluation. Global step = %d < %d", global_step,
+                      FLAGS.min_global_step)
+      return
+
+    # Start the queue runners.
+    coord = tf.train.Coordinator()
+    threads = tf.train.start_queue_runners(coord=coord)
+
+    num_eval_batches = int(
+        math.ceil(FLAGS.num_eval_examples / model.config.batch_size))
+
+    # Run evaluation on the latest checkpoint.
+    try:
+      evaluate_model(sess, losses, weights, num_eval_batches, global_step,
+                     summary_writer, summary_op)
+    except tf.InvalidArgumentError:
+      tf.logging.error(
+          "Evaluation raised InvalidArgumentError (e.g. due to Nans).")
+    finally:
+      coord.request_stop()
+      coord.join(threads, stop_grace_period_secs=10)
+
+
+def main(unused_argv):
+  if not FLAGS.input_file_pattern:
+    raise ValueError("--input_file_pattern is required.")
+  if not FLAGS.checkpoint_dir:
+    raise ValueError("--checkpoint_dir is required.")
+  if not FLAGS.eval_dir:
+    raise ValueError("--eval_dir is required.")
+
+  # Create the evaluation directory if it doesn't exist.
+  eval_dir = FLAGS.eval_dir
+  if not tf.gfile.IsDirectory(eval_dir):
+    tf.logging.info("Creating eval directory: %s", eval_dir)
+    tf.gfile.MakeDirs(eval_dir)
+
+  g = tf.Graph()
+  with g.as_default():
+    # Build the model for evaluation.
+    model_config = configuration.model_config(
+        input_file_pattern=FLAGS.input_file_pattern,
+        input_queue_capacity=FLAGS.num_eval_examples,
+        shuffle_input_data=False)
+    model = skip_thoughts_model.SkipThoughtsModel(model_config, mode="eval")
+    model.build()
+
+    losses = tf.concat(model.target_cross_entropy_losses, 0)
+    weights = tf.concat(model.target_cross_entropy_loss_weights, 0)
+
+    # Create the Saver to restore model Variables.
+    saver = tf.train.Saver()
+
+    # Create the summary operation and the summary writer.
+    summary_op = tf.summary.merge_all()
+    summary_writer = tf.summary.FileWriter(eval_dir)
+
+    g.finalize()
+
+    # Run a new evaluation run every eval_interval_secs.
+    while True:
+      start = time.time()
+      tf.logging.info("Starting evaluation at " + time.strftime(
+          "%Y-%m-%d-%H:%M:%S", time.localtime()))
+      run_once(model, losses, weights, saver, summary_writer, summary_op)
+      time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
+      if time_to_next_eval > 0:
+        time.sleep(time_to_next_eval)
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 99 - 0
skip_thoughts/skip_thoughts/train.py

@@ -0,0 +1,99 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Train the skip-thoughts model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from skip_thoughts import configuration
+from skip_thoughts import skip_thoughts_model
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_string("input_file_pattern", None,
+                       "File pattern of sharded TFRecord files containing "
+                       "tf.Example protos.")
+tf.flags.DEFINE_string("train_dir", None,
+                       "Directory for saving and loading checkpoints.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def _setup_learning_rate(config, global_step):
+  """Sets up the learning rate with optional exponential decay.
+
+  Args:
+    config: Object containing learning rate configuration parameters.
+    global_step: Tensor; the global step.
+
+  Returns:
+    learning_rate: Tensor; the learning rate with exponential decay.
+  """
+  if config.learning_rate_decay_factor > 0:
+    learning_rate = tf.train.exponential_decay(
+        learning_rate=float(config.learning_rate),
+        global_step=global_step,
+        decay_steps=config.learning_rate_decay_steps,
+        decay_rate=config.learning_rate_decay_factor,
+        staircase=False)
+  else:
+    learning_rate = tf.constant(config.learning_rate)
+  return learning_rate
+
+
+def main(unused_argv):
+  if not FLAGS.input_file_pattern:
+    raise ValueError("--input_file_pattern is required.")
+  if not FLAGS.train_dir:
+    raise ValueError("--train_dir is required.")
+
+  model_config = configuration.model_config(
+      input_file_pattern=FLAGS.input_file_pattern)
+  training_config = configuration.training_config()
+
+  tf.logging.info("Building training graph.")
+  g = tf.Graph()
+  with g.as_default():
+    model = skip_thoughts_model.SkipThoughtsModel(model_config, mode="train")
+    model.build()
+
+    learning_rate = _setup_learning_rate(training_config, model.global_step)
+    optimizer = tf.train.AdamOptimizer(learning_rate)
+
+    train_tensor = tf.contrib.slim.learning.create_train_op(
+        total_loss=model.total_loss,
+        optimizer=optimizer,
+        global_step=model.global_step,
+        clip_gradient_norm=training_config.clip_gradient_norm)
+
+    saver = tf.train.Saver()
+
+  tf.contrib.slim.learning.train(
+      train_op=train_tensor,
+      logdir=FLAGS.train_dir,
+      graph=g,
+      global_step=model.global_step,
+      number_of_steps=training_config.number_of_steps,
+      save_summaries_secs=training_config.save_summaries_secs,
+      saver=saver,
+      save_interval_secs=training_config.save_model_secs)
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 203 - 0
skip_thoughts/skip_thoughts/vocabulary_expansion.py

@@ -0,0 +1,203 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Compute an expanded vocabulary of embeddings using a word2vec model.
+
+This script loads the word embeddings from a trained skip-thoughts model and
+from a trained word2vec model (typically with a larger vocabulary). It trains a
+linear regression model without regularization to learn a linear mapping from
+the word2vec embedding space to the skip-thoughts embedding space. The model is
+then applied to all words in the word2vec vocabulary, yielding vectors in the
+skip-thoughts word embedding space for the union of the two vocabularies.
+
+The linear regression task is to learn a parameter matrix W to minimize
+  || X - Y * W ||^2,
+where X is a matrix of skip-thoughts embeddings of shape [num_words, dim1],
+Y is a matrix of word2vec embeddings of shape [num_words, dim2], and W is a
+matrix of shape [dim2, dim1].
+
+This is based on the "Translation Matrix" method from the paper:
+
+  "Exploiting Similarities among Languages for Machine Translation"
+  Tomas Mikolov, Quoc V. Le, Ilya Sutskever
+  https://arxiv.org/abs/1309.4168
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os.path
+
+
+import gensim.models
+import numpy as np
+import sklearn.linear_model
+import tensorflow as tf
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_string("skip_thoughts_model", None,
+                       "Checkpoint file or directory containing a checkpoint "
+                       "file.")
+
+tf.flags.DEFINE_string("skip_thoughts_vocab", None,
+                       "Path to vocabulary file containing a list of newline-"
+                       "separated words where the word id is the "
+                       "corresponding 0-based index in the file.")
+
+tf.flags.DEFINE_string("word2vec_model", None,
+                       "File containing a word2vec model in binary format.")
+
+tf.flags.DEFINE_string("output_dir", None, "Output directory.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def _load_skip_thoughts_embeddings(checkpoint_path):
+  """Loads the embedding matrix from a skip-thoughts model checkpoint.
+
+  Args:
+    checkpoint_path: Model checkpoint file or directory containing a checkpoint
+        file.
+
+  Returns:
+    word_embedding: A numpy array of shape [vocab_size, embedding_dim].
+
+  Raises:
+    ValueError: If no checkpoint file matches checkpoint_path.
+  """
+  if tf.gfile.IsDirectory(checkpoint_path):
+    checkpoint_file = tf.train.latest_checkpoint(checkpoint_path)
+    if not checkpoint_file:
+      raise ValueError("No checkpoint file found in %s" % checkpoint_path)
+  else:
+    checkpoint_file = checkpoint_path
+
+  tf.logging.info("Loading skip-thoughts embedding matrix from %s",
+                  checkpoint_file)
+  reader = tf.train.NewCheckpointReader(checkpoint_file)
+  word_embedding = reader.get_tensor("word_embedding")
+  tf.logging.info("Loaded skip-thoughts embedding matrix of shape %s",
+                  word_embedding.shape)
+
+  return word_embedding
+
+
+def _load_vocabulary(filename):
+  """Loads a vocabulary file.
+
+  Args:
+    filename: Path to text file containing newline-separated words.
+
+  Returns:
+    vocab: A dictionary mapping word to word id.
+  """
+  tf.logging.info("Reading vocabulary from %s", filename)
+  vocab = collections.OrderedDict()
+  with tf.gfile.GFile(filename, mode="r") as f:
+    for i, line in enumerate(f):
+      word = line.decode("utf-8").strip()
+      assert word not in vocab, "Attempting to add word twice: %s" % word
+      vocab[word] = i
+  tf.logging.info("Read vocabulary of size %d", len(vocab))
+  return vocab
+
+
+def _expand_vocabulary(skip_thoughts_emb, skip_thoughts_vocab, word2vec):
+  """Runs vocabulary expansion on a skip-thoughts model using a word2vec model.
+
+  Args:
+    skip_thoughts_emb: A numpy array of shape [skip_thoughts_vocab_size,
+        skip_thoughts_embedding_dim].
+    skip_thoughts_vocab: A dictionary of word to id.
+    word2vec: An instance of gensim.models.Word2Vec.
+
+  Returns:
+    combined_emb: A dictionary mapping words to embedding vectors.
+  """
+  # Find words shared between the two vocabularies.
+  tf.logging.info("Finding shared words")
+  shared_words = [w for w in word2vec.vocab if w in skip_thoughts_vocab]
+
+  # Select embedding vectors for shared words.
+  tf.logging.info("Selecting embeddings for %d shared words", len(shared_words))
+  shared_st_emb = skip_thoughts_emb[[
+      skip_thoughts_vocab[w] for w in shared_words
+  ]]
+  shared_w2v_emb = word2vec[shared_words]
+
+  # Train a linear regression model on the shared embedding vectors.
+  tf.logging.info("Training linear regression model")
+  model = sklearn.linear_model.LinearRegression()
+  model.fit(shared_w2v_emb, shared_st_emb)
+
+  # Create the expanded vocabulary.
+  tf.logging.info("Creating embeddings for expanded vocabuary")
+  combined_emb = collections.OrderedDict()
+  for w in word2vec.vocab:
+    # Ignore words with underscores (spaces).
+    if "_" not in w:
+      w_emb = model.predict(word2vec[w].reshape(1, -1))
+      combined_emb[w] = w_emb.reshape(-1)
+
+  for w in skip_thoughts_vocab:
+    combined_emb[w] = skip_thoughts_emb[skip_thoughts_vocab[w]]
+
+  tf.logging.info("Created expanded vocabulary of %d words", len(combined_emb))
+
+  return combined_emb
+
+
+def main(unused_argv):
+  if not FLAGS.skip_thoughts_model:
+    raise ValueError("--skip_thoughts_model is required.")
+  if not FLAGS.skip_thoughts_vocab:
+    raise ValueError("--skip_thoughts_vocab is required.")
+  if not FLAGS.word2vec_model:
+    raise ValueError("--word2vec_model is required.")
+  if not FLAGS.output_dir:
+    raise ValueError("--output_dir is required.")
+
+  if not tf.gfile.IsDirectory(FLAGS.output_dir):
+    tf.gfile.MakeDirs(FLAGS.output_dir)
+
+  # Load the skip-thoughts embeddings and vocabulary.
+  skip_thoughts_emb = _load_skip_thoughts_embeddings(FLAGS.skip_thoughts_model)
+  skip_thoughts_vocab = _load_vocabulary(FLAGS.skip_thoughts_vocab)
+
+  # Load the Word2Vec model.
+  word2vec = gensim.models.Word2Vec.load_word2vec_format(
+      FLAGS.word2vec_model, binary=True)
+
+  # Run vocabulary expansion.
+  embedding_map = _expand_vocabulary(skip_thoughts_emb, skip_thoughts_vocab,
+                                     word2vec)
+
+  # Save the output.
+  vocab = embedding_map.keys()
+  vocab_file = os.path.join(FLAGS.output_dir, "vocab.txt")
+  with tf.gfile.GFile(vocab_file, "w") as f:
+    f.write("\n".join(vocab))
+  tf.logging.info("Wrote vocabulary file to %s", vocab_file)
+
+  embeddings = np.array(embedding_map.values())
+  embeddings_file = os.path.join(FLAGS.output_dir, "embeddings.npy")
+  np.save(embeddings_file, embeddings)
+  tf.logging.info("Wrote embeddings file to %s", embeddings_file)
+
+
+if __name__ == "__main__":
+  tf.app.run()