Selaa lähdekoodia

Merge https://github.com/tensorflow/models

Ivan Bogatyy 8 vuotta sitten
vanhempi
commit
b7523ee562
39 muutettua tiedostoa jossa 2917 lisäystä ja 60 poistoa
  1. 1 0
      README.md
  2. 3 4
      autoencoder/AdditiveGaussianNoiseAutoencoderRunner.py
  3. 3 4
      autoencoder/AutoencoderRunner.py
  4. 3 4
      autoencoder/MaskingNoiseAutoencoderRunner.py
  5. 0 9
      autoencoder/Utils.py
  6. 3 4
      autoencoder/VariationalAutoencoderRunner.py
  7. 3 4
      autoencoder/autoencoder_models/Autoencoder.py
  8. 8 9
      autoencoder/autoencoder_models/DenoisingAutoencoder.py
  9. 4 3
      autoencoder/autoencoder_models/VariationalAutoencoder.py
  10. 2 10
      im2txt/README.md
  11. 6 1
      inception/inception/data/build_image_data.py
  12. 1 1
      resnet/resnet_model.py
  13. 8 0
      skip_thoughts/.gitignore
  14. 471 0
      skip_thoughts/README.md
  15. 0 0
      skip_thoughts/WORKSPACE
  16. 94 0
      skip_thoughts/skip_thoughts/BUILD
  17. 0 0
      skip_thoughts/skip_thoughts/__init__.py
  18. 110 0
      skip_thoughts/skip_thoughts/configuration.py
  19. 23 0
      skip_thoughts/skip_thoughts/data/BUILD
  20. 0 0
      skip_thoughts/skip_thoughts/data/__init__.py
  21. 301 0
      skip_thoughts/skip_thoughts/data/preprocess_dataset.py
  22. 27 0
      skip_thoughts/skip_thoughts/data/special_words.py
  23. 134 0
      skip_thoughts/skip_thoughts/encoder_manager.py
  24. 117 0
      skip_thoughts/skip_thoughts/evaluate.py
  25. 17 0
      skip_thoughts/skip_thoughts/ops/BUILD
  26. 0 0
      skip_thoughts/skip_thoughts/ops/__init__.py
  27. 134 0
      skip_thoughts/skip_thoughts/ops/gru_cell.py
  28. 118 0
      skip_thoughts/skip_thoughts/ops/input_ops.py
  29. 258 0
      skip_thoughts/skip_thoughts/skip_thoughts_encoder.py
  30. 369 0
      skip_thoughts/skip_thoughts/skip_thoughts_model.py
  31. 191 0
      skip_thoughts/skip_thoughts/skip_thoughts_model_test.py
  32. 199 0
      skip_thoughts/skip_thoughts/track_perplexity.py
  33. 99 0
      skip_thoughts/skip_thoughts/train.py
  34. 203 0
      skip_thoughts/skip_thoughts/vocabulary_expansion.py
  35. 1 1
      syntaxnet/g3doc/universal.md
  36. 3 3
      textsum/seq2seq_attention_model.py
  37. 1 1
      transformer/cluttered_mnist.py
  38. 1 1
      tutorials/README.md
  39. 1 1
      tutorials/rnn/translate/seq2seq_model.py

+ 1 - 0
README.md

@@ -21,6 +21,7 @@ To propose a model for inclusion please submit a pull request.
 - [next_frame_prediction](next_frame_prediction): probabilistic future frame synthesis via cross convolutional networks.
 - [next_frame_prediction](next_frame_prediction): probabilistic future frame synthesis via cross convolutional networks.
 - [real_nvp](real_nvp): density estimation using real-valued non-volume preserving (real NVP) transformations.
 - [real_nvp](real_nvp): density estimation using real-valued non-volume preserving (real NVP) transformations.
 - [resnet](resnet): deep and wide residual networks.
 - [resnet](resnet): deep and wide residual networks.
+- [skip_thoughts](skip_thoughts): recurrent neural network sentence-to-vector encoder.
 - [slim](slim): image classification models in TF-Slim.
 - [slim](slim): image classification models in TF-Slim.
 - [street](street): identify the name of a street (in France) from an image using a Deep RNN.
 - [street](street): identify the name of a street (in France) from an image using a Deep RNN.
 - [swivel](swivel): the Swivel algorithm for generating word embeddings.
 - [swivel](swivel): the Swivel algorithm for generating word embeddings.

+ 3 - 4
autoencoder/AdditiveGaussianNoiseAutoencoderRunner.py

@@ -4,7 +4,7 @@ import sklearn.preprocessing as prep
 import tensorflow as tf
 import tensorflow as tf
 from tensorflow.examples.tutorials.mnist import input_data
 from tensorflow.examples.tutorials.mnist import input_data
 
 
-from autoencoder.autoencoder_models.DenoisingAutoencoder import AdditiveGaussianNoiseAutoencoder
+from autoencoder_models.DenoisingAutoencoder import AdditiveGaussianNoiseAutoencoder
 
 
 mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
 mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
 
 
@@ -45,7 +45,6 @@ for epoch in range(training_epochs):
 
 
     # Display logs per epoch step
     # Display logs per epoch step
     if epoch % display_step == 0:
     if epoch % display_step == 0:
-        print "Epoch:", '%04d' % (epoch + 1), \
-            "cost=", "{:.9f}".format(avg_cost)
+        print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))
 
 
-print "Total cost: " + str(autoencoder.calc_total_cost(X_test))
+print("Total cost: " + str(autoencoder.calc_total_cost(X_test)))

+ 3 - 4
autoencoder/AutoencoderRunner.py

@@ -4,7 +4,7 @@ import sklearn.preprocessing as prep
 import tensorflow as tf
 import tensorflow as tf
 from tensorflow.examples.tutorials.mnist import input_data
 from tensorflow.examples.tutorials.mnist import input_data
 
 
-from autoencoder.autoencoder_models.Autoencoder import Autoencoder
+from autoencoder_models.Autoencoder import Autoencoder
 
 
 mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
 mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
 
 
@@ -44,7 +44,6 @@ for epoch in range(training_epochs):
 
 
     # Display logs per epoch step
     # Display logs per epoch step
     if epoch % display_step == 0:
     if epoch % display_step == 0:
-        print "Epoch:", '%04d' % (epoch + 1), \
-            "cost=", "{:.9f}".format(avg_cost)
+        print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))
 
 
-print "Total cost: " + str(autoencoder.calc_total_cost(X_test))
+print("Total cost: " + str(autoencoder.calc_total_cost(X_test)))

+ 3 - 4
autoencoder/MaskingNoiseAutoencoderRunner.py

@@ -4,7 +4,7 @@ import sklearn.preprocessing as prep
 import tensorflow as tf
 import tensorflow as tf
 from tensorflow.examples.tutorials.mnist import input_data
 from tensorflow.examples.tutorials.mnist import input_data
 
 
-from autoencoder.autoencoder_models.DenoisingAutoencoder import MaskingNoiseAutoencoder
+from autoencoder_models.DenoisingAutoencoder import MaskingNoiseAutoencoder
 
 
 mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
 mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
 
 
@@ -43,7 +43,6 @@ for epoch in range(training_epochs):
         avg_cost += cost / n_samples * batch_size
         avg_cost += cost / n_samples * batch_size
 
 
     if epoch % display_step == 0:
     if epoch % display_step == 0:
-        print "Epoch:", '%04d' % (epoch + 1), \
-            "cost=", "{:.9f}".format(avg_cost)
+        print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))
 
 
-print "Total cost: " + str(autoencoder.calc_total_cost(X_test))
+print("Total cost: " + str(autoencoder.calc_total_cost(X_test)))

+ 0 - 9
autoencoder/Utils.py

@@ -1,9 +0,0 @@
-import numpy as np
-import tensorflow as tf
-
-def xavier_init(fan_in, fan_out, constant = 1):
-    low = -constant * np.sqrt(6.0 / (fan_in + fan_out))
-    high = constant * np.sqrt(6.0 / (fan_in + fan_out))
-    return tf.random_uniform((fan_in, fan_out),
-                             minval = low, maxval = high,
-                             dtype = tf.float32)

+ 3 - 4
autoencoder/VariationalAutoencoderRunner.py

@@ -4,7 +4,7 @@ import sklearn.preprocessing as prep
 import tensorflow as tf
 import tensorflow as tf
 from tensorflow.examples.tutorials.mnist import input_data
 from tensorflow.examples.tutorials.mnist import input_data
 
 
-from autoencoder.autoencoder_models.VariationalAutoencoder import VariationalAutoencoder
+from autoencoder_models.VariationalAutoencoder import VariationalAutoencoder
 
 
 mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
 mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
 
 
@@ -47,7 +47,6 @@ for epoch in range(training_epochs):
 
 
     # Display logs per epoch step
     # Display logs per epoch step
     if epoch % display_step == 0:
     if epoch % display_step == 0:
-        print "Epoch:", '%04d' % (epoch + 1), \
-            "cost=", "{:.9f}".format(avg_cost)
+        print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))
 
 
-print "Total cost: " + str(autoencoder.calc_total_cost(X_test))
+print("Total cost: " + str(autoencoder.calc_total_cost(X_test)))

+ 3 - 4
autoencoder/autoencoder_models/Autoencoder.py

@@ -1,6 +1,4 @@
 import tensorflow as tf
 import tensorflow as tf
-import numpy as np
-import autoencoder.Utils
 
 
 class Autoencoder(object):
 class Autoencoder(object):
 
 
@@ -28,7 +26,8 @@ class Autoencoder(object):
 
 
     def _initialize_weights(self):
     def _initialize_weights(self):
         all_weights = dict()
         all_weights = dict()
-        all_weights['w1'] = tf.Variable(autoencoder.Utils.xavier_init(self.n_input, self.n_hidden))
+        all_weights['w1'] = tf.get_variable("w1", shape=[self.n_input, self.n_hidden],
+            initializer=tf.contrib.layers.xavier_initializer())
         all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32))
         all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32))
         all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype=tf.float32))
         all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype=tf.float32))
         all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype=tf.float32))
         all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype=tf.float32))
@@ -46,7 +45,7 @@ class Autoencoder(object):
 
 
     def generate(self, hidden = None):
     def generate(self, hidden = None):
         if hidden is None:
         if hidden is None:
-            hidden = np.random.normal(size=self.weights["b1"])
+            hidden = self.sess.run(tf.random_normal([1, self.n_hidden]))
         return self.sess.run(self.reconstruction, feed_dict={self.hidden: hidden})
         return self.sess.run(self.reconstruction, feed_dict={self.hidden: hidden})
 
 
     def reconstruct(self, X):
     def reconstruct(self, X):

+ 8 - 9
autoencoder/autoencoder_models/DenoisingAutoencoder.py

@@ -1,7 +1,4 @@
 import tensorflow as tf
 import tensorflow as tf
-import numpy as np
-import autoencoder.Utils
-
 
 
 class AdditiveGaussianNoiseAutoencoder(object):
 class AdditiveGaussianNoiseAutoencoder(object):
     def __init__(self, n_input, n_hidden, transfer_function = tf.nn.softplus, optimizer = tf.train.AdamOptimizer(),
     def __init__(self, n_input, n_hidden, transfer_function = tf.nn.softplus, optimizer = tf.train.AdamOptimizer(),
@@ -31,7 +28,8 @@ class AdditiveGaussianNoiseAutoencoder(object):
 
 
     def _initialize_weights(self):
     def _initialize_weights(self):
         all_weights = dict()
         all_weights = dict()
-        all_weights['w1'] = tf.Variable(autoencoder.Utils.xavier_init(self.n_input, self.n_hidden))
+        all_weights['w1'] = tf.get_variable("w1", shape=[self.n_input, self.n_hidden],
+            initializer=tf.contrib.layers.xavier_initializer())
         all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype = tf.float32))
         all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype = tf.float32))
         all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype = tf.float32))
         all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype = tf.float32))
         all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype = tf.float32))
         all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype = tf.float32))
@@ -53,9 +51,9 @@ class AdditiveGaussianNoiseAutoencoder(object):
                                                        self.scale: self.training_scale
                                                        self.scale: self.training_scale
                                                        })
                                                        })
 
 
-    def generate(self, hidden = None):
+    def generate(self, hidden=None):
         if hidden is None:
         if hidden is None:
-            hidden = np.random.normal(size = self.weights["b1"])
+            hidden = self.sess.run(tf.random_normal([1, self.n_hidden]))
         return self.sess.run(self.reconstruction, feed_dict = {self.hidden: hidden})
         return self.sess.run(self.reconstruction, feed_dict = {self.hidden: hidden})
 
 
     def reconstruct(self, X):
     def reconstruct(self, X):
@@ -98,7 +96,8 @@ class MaskingNoiseAutoencoder(object):
 
 
     def _initialize_weights(self):
     def _initialize_weights(self):
         all_weights = dict()
         all_weights = dict()
-        all_weights['w1'] = tf.Variable(autoencoder.Utils.xavier_init(self.n_input, self.n_hidden))
+        all_weights['w1'] = tf.get_variable("w1", shape=[self.n_input, self.n_hidden],
+            initializer=tf.contrib.layers.xavier_initializer())
         all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype = tf.float32))
         all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype = tf.float32))
         all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype = tf.float32))
         all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype = tf.float32))
         all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype = tf.float32))
         all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype = tf.float32))
@@ -115,9 +114,9 @@ class MaskingNoiseAutoencoder(object):
     def transform(self, X):
     def transform(self, X):
         return self.sess.run(self.hidden, feed_dict = {self.x: X, self.keep_prob: 1.0})
         return self.sess.run(self.hidden, feed_dict = {self.x: X, self.keep_prob: 1.0})
 
 
-    def generate(self, hidden = None):
+    def generate(self, hidden=None):
         if hidden is None:
         if hidden is None:
-            hidden = np.random.normal(size = self.weights["b1"])
+            hidden = self.sess.run(tf.random_normal([1, self.n_hidden]))
         return self.sess.run(self.reconstruction, feed_dict = {self.hidden: hidden})
         return self.sess.run(self.reconstruction, feed_dict = {self.hidden: hidden})
 
 
     def reconstruct(self, X):
     def reconstruct(self, X):

+ 4 - 3
autoencoder/autoencoder_models/VariationalAutoencoder.py

@@ -1,6 +1,5 @@
 import tensorflow as tf
 import tensorflow as tf
 import numpy as np
 import numpy as np
-import autoencoder.Utils
 
 
 class VariationalAutoencoder(object):
 class VariationalAutoencoder(object):
 
 
@@ -36,8 +35,10 @@ class VariationalAutoencoder(object):
 
 
     def _initialize_weights(self):
     def _initialize_weights(self):
         all_weights = dict()
         all_weights = dict()
-        all_weights['w1'] = tf.Variable(autoencoder.Utils.xavier_init(self.n_input, self.n_hidden))
-        all_weights['log_sigma_w1'] = tf.Variable(autoencoder.Utils.xavier_init(self.n_input, self.n_hidden))
+        all_weights['w1'] = tf.get_variable("w1", shape=[self.n_input, self.n_hidden],
+            initializer=tf.contrib.layers.xavier_initializer())
+        all_weights['log_sigma_w1'] = tf.get_variable("log_sigma_w1", shape=[self.n_input, self.n_hidden],
+            initializer=tf.contrib.layers.xavier_initializer())
         all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32))
         all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32))
         all_weights['log_sigma_b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32))
         all_weights['log_sigma_b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32))
         all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype=tf.float32))
         all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype=tf.float32))

+ 2 - 10
im2txt/README.md

@@ -37,9 +37,7 @@ Full text available at: http://arxiv.org/abs/1609.06647
 The *Show and Tell* model is a deep neural network that learns how to describe
 The *Show and Tell* model is a deep neural network that learns how to describe
 the content of images. For example:
 the content of images. For example:
 
 
-<center>
 ![Example captions](g3doc/example_captions.jpg)
 ![Example captions](g3doc/example_captions.jpg)
-</center>
 
 
 ### Architecture
 ### Architecture
 
 
@@ -66,9 +64,7 @@ learned during training.
 
 
 The following diagram illustrates the model architecture.
 The following diagram illustrates the model architecture.
 
 
-<center>
 ![Show and Tell Architecture](g3doc/show_and_tell_architecture.png)
 ![Show and Tell Architecture](g3doc/show_and_tell_architecture.png)
-</center>
 
 
 In this diagram, \{*s*<sub>0</sub>, *s*<sub>1</sub>, ..., *s*<sub>*N*-1</sub>\}
 In this diagram, \{*s*<sub>0</sub>, *s*<sub>1</sub>, ..., *s*<sub>*N*-1</sub>\}
 are the words of the caption and \{*w*<sub>*e*</sub>*s*<sub>0</sub>,
 are the words of the caption and \{*w*<sub>*e*</sub>*s*<sub>0</sub>,
@@ -137,8 +133,7 @@ Each caption is a list of words. During preprocessing, a dictionary is created
 that assigns each word in the vocabulary to an integer-valued id. Each caption
 that assigns each word in the vocabulary to an integer-valued id. Each caption
 is encoded as a list of integer word ids in the `tf.SequenceExample` protos.
 is encoded as a list of integer word ids in the `tf.SequenceExample` protos.
 
 
-We have provided a script to download and preprocess the [MSCOCO]
-(http://mscoco.org/) image captioning data set into this format. Downloading
+We have provided a script to download and preprocess the [MSCOCO](http://mscoco.org/) image captioning data set into this format. Downloading
 and preprocessing the data may take several hours depending on your network and
 and preprocessing the data may take several hours depending on your network and
 computer speed. Please be patient.
 computer speed. Please be patient.
 
 
@@ -266,8 +261,7 @@ tensorboard --logdir="${MODEL_DIR}"
 ### Fine Tune the Inception v3 Model
 ### Fine Tune the Inception v3 Model
 
 
 Your model will already be able to generate reasonable captions after the first
 Your model will already be able to generate reasonable captions after the first
-phase of training. Try it out! (See [Generating Captions]
-(#generating-captions)).
+phase of training. Try it out! (See [Generating Captions](#generating-captions)).
 
 
 You can further improve the performance of the model by running a
 You can further improve the performance of the model by running a
 second training phase to jointly fine-tune the parameters of the *Inception v3*
 second training phase to jointly fine-tune the parameters of the *Inception v3*
@@ -337,6 +331,4 @@ expected.
 
 
 Here is the image:
 Here is the image:
 
 
-<center>
 ![Surfer](g3doc/COCO_val2014_000000224477.jpg)
 ![Surfer](g3doc/COCO_val2014_000000224477.jpg)
-</center>

+ 6 - 1
inception/inception/data/build_image_data.py

@@ -261,7 +261,12 @@ def _process_image_files_batch(coder, thread_index, ranges, name, filenames,
       label = labels[i]
       label = labels[i]
       text = texts[i]
       text = texts[i]
 
 
-      image_buffer, height, width = _process_image(filename, coder)
+      try:
+        image_buffer, height, width = _process_image(filename, coder)
+      except Exception as e:
+        print(e)
+        print('SKIPPED: Unexpected eror while decoding %s.' % filename)
+        continue
 
 
       example = _convert_to_example(filename, image_buffer, label,
       example = _convert_to_example(filename, image_buffer, label,
                                     text, height, width)
                                     text, height, width)

+ 1 - 1
resnet/resnet_model.py

@@ -128,7 +128,7 @@ class ResNet(object):
   def _build_train_op(self):
   def _build_train_op(self):
     """Build training specific ops for the graph."""
     """Build training specific ops for the graph."""
     self.lrn_rate = tf.constant(self.hps.lrn_rate, tf.float32)
     self.lrn_rate = tf.constant(self.hps.lrn_rate, tf.float32)
-    tf.summary.scalar('learning rate', self.lrn_rate)
+    tf.summary.scalar('learning_rate', self.lrn_rate)
 
 
     trainable_variables = tf.trainable_variables()
     trainable_variables = tf.trainable_variables()
     grads = tf.gradients(self.cost, trainable_variables)
     grads = tf.gradients(self.cost, trainable_variables)

+ 8 - 0
skip_thoughts/.gitignore

@@ -0,0 +1,8 @@
+/bazel-bin
+/bazel-ci_build-cache
+/bazel-genfiles
+/bazel-out
+/bazel-skip_thoughts
+/bazel-testlogs
+/bazel-tf
+*.pyc

+ 471 - 0
skip_thoughts/README.md

@@ -0,0 +1,471 @@
+# Skip-Thought Vectors
+
+This is a TensorFlow implementation of the model described in:
+
+Jamie Ryan Kiros, Yukun Zhu, Ruslan Salakhutdinov, Richard S. Zemel,
+Antonio Torralba, Raquel Urtasun, Sanja Fidler.
+[Skip-Thought Vectors](https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf).
+*In NIPS, 2015.*
+
+
+## Contact
+***Code author:*** Chris Shallue
+
+***Pull requests and issues:*** @cshallue
+
+## Contents
+* [Model Overview](#model-overview)
+* [Getting Started](#getting-started)
+    * [Install Required Packages](#install-required-packages)
+    * [Download Pretrained Models (Optional)](#download-pretrained-models-optional)
+* [Training a Model](#training-a-model)
+    * [Prepare the Training Data](#prepare-the-training-data)
+    * [Run the Training Script](#run-the-training-script)
+    * [Track Training Progress](#track-training-progress)
+* [Expanding the Vocabulary](#expanding-the-vocabulary)
+    * [Overview](#overview)
+    * [Preparation](#preparation)
+    * [Run the Vocabulary Expansion Script](#run-the-vocabulary-expansion-script)
+* [Evaluating a Model](#evaluating-a-model)
+    * [Overview](#overview-1)
+    * [Preparation](#preparation-1)
+    * [Run the Evaluation Tasks](#run-the-evaluation-tasks)
+* [Encoding Sentences](#encoding-sentences)
+
+## Model overview
+
+The *Skip-Thoughts* model is a sentence encoder. It learns to encode input
+sentences into a fixed-dimensional vector representation that is useful for many
+tasks, for example to detect paraphrases or to classify whether a product review
+is positive or negative. See the
+[Skip-Thought Vectors](https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf)
+paper for details of the model architecture and more example applications.
+
+A trained *Skip-Thoughts* model will encode similar sentences nearby each other
+in the embedding vector space. The following examples show the nearest neighbor by
+cosine similarity of some sentences from the
+[movie review dataset](https://www.cs.cornell.edu/people/pabo/movie-review-data/).
+
+
+| Input sentence | Nearest Neighbor |
+|----------------|------------------|
+| Simplistic, silly and tedious. | Trite, banal, cliched, mostly inoffensive. |
+| Not so much farcical as sour. | Not only unfunny, but downright repellent. |
+| A sensitive and astute first feature by Anne-Sophie Birot. | Absorbing character study by André Turpin . |
+| An enthralling, entertaining feature. |  A slick, engrossing melodrama. |
+
+## Getting Started
+
+### Install Required Packages
+First ensure that you have installed the following required packages:
+
+* **Bazel** ([instructions](http://bazel.build/docs/install.html))
+* **TensorFlow** ([instructions](https://www.tensorflow.org/install/))
+* **NumPy** ([instructions](http://www.scipy.org/install.html))
+* **scikit-learn** ([instructions](http://scikit-learn.org/stable/install.html))
+* **Natural Language Toolkit (NLTK)**
+    * First install NLTK ([instructions](http://www.nltk.org/install.html))
+    * Then install the NLTK data ([instructions](http://www.nltk.org/data.html))
+* **gensim** ([instructions](https://radimrehurek.com/gensim/install.html))
+    * Only required if you will be expanding your vocabulary with the [word2vec](https://code.google.com/archive/p/word2vec/) model.
+
+
+### Download Pretrained Models (Optional)
+
+You can download model checkpoints pretrained on the
+[BookCorpus](http://yknzhu.wixsite.com/mbweb) dataset in the following
+configurations:
+
+* Unidirectional RNN encoder ("uni-skip" in the paper)
+* Bidirectional RNN encoder ("bi-skip" in the paper)
+
+```shell
+# Directory to download the pretrained models to.
+PRETRAINED_MODELS_DIR="${HOME}/skip_thoughts/pretrained/"
+
+mkdir -p ${PRETRAINED_MODELS_DIR}
+cd ${PRETRAINED_MODELS_DIR}
+
+# Download and extract the unidirectional model.
+wget "http://download.tensorflow.org/models/skip_thoughts_uni_2017_02_02.tar.gz"
+tar -xvf skip_thoughts_uni_2017_02_02.tar.gz
+rm skip_thoughts_uni_2017_02_02.tar.gz
+
+# Download and extract the bidirectional model.
+wget "http://download.tensorflow.org/models/skip_thoughts_bi_2017_02_16.tar.gz"
+tar -xvf skip_thoughts_bi_2017_02_16.tar.gz
+rm skip_thoughts_bi_2017_02_16.tar.gz
+```
+
+You can now skip to the sections [Evaluating a Model](#evaluating-a-model) and
+[Encoding Sentences](#encoding-sentences).
+
+
+## Training a Model
+
+### Prepare the Training Data
+
+To train a model you will need to provide training data in TFRecord format. The
+TFRecord format consists of a set of sharded files containing serialized
+`tf.Example` protocol buffers. Each `tf.Example` proto contains three
+sentences:
+
+  * `encode`: The sentence to encode.
+  * `decode_pre`: The sentence preceding `encode` in the original text.
+  * `decode_post`: The sentence following `encode` in the original text.
+
+Each sentence is a list of words. During preprocessing, a dictionary is created
+that assigns each word in the vocabulary to an integer-valued id. Each sentence
+is encoded as a list of integer word ids in the `tf.Example` protos.
+
+We have provided a script to preprocess any set of text-files into this format.
+You may wish to use the [BookCorpus](http://yknzhu.wixsite.com/mbweb) dataset.
+Note that the preprocessing script may take **12 hours** or more to complete
+on this large dataset.
+
+```shell
+# Comma-separated list of globs matching the input input files. The format of
+# the input files is assumed to be a list of newline-separated sentences, where
+# each sentence is already tokenized.
+INPUT_FILES="${HOME}/skip_thoughts/bookcorpus/*.txt"
+
+# Location to save the preprocessed training and validation data.
+DATA_DIR="${HOME}/skip_thoughts/data"
+
+# Build the preprocessing script.
+bazel build -c opt skip_thoughts/data/preprocess_dataset
+
+# Run the preprocessing script.
+bazel-bin/skip_thoughts/data/preprocess_dataset \
+  --input_files=${INPUT_FILES} \
+  --output_dir=${DATA_DIR}
+```
+
+When the script finishes you will find 100 training files and 1 validation file
+in `DATA_DIR`. The files will match the patterns `train-?????-of-00100` and
+`validation-00000-of-00001` respectively.
+
+The script will also produce a file named `vocab.txt`. The format of this file
+is a list of newline-separated words where the word id is the corresponding 0-
+based line index. Words are sorted by descending order of frequency in the input
+data. Only the top 20,000 words are assigned unique ids; all other words are
+assigned the "unknown id" of 1 in the processed data.
+
+### Run the Training Script
+
+Execute the following commands to start the training script. By default it will
+run for 500k steps (around 9 days on a GeForce GTX 1080 GPU).
+
+```shell
+# Directory containing the preprocessed data.
+DATA_DIR="${HOME}/skip_thoughts/data"
+
+# Directory to save the model.
+MODEL_DIR="${HOME}/skip_thoughts/model"
+
+# Build the model.
+bazel build -c opt skip_thoughts/...
+
+# Run the training script.
+bazel-bin/skip_thoughts/train \
+  --input_file_pattern="${DATA_DIR}/train-?????-of-00100" \
+  --train_dir="${MODEL_DIR}/train"
+```
+
+### Track Training Progress
+
+Optionally, you can run the `track_perplexity` script in a separate process.
+This will log per-word perplexity on the validation set which allows training
+progress to be monitored on
+[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
+
+Note that you may run out of memory if you run the this script on the same GPU
+as the training script. You can set the environment variable
+`CUDA_VISIBLE_DEVICES=""` to force the script to run on CPU. If it runs too
+slowly on CPU, you can decrease the value of `--num_eval_examples`.
+
+```shell
+DATA_DIR="${HOME}/skip_thoughts/data"
+MODEL_DIR="${HOME}/skip_thoughts/model"
+
+# Ignore GPU devices (only necessary if your GPU is currently memory
+# constrained, for example, by running the training script).
+export CUDA_VISIBLE_DEVICES=""
+
+# Run the evaluation script. This will run in a loop, periodically loading the
+# latest model checkpoint file and computing evaluation metrics.
+bazel-bin/skip_thoughts/track_perplexity \
+  --input_file_pattern="${DATA_DIR}/validation-?????-of-00001" \
+  --checkpoint_dir="${MODEL_DIR}/train" \
+  --eval_dir="${MODEL_DIR}/val" \
+  --num_eval_examples=50000
+```
+
+If you started the `track_perplexity` script, run a
+[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard)
+server in a separate process for real-time monitoring of training summaries and
+validation perplexity.
+
+```shell
+MODEL_DIR="${HOME}/skip_thoughts/model"
+
+# Run a TensorBoard server.
+tensorboard --logdir="${MODEL_DIR}"
+```
+
+## Expanding the Vocabulary
+
+### Overview
+
+The vocabulary generated by the preprocessing script contains only 20,000 words
+which is insufficient for many tasks. For example, a sentence from Wikipedia
+might contain nouns that do not appear in this vocabulary.
+
+A solution to this problem described in the
+[Skip-Thought Vectors](https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf)
+paper is to learn a mapping that transfers word representations from one model to
+another. This idea is based on the "Translation Matrix" method from the paper
+[Exploiting Similarities Among Languages for Machine Translation](https://arxiv.org/abs/1309.4168).
+
+
+Specifically, we will load the word embeddings from a trained *Skip-Thoughts*
+model and from a trained [word2vec model](https://arxiv.org/pdf/1301.3781.pdf)
+(which has a much larger vocabulary). We will train a linear regression model
+without regularization to learn a linear mapping from the word2vec embedding
+space to the *Skip-Thoughts* embedding space. We will then apply the linear
+model to all words in the word2vec vocabulary, yielding vectors in the *Skip-
+Thoughts* word embedding space for the union of the two vocabularies.
+
+The linear regression task is to learn a parameter matrix *W* to minimize
+*|| X - Y \* W ||<sup>2</sup>*, where *X* is a matrix of *Skip-Thoughts*
+embeddings of shape `[num_words, dim1]`, *Y* is a matrix of word2vec embeddings
+of shape `[num_words, dim2]`, and *W* is a matrix of shape `[dim2, dim1]`.
+
+### Preparation
+
+First you will need to download and unpack a pretrained
+[word2vec model](https://arxiv.org/pdf/1301.3781.pdf) from
+[this website](https://code.google.com/archive/p/word2vec/)
+([direct download link](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing)).
+This model was trained on the Google News dataset (about 100 billion words).
+
+
+Also ensure that you have already [installed gensim](https://radimrehurek.com/gensim/install.html).
+
+### Run the Vocabulary Expansion Script
+
+```shell
+# Path to checkpoint file or a directory containing checkpoint files (the script
+# will select the most recent).
+CHECKPOINT_PATH="${HOME}/skip_thoughts/model/train"
+
+# Vocabulary file generated by the preprocessing script.
+SKIP_THOUGHTS_VOCAB="${HOME}/skip_thoughts/data/vocab.txt"
+
+# Path to downloaded word2vec model.
+WORD2VEC_MODEL="${HOME}/skip_thoughts/googlenews/GoogleNews-vectors-negative300.bin"
+
+# Output directory.
+EXP_VOCAB_DIR="${HOME}/skip_thoughts/exp_vocab"
+
+# Build the vocabulary expansion script.
+bazel build -c opt skip_thoughts/vocabulary_expansion
+
+# Run the vocabulary expansion script.
+bazel-bin/skip_thoughts/vocabulary_expansion \
+  --skip_thoughts_model=${CHECKPOINT_PATH} \
+  --skip_thoughts_vocab=${SKIP_THOUGHTS_VOCAB} \
+  --word2vec_model=${WORD2VEC_MODEL} \
+  --output_dir=${EXP_VOCAB_DIR}
+```
+
+## Evaluating a Model
+
+### Overview
+
+The model can be evaluated using the benchmark tasks described in the
+[Skip-Thought Vectors](https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf)
+paper. The following tasks are suported (refer to the paper for full details):
+
+ * **SICK** semantic relatedness task.
+ * **MSRP** (Microsoft Research Paraphrase Corpus) paraphrase detection task.
+ * Binary classification tasks:
+   * **MR** movie review sentiment task.
+   * **CR** customer product review task.
+   * **SUBJ** subjectivity/objectivity task.
+   * **MPQA** opinion polarity task.
+   * **TREC** question-type classification task.
+
+### Preparation
+
+You will need to clone or download the
+[skip-thoughts GitHub repository](https://github.com/ryankiros/skip-thoughts) by
+[ryankiros](https://github.com/ryankiros) (the first author of the Skip-Thoughts
+paper):
+
+```shell
+# Folder to clone the repository to.
+ST_KIROS_DIR="${HOME}/skip_thoughts/skipthoughts_kiros"
+
+# Clone the repository.
+git clone git@github.com:ryankiros/skip-thoughts.git "${ST_KIROS_DIR}/skipthoughts"
+
+# Make the package importable.
+export PYTHONPATH="${ST_KIROS_DIR}/:${PYTHONPATH}"
+```
+
+You will also need to download the data needed for each evaluation task. See the
+instructions [here](https://github.com/ryankiros/skip-thoughts).
+
+For example, the CR (customer review) dataset is found [here](http://nlp.stanford.edu/~sidaw/home/projects:nbsvm). For this task we want the
+files `custrev.pos` and `custrev.neg`.
+
+### Run the Evaluation Tasks
+
+In the following example we will evaluate a unidirectional model ("uni-skip" in
+the paper) on the CR task. To use a bidirectional model ("bi-skip" in the
+paper),  simply pass the flags `--bi_vocab_file`, `--bi_embeddings_file` and
+`--bi_checkpoint_path` instead. To use the "combine-skip" model described in the
+paper you will need to pass both the unidirectional and bidirectional flags.
+
+```shell
+# Path to checkpoint file or a directory containing checkpoint files (the script
+# will select the most recent).
+CHECKPOINT_PATH="${HOME}/skip_thoughts/model/train"
+
+# Vocabulary file generated by the vocabulary expansion script.
+VOCAB_FILE="${HOME}/skip_thoughts/exp_vocab/vocab.txt"
+
+# Embeddings file generated by the vocabulary expansion script.
+EMBEDDINGS_FILE="${HOME}/skip_thoughts/exp_vocab/embeddings.npy"
+
+# Directory containing files custrev.pos and custrev.neg.
+EVAL_DATA_DIR="${HOME}/skip_thoughts/eval_data"
+
+# Build the evaluation script.
+bazel build -c opt skip_thoughts/evaluate
+
+# Run the evaluation script.
+bazel-bin/skip_thoughts/evaluate \
+  --eval_task=CR \
+  --data_dir=${EVAL_DATA_DIR} \
+  --uni_vocab_file=${VOCAB_FILE} \
+  --uni_embeddings_file=${EMBEDDINGS_FILE} \
+  --uni_checkpoint_path=${CHECKPOINT_PATH}
+```
+
+Output:
+
+```python
+[0.82539682539682535, 0.84084880636604775, 0.83023872679045096,
+ 0.86206896551724133, 0.83554376657824936, 0.85676392572944293,
+ 0.84084880636604775, 0.83023872679045096, 0.85145888594164454,
+ 0.82758620689655171]
+```
+
+The output is a list of accuracies of 10 cross-validation classification models.
+To get a single number, simply take the average:
+
+```python
+ipython  # Launch iPython.
+
+In [0]:
+import numpy as np
+np.mean([0.82539682539682535, 0.84084880636604775, 0.83023872679045096,
+         0.86206896551724133, 0.83554376657824936, 0.85676392572944293,
+         0.84084880636604775, 0.83023872679045096, 0.85145888594164454,
+         0.82758620689655171])
+
+Out [0]: 0.84009936423729525
+```
+
+## Encoding Sentences
+
+In this example we will encode data from the
+[movie review dataset](https://www.cs.cornell.edu/people/pabo/movie-review-data/)
+(specifically the [sentence polarity dataset v1.0](https://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz)).
+
+```python
+ipython  # Launch iPython.
+
+In [0]:
+
+# Imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+import os.path
+import scipy.spatial.distance as sd
+from skip_thoughts import configuration
+from skip_thoughts import encoder_manager
+
+In [1]:
+# Set paths to the model.
+VOCAB_FILE = "/path/to/vocab.txt"
+EMBEDDING_MATRIX_FILE = "/path/to/embeddings.npy"
+CHECKPOINT_PATH = "/path/to/model.ckpt-9999"
+# The following directory should contain files rt-polarity.neg and
+# rt-polarity.pos.
+MR_DATA_DIR = "/dir/containing/mr/data"
+
+In [2]:
+# Set up the encoder. Here we are using a single unidirectional model.
+# To use a bidirectional model as well, call load_model() again with
+# configuration.model_config(bidirectional_encoder=True) and paths to the
+# bidirectional model's files. The encoder will use the concatenation of
+# all loaded models.
+encoder = encoder_manager.EncoderManager()
+encoder.load_model(configuration.model_config(),
+                   vocabulary_file=VOCAB_FILE,
+                   embedding_matrix_file=EMBEDDING_MATRIX_FILE,
+                   checkpoint_path=CHECKPOINT_PATH)
+
+In [3]:
+# Load the movie review dataset.
+data = []
+with open(os.path.join(MR_DATA_DIR, 'rt-polarity.neg'), 'rb') as f:
+  data.extend([line.decode('latin-1').strip() for line in f])
+with open(os.path.join(MR_DATA_DIR, 'rt-polarity.pos'), 'rb') as f:
+  data.extend([line.decode('latin-1').strip() for line in f])
+
+In [4]:
+# Generate Skip-Thought Vectors for each sentence in the dataset.
+encodings = encoder.encode(data)
+
+In [5]:
+# Define a helper function to generate nearest neighbors.
+def get_nn(ind, num=10):
+  encoding = encodings[ind]
+  scores = sd.cdist([encoding], encodings, "cosine")[0]
+  sorted_ids = np.argsort(scores)
+  print("Sentence:")
+  print("", data[ind])
+  print("\nNearest neighbors:")
+  for i in range(1, num + 1):
+    print(" %d. %s (%.3f)" %
+          (i, data[sorted_ids[i]], scores[sorted_ids[i]]))
+
+In [6]:
+# Compute nearest neighbors of the first sentence in the dataset.
+get_nn(0)
+```
+
+Output:
+
+```
+Sentence:
+ simplistic , silly and tedious .
+
+Nearest neighbors:
+ 1. trite , banal , cliched , mostly inoffensive . (0.247)
+ 2. banal and predictable . (0.253)
+ 3. witless , pointless , tasteless and idiotic . (0.272)
+ 4. loud , silly , stupid and pointless . (0.295)
+ 5. grating and tedious . (0.299)
+ 6. idiotic and ugly . (0.330)
+ 7. black-and-white and unrealistic . (0.335)
+ 8. hopelessly inane , humorless and under-inspired . (0.335)
+ 9. shallow , noisy and pretentious . (0.340)
+ 10. . . . unlikable , uninteresting , unfunny , and completely , utterly inept . (0.346)
+```

+ 0 - 0
skip_thoughts/WORKSPACE


+ 94 - 0
skip_thoughts/skip_thoughts/BUILD

@@ -0,0 +1,94 @@
+package(default_visibility = [":internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package_group(
+    name = "internal",
+    packages = [
+        "//skip_thoughts/...",
+    ],
+)
+
+py_library(
+    name = "configuration",
+    srcs = ["configuration.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "skip_thoughts_model",
+    srcs = ["skip_thoughts_model.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//skip_thoughts/ops:gru_cell",
+        "//skip_thoughts/ops:input_ops",
+    ],
+)
+
+py_test(
+    name = "skip_thoughts_model_test",
+    size = "large",
+    srcs = ["skip_thoughts_model_test.py"],
+    deps = [
+        ":configuration",
+        ":skip_thoughts_model",
+    ],
+)
+
+py_binary(
+    name = "train",
+    srcs = ["train.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":configuration",
+        ":skip_thoughts_model",
+    ],
+)
+
+py_binary(
+    name = "track_perplexity",
+    srcs = ["track_perplexity.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":configuration",
+        ":skip_thoughts_model",
+    ],
+)
+
+py_binary(
+    name = "vocabulary_expansion",
+    srcs = ["vocabulary_expansion.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "skip_thoughts_encoder",
+    srcs = ["skip_thoughts_encoder.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":skip_thoughts_model",
+        "//skip_thoughts/data:special_words",
+    ],
+)
+
+py_library(
+    name = "encoder_manager",
+    srcs = ["encoder_manager.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":skip_thoughts_encoder",
+    ],
+)
+
+py_binary(
+    name = "evaluate",
+    srcs = ["evaluate.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":encoder_manager",
+        "//skip_thoughts:configuration",
+    ],
+)
+

+ 0 - 0
skip_thoughts/skip_thoughts/__init__.py


+ 110 - 0
skip_thoughts/skip_thoughts/configuration.py

@@ -0,0 +1,110 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Default configuration for model architecture and training."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+class _HParams(object):
+  """Wrapper for configuration parameters."""
+  pass
+
+
+def model_config(input_file_pattern=None,
+                 input_queue_capacity=640000,
+                 num_input_reader_threads=1,
+                 shuffle_input_data=True,
+                 uniform_init_scale=0.1,
+                 vocab_size=20000,
+                 batch_size=128,
+                 word_embedding_dim=620,
+                 bidirectional_encoder=False,
+                 encoder_dim=2400):
+  """Creates a model configuration object.
+
+  Args:
+    input_file_pattern: File pattern of sharded TFRecord files containing
+      tf.Example protobufs.
+    input_queue_capacity: Number of examples to keep in the input queue.
+    num_input_reader_threads: Number of threads for prefetching input
+      tf.Examples.
+    shuffle_input_data: Whether to shuffle the input data.
+    uniform_init_scale: Scale of random uniform initializer.
+    vocab_size: Number of unique words in the vocab.
+    batch_size: Batch size (training and evaluation only).
+    word_embedding_dim: Word embedding dimension.
+    bidirectional_encoder: Whether to use a bidirectional or unidirectional
+      encoder RNN.
+    encoder_dim: Number of output dimensions of the sentence encoder.
+
+  Returns:
+    An object containing model configuration parameters.
+  """
+  config = _HParams()
+  config.input_file_pattern = input_file_pattern
+  config.input_queue_capacity = input_queue_capacity
+  config.num_input_reader_threads = num_input_reader_threads
+  config.shuffle_input_data = shuffle_input_data
+  config.uniform_init_scale = uniform_init_scale
+  config.vocab_size = vocab_size
+  config.batch_size = batch_size
+  config.word_embedding_dim = word_embedding_dim
+  config.bidirectional_encoder = bidirectional_encoder
+  config.encoder_dim = encoder_dim
+  return config
+
+
+def training_config(learning_rate=0.0008,
+                    learning_rate_decay_factor=0.5,
+                    learning_rate_decay_steps=400000,
+                    number_of_steps=500000,
+                    clip_gradient_norm=5.0,
+                    save_model_secs=600,
+                    save_summaries_secs=600):
+  """Creates a training configuration object.
+
+  Args:
+    learning_rate: Initial learning rate.
+    learning_rate_decay_factor: If > 0, the learning rate decay factor.
+    learning_rate_decay_steps: The number of steps before the learning rate
+      decays by learning_rate_decay_factor.
+    number_of_steps: The total number of training steps to run. Passing None
+      will cause the training script to run indefinitely.
+    clip_gradient_norm: If not None, then clip gradients to this value.
+    save_model_secs: How often (in seconds) to save model checkpoints.
+    save_summaries_secs: How often (in seconds) to save model summaries.
+
+  Returns:
+    An object containing training configuration parameters.
+
+  Raises:
+    ValueError: If learning_rate_decay_factor is set and
+      learning_rate_decay_steps is unset.
+  """
+  if learning_rate_decay_factor and not learning_rate_decay_steps:
+    raise ValueError(
+        "learning_rate_decay_factor requires learning_rate_decay_steps.")
+
+  config = _HParams()
+  config.learning_rate = learning_rate
+  config.learning_rate_decay_factor = learning_rate_decay_factor
+  config.learning_rate_decay_steps = learning_rate_decay_steps
+  config.number_of_steps = number_of_steps
+  config.clip_gradient_norm = clip_gradient_norm
+  config.save_model_secs = save_model_secs
+  config.save_summaries_secs = save_summaries_secs
+  return config

+ 23 - 0
skip_thoughts/skip_thoughts/data/BUILD

@@ -0,0 +1,23 @@
+package(default_visibility = ["//skip_thoughts:internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+    name = "special_words",
+    srcs = ["special_words.py"],
+    srcs_version = "PY2AND3",
+    deps = [],
+)
+
+py_binary(
+    name = "preprocess_dataset",
+    srcs = [
+        "preprocess_dataset.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":special_words",
+    ],
+)

+ 0 - 0
skip_thoughts/skip_thoughts/data/__init__.py


+ 301 - 0
skip_thoughts/skip_thoughts/data/preprocess_dataset.py

@@ -0,0 +1,301 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converts a set of text files to TFRecord format with Example protos.
+
+Each Example proto in the output contains the following fields:
+
+  decode_pre: list of int64 ids corresponding to the "previous" sentence.
+  encode: list of int64 ids corresponding to the "current" sentence.
+  decode_post: list of int64 ids corresponding to the "post" sentence.
+
+In addition, the following files are generated:
+
+  vocab.txt: List of "<word> <id>" pairs, where <id> is the integer
+             encoding of <word> in the Example protos.
+  word_counts.txt: List of "<word> <count>" pairs, where <count> is the number
+                   of occurrences of <word> in the input files.
+
+The vocabulary of word ids is constructed from the top --num_words by word
+count. All other words get the <unk> word id.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+
+
+import numpy as np
+import tensorflow as tf
+
+from skip_thoughts.data import special_words
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_string("input_files", None,
+                       "Comma-separated list of globs matching the input "
+                       "files. The format of the input files is assumed to be "
+                       "a list of newline-separated sentences, where each "
+                       "sentence is already tokenized.")
+
+tf.flags.DEFINE_string("vocab_file", "",
+                       "(Optional) existing vocab file. Otherwise, a new vocab "
+                       "file is created and written to the output directory. "
+                       "The file format is a list of newline-separated words, "
+                       "where the word id is the corresponding 0-based index "
+                       "in the file.")
+
+tf.flags.DEFINE_string("output_dir", None, "Output directory.")
+
+tf.flags.DEFINE_integer("train_output_shards", 100,
+                        "Number of output shards for the training set.")
+
+tf.flags.DEFINE_integer("validation_output_shards", 1,
+                        "Number of output shards for the validation set.")
+
+tf.flags.DEFINE_integer("num_validation_sentences", 50000,
+                        "Number of output shards for the validation set.")
+
+tf.flags.DEFINE_integer("num_words", 20000,
+                        "Number of words to include in the output.")
+
+tf.flags.DEFINE_integer("max_sentences", 0,
+                        "If > 0, the maximum number of sentences to output.")
+
+tf.flags.DEFINE_integer("max_sentence_length", 30,
+                        "If > 0, exclude sentences whose encode, decode_pre OR"
+                        "decode_post sentence exceeds this length.")
+
+tf.flags.DEFINE_boolean("add_eos", True,
+                        "Whether to add end-of-sentence ids to the output.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def _build_vocabulary(input_files):
+  """Loads or builds the model vocabulary.
+
+  Args:
+    input_files: List of pre-tokenized input .txt files.
+
+  Returns:
+    vocab: A dictionary of word to id.
+  """
+  if FLAGS.vocab_file:
+    tf.logging.info("Loading existing vocab file.")
+    vocab = collections.OrderedDict()
+    with tf.gfile.GFile(FLAGS.vocab_file, mode="r") as f:
+      for i, line in enumerate(f):
+        word = line.decode("utf-8").strip()
+        assert word not in vocab, "Attempting to add word twice: %s" % word
+        vocab[word] = i
+    tf.logging.info("Read vocab of size %d from %s",
+                    len(vocab), FLAGS.vocab_file)
+    return vocab
+
+  tf.logging.info("Creating vocabulary.")
+  num = 0
+  wordcount = collections.Counter()
+  for input_file in input_files:
+    tf.logging.info("Processing file: %s", input_file)
+    for sentence in tf.gfile.FastGFile(input_file):
+      wordcount.update(sentence.split())
+
+      num += 1
+      if num % 1000000 == 0:
+        tf.logging.info("Processed %d sentences", num)
+
+  tf.logging.info("Processed %d sentences total", num)
+
+  words = wordcount.keys()
+  freqs = wordcount.values()
+  sorted_indices = np.argsort(freqs)[::-1]
+
+  vocab = collections.OrderedDict()
+  vocab[special_words.EOS] = special_words.EOS_ID
+  vocab[special_words.UNK] = special_words.UNK_ID
+  for w_id, w_index in enumerate(sorted_indices[0:FLAGS.num_words - 2]):
+    vocab[words[w_index]] = w_id + 2  # 0: EOS, 1: UNK.
+
+  tf.logging.info("Created vocab with %d words", len(vocab))
+
+  vocab_file = os.path.join(FLAGS.output_dir, "vocab.txt")
+  with tf.gfile.FastGFile(vocab_file, "w") as f:
+    f.write("\n".join(vocab.keys()))
+  tf.logging.info("Wrote vocab file to %s", vocab_file)
+
+  word_counts_file = os.path.join(FLAGS.output_dir, "word_counts.txt")
+  with tf.gfile.FastGFile(word_counts_file, "w") as f:
+    for i in sorted_indices:
+      f.write("%s %d\n" % (words[i], freqs[i]))
+  tf.logging.info("Wrote word counts file to %s", word_counts_file)
+
+  return vocab
+
+
+def _int64_feature(value):
+  """Helper for creating an Int64 Feature."""
+  return tf.train.Feature(int64_list=tf.train.Int64List(
+      value=[int(v) for v in value]))
+
+
+def _sentence_to_ids(sentence, vocab):
+  """Helper for converting a sentence (list of words) to a list of ids."""
+  ids = [vocab.get(w, special_words.UNK_ID) for w in sentence]
+  if FLAGS.add_eos:
+    ids.append(special_words.EOS_ID)
+  return ids
+
+
+def _create_serialized_example(predecessor, current, successor, vocab):
+  """Helper for creating a serialized Example proto."""
+  example = tf.train.Example(features=tf.train.Features(feature={
+      "decode_pre": _int64_feature(_sentence_to_ids(predecessor, vocab)),
+      "encode": _int64_feature(_sentence_to_ids(current, vocab)),
+      "decode_post": _int64_feature(_sentence_to_ids(successor, vocab)),
+  }))
+
+  return example.SerializeToString()
+
+
+def _process_input_file(filename, vocab, stats):
+  """Processes the sentences in an input file.
+
+  Args:
+    filename: Path to a pre-tokenized input .txt file.
+    vocab: A dictionary of word to id.
+    stats: A Counter object for statistics.
+
+  Returns:
+    processed: A list of serialized Example protos
+  """
+  tf.logging.info("Processing input file: %s", filename)
+  processed = []
+
+  predecessor = None  # Predecessor sentence (list of words).
+  current = None  # Current sentence (list of words).
+  successor = None  # Successor sentence (list of words).
+
+  for successor_str in tf.gfile.FastGFile(filename):
+    stats.update(["sentences_seen"])
+    successor = successor_str.split()
+
+    # The first 2 sentences per file will be skipped.
+    if predecessor and current and successor:
+      stats.update(["sentences_considered"])
+
+      # Note that we are going to insert <EOS> later, so we only allow
+      # sentences with strictly less than max_sentence_length to pass.
+      if FLAGS.max_sentence_length and (
+          len(predecessor) >= FLAGS.max_sentence_length or len(current) >=
+          FLAGS.max_sentence_length or len(successor) >=
+          FLAGS.max_sentence_length):
+        stats.update(["sentences_too_long"])
+      else:
+        serialized = _create_serialized_example(predecessor, current, successor,
+                                                vocab)
+        processed.append(serialized)
+        stats.update(["sentences_output"])
+
+    predecessor = current
+    current = successor
+
+    sentences_seen = stats["sentences_seen"]
+    sentences_output = stats["sentences_output"]
+    if sentences_seen and sentences_seen % 100000 == 0:
+      tf.logging.info("Processed %d sentences (%d output)", sentences_seen,
+                      sentences_output)
+    if FLAGS.max_sentences and sentences_output >= FLAGS.max_sentences:
+      break
+
+  tf.logging.info("Completed processing file %s", filename)
+  return processed
+
+
+def _write_shard(filename, dataset, indices):
+  """Writes a TFRecord shard."""
+  with tf.python_io.TFRecordWriter(filename) as writer:
+    for j in indices:
+      writer.write(dataset[j])
+
+
+def _write_dataset(name, dataset, indices, num_shards):
+  """Writes a sharded TFRecord dataset.
+
+  Args:
+    name: Name of the dataset (e.g. "train").
+    dataset: List of serialized Example protos.
+    indices: List of indices of 'dataset' to be written.
+    num_shards: The number of output shards.
+  """
+  tf.logging.info("Writing dataset %s", name)
+  borders = np.int32(np.linspace(0, len(indices), num_shards + 1))
+  for i in range(num_shards):
+    filename = os.path.join(FLAGS.output_dir, "%s-%.5d-of-%.5d" % (name, i,
+                                                                   num_shards))
+    shard_indices = indices[borders[i]:borders[i + 1]]
+    _write_shard(filename, dataset, shard_indices)
+    tf.logging.info("Wrote dataset indices [%d, %d) to output shard %s",
+                    borders[i], borders[i + 1], filename)
+  tf.logging.info("Finished writing %d sentences in dataset %s.",
+                  len(indices), name)
+
+
+def main(unused_argv):
+  if not FLAGS.input_files:
+    raise ValueError("--input_files is required.")
+  if not FLAGS.output_dir:
+    raise ValueError("--output_dir is required.")
+
+  if not tf.gfile.IsDirectory(FLAGS.output_dir):
+    tf.gfile.MakeDirs(FLAGS.output_dir)
+
+  input_files = []
+  for pattern in FLAGS.input_files.split(","):
+    match = tf.gfile.Glob(FLAGS.input_files)
+    if not match:
+      raise ValueError("Found no files matching %s" % pattern)
+    input_files.extend(match)
+  tf.logging.info("Found %d input files.", len(input_files))
+
+  vocab = _build_vocabulary(input_files)
+
+  tf.logging.info("Generating dataset.")
+  stats = collections.Counter()
+  dataset = []
+  for filename in input_files:
+    dataset.extend(_process_input_file(filename, vocab, stats))
+    if FLAGS.max_sentences and stats["sentences_output"] >= FLAGS.max_sentences:
+      break
+
+  tf.logging.info("Generated dataset with %d sentences.", len(dataset))
+  for k, v in stats.items():
+    tf.logging.info("%s: %d", k, v)
+
+  tf.logging.info("Shuffling dataset.")
+  np.random.seed(123)
+  shuffled_indices = np.random.permutation(len(dataset))
+  val_indices = shuffled_indices[:FLAGS.num_validation_sentences]
+  train_indices = shuffled_indices[FLAGS.num_validation_sentences:]
+
+  _write_dataset("train", dataset, train_indices, FLAGS.train_output_shards)
+  _write_dataset("validation", dataset, val_indices,
+                 FLAGS.validation_output_shards)
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 27 - 0
skip_thoughts/skip_thoughts/data/special_words.py

@@ -0,0 +1,27 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Special word constants.
+
+NOTE: The ids of the EOS and UNK constants should not be modified. It is assumed
+that these always occupy the first two ids.
+"""
+
+# End of sentence.
+EOS = "<eos>"
+EOS_ID = 0
+
+# Unknown.
+UNK = "<unk>"
+UNK_ID = 1

+ 134 - 0
skip_thoughts/skip_thoughts/encoder_manager.py

@@ -0,0 +1,134 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Manager class for loading and encoding with multiple skip-thoughts models.
+
+If multiple models are loaded at once then the encode() function returns the
+concatenation of the outputs of each model.
+
+Example usage:
+  manager = EncoderManager()
+  manager.load_model(model_config_1, vocabulary_file_1, embedding_matrix_file_1,
+                     checkpoint_path_1)
+  manager.load_model(model_config_2, vocabulary_file_2, embedding_matrix_file_2,
+                     checkpoint_path_2)
+  encodings = manager.encode(data)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+
+import numpy as np
+import tensorflow as tf
+
+from skip_thoughts import skip_thoughts_encoder
+
+
+class EncoderManager(object):
+  """Manager class for loading and encoding with skip-thoughts models."""
+
+  def __init__(self):
+    self.encoders = []
+    self.sessions = []
+
+  def load_model(self, model_config, vocabulary_file, embedding_matrix_file,
+                 checkpoint_path):
+    """Loads a skip-thoughts model.
+
+    Args:
+      model_config: Object containing parameters for building the model.
+      vocabulary_file: Path to vocabulary file containing a list of newline-
+        separated words where the word id is the corresponding 0-based index in
+        the file.
+      embedding_matrix_file: Path to a serialized numpy array of shape
+        [vocab_size, embedding_dim].
+      checkpoint_path: SkipThoughtsModel checkpoint file or a directory
+        containing a checkpoint file.
+    """
+    tf.logging.info("Reading vocabulary from %s", vocabulary_file)
+    with tf.gfile.GFile(vocabulary_file, mode="r") as f:
+      lines = list(f.readlines())
+    reverse_vocab = [line.decode("utf-8").strip() for line in lines]
+    tf.logging.info("Loaded vocabulary with %d words.", len(reverse_vocab))
+
+    tf.logging.info("Loading embedding matrix from %s", embedding_matrix_file)
+    # Note: tf.gfile.GFile doesn't work here because np.load() calls f.seek()
+    # with 3 arguments.
+    with open(embedding_matrix_file, "r") as f:
+      embedding_matrix = np.load(f)
+    tf.logging.info("Loaded embedding matrix with shape %s",
+                    embedding_matrix.shape)
+
+    word_embeddings = collections.OrderedDict(
+        zip(reverse_vocab, embedding_matrix))
+
+    g = tf.Graph()
+    with g.as_default():
+      encoder = skip_thoughts_encoder.SkipThoughtsEncoder(word_embeddings)
+      restore_model = encoder.build_graph_from_config(model_config,
+                                                      checkpoint_path)
+
+    sess = tf.Session(graph=g)
+    restore_model(sess)
+
+    self.encoders.append(encoder)
+    self.sessions.append(sess)
+
+  def encode(self,
+             data,
+             use_norm=True,
+             verbose=False,
+             batch_size=128,
+             use_eos=False):
+    """Encodes a sequence of sentences as skip-thought vectors.
+
+    Args:
+      data: A list of input strings.
+      use_norm: If True, normalize output skip-thought vectors to unit L2 norm.
+      verbose: Whether to log every batch.
+      batch_size: Batch size for the RNN encoders.
+      use_eos: If True, append the end-of-sentence word to each input sentence.
+
+    Returns:
+      thought_vectors: A list of numpy arrays corresponding to 'data'.
+
+    Raises:
+      ValueError: If called before calling load_encoder.
+    """
+    if not self.encoders:
+      raise ValueError(
+          "Must call load_model at least once before calling encode.")
+
+    encoded = []
+    for encoder, sess in zip(self.encoders, self.sessions):
+      encoded.append(
+          np.array(
+              encoder.encode(
+                  sess,
+                  data,
+                  use_norm=use_norm,
+                  verbose=verbose,
+                  batch_size=batch_size,
+                  use_eos=use_eos)))
+
+    return np.concatenate(encoded, axis=1)
+
+  def close(self):
+    """Closes the active TensorFlow Sessions."""
+    for sess in self.sessions:
+      sess.close()

+ 117 - 0
skip_thoughts/skip_thoughts/evaluate.py

@@ -0,0 +1,117 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to evaluate a skip-thoughts model.
+
+This script can evaluate a model with a unidirectional encoder ("uni-skip" in
+the paper); or a model with a bidirectional encoder ("bi-skip"); or the
+combination of a model with a unidirectional encoder and a model with a
+bidirectional encoder ("combine-skip").
+
+The uni-skip model (if it exists) is specified by the flags
+--uni_vocab_file, --uni_embeddings_file, --uni_checkpoint_path.
+
+The bi-skip model (if it exists) is specified by the flags
+--bi_vocab_file, --bi_embeddings_path, --bi_checkpoint_path.
+
+The evaluation tasks have different running times. SICK may take 5-10 minutes.
+MSRP, TREC and CR may take 20-60 minutes. SUBJ, MPQA and MR may take 2+ hours.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from skipthoughts import eval_classification
+from skipthoughts import eval_msrp
+from skipthoughts import eval_sick
+from skipthoughts import eval_trec
+import tensorflow as tf
+
+from skip_thoughts import configuration
+from skip_thoughts import encoder_manager
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_string("eval_task", "CR",
+                       "Name of the evaluation task to run. Available tasks: "
+                       "MR, CR, SUBJ, MPQA, SICK, MSRP, TREC.")
+
+tf.flags.DEFINE_string("data_dir", None, "Directory containing training data.")
+
+tf.flags.DEFINE_string("uni_vocab_file", None,
+                       "Path to vocabulary file containing a list of newline-"
+                       "separated words where the word id is the "
+                       "corresponding 0-based index in the file.")
+tf.flags.DEFINE_string("bi_vocab_file", None,
+                       "Path to vocabulary file containing a list of newline-"
+                       "separated words where the word id is the "
+                       "corresponding 0-based index in the file.")
+
+tf.flags.DEFINE_string("uni_embeddings_file", None,
+                       "Path to serialized numpy array of shape "
+                       "[vocab_size, embedding_dim].")
+tf.flags.DEFINE_string("bi_embeddings_file", None,
+                       "Path to serialized numpy array of shape "
+                       "[vocab_size, embedding_dim].")
+
+tf.flags.DEFINE_string("uni_checkpoint_path", None,
+                       "Checkpoint file or directory containing a checkpoint "
+                       "file.")
+tf.flags.DEFINE_string("bi_checkpoint_path", None,
+                       "Checkpoint file or directory containing a checkpoint "
+                       "file.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def main(unused_argv):
+  if not FLAGS.data_dir:
+    raise ValueError("--data_dir is required.")
+
+  encoder = encoder_manager.EncoderManager()
+
+  # Maybe load unidirectional encoder.
+  if FLAGS.uni_checkpoint_path:
+    print("Loading unidirectional model...")
+    uni_config = configuration.model_config()
+    encoder.load_model(uni_config, FLAGS.uni_vocab_file,
+                       FLAGS.uni_embeddings_file, FLAGS.uni_checkpoint_path)
+
+  # Maybe load bidirectional encoder.
+  if FLAGS.bi_checkpoint_path:
+    print("Loading bidirectional model...")
+    bi_config = configuration.model_config(bidirectional_encoder=True)
+    encoder.load_model(bi_config, FLAGS.bi_vocab_file, FLAGS.bi_embeddings_file,
+                       FLAGS.bi_checkpoint_path)
+
+  if FLAGS.eval_task in ["MR", "CR", "SUBJ", "MPQA"]:
+    eval_classification.eval_nested_kfold(
+        encoder, FLAGS.eval_task, FLAGS.data_dir, use_nb=False)
+  elif FLAGS.eval_task == "SICK":
+    eval_sick.evaluate(encoder, evaltest=True, loc=FLAGS.data_dir)
+  elif FLAGS.eval_task == "MSRP":
+    eval_msrp.evaluate(
+        encoder, evalcv=True, evaltest=True, use_feats=True, loc=FLAGS.data_dir)
+  elif FLAGS.eval_task == "TREC":
+    eval_trec.evaluate(encoder, evalcv=True, evaltest=True, loc=FLAGS.data_dir)
+  else:
+    raise ValueError("Unrecognized eval_task: %s" % FLAGS.eval_task)
+
+  encoder.close()
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 17 - 0
skip_thoughts/skip_thoughts/ops/BUILD

@@ -0,0 +1,17 @@
+package(default_visibility = ["//skip_thoughts:internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+    name = "input_ops",
+    srcs = ["input_ops.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "gru_cell",
+    srcs = ["gru_cell.py"],
+    srcs_version = "PY2AND3",
+)

+ 0 - 0
skip_thoughts/skip_thoughts/ops/__init__.py


+ 134 - 0
skip_thoughts/skip_thoughts/ops/gru_cell.py

@@ -0,0 +1,134 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""GRU cell implementation for the skip-thought vectors model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+_layer_norm = tf.contrib.layers.layer_norm
+
+
+class LayerNormGRUCell(tf.contrib.rnn.RNNCell):
+  """GRU cell with layer normalization.
+
+  The layer normalization implementation is based on:
+
+    https://arxiv.org/abs/1607.06450.
+
+  "Layer Normalization"
+  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
+  """
+
+  def __init__(self,
+               num_units,
+               w_initializer,
+               u_initializer,
+               b_initializer,
+               activation=tf.nn.tanh):
+    """Initializes the cell.
+
+    Args:
+      num_units: Number of cell units.
+      w_initializer: Initializer for the "W" (input) parameter matrices.
+      u_initializer: Initializer for the "U" (recurrent) parameter matrices.
+      b_initializer: Initializer for the "b" (bias) parameter vectors.
+      activation: Cell activation function.
+    """
+    self._num_units = num_units
+    self._w_initializer = w_initializer
+    self._u_initializer = u_initializer
+    self._b_initializer = b_initializer
+    self._activation = activation
+
+  @property
+  def state_size(self):
+    return self._num_units
+
+  @property
+  def output_size(self):
+    return self._num_units
+
+  def _w_h_initializer(self):
+    """Returns an initializer for the "W_h" parameter matrix.
+
+    See equation (23) in the paper. The "W_h" parameter matrix is the
+    concatenation of two parameter submatrices. The matrix returned is
+    [U_z, U_r].
+
+    Returns:
+      A Tensor with shape [num_units, 2 * num_units] as described above.
+    """
+
+    def _initializer(shape, dtype=tf.float32, partition_info=None):
+      num_units = self._num_units
+      assert shape == [num_units, 2 * num_units]
+      u_z = self._u_initializer([num_units, num_units], dtype, partition_info)
+      u_r = self._u_initializer([num_units, num_units], dtype, partition_info)
+      return tf.concat([u_z, u_r], 1)
+
+    return _initializer
+
+  def _w_x_initializer(self, input_dim):
+    """Returns an initializer for the "W_x" parameter matrix.
+
+    See equation (23) in the paper. The "W_x" parameter matrix is the
+    concatenation of two parameter submatrices. The matrix returned is
+    [W_z, W_r].
+
+    Args:
+      input_dim: The dimension of the cell inputs.
+
+    Returns:
+      A Tensor with shape [input_dim, 2 * num_units] as described above.
+    """
+
+    def _initializer(shape, dtype=tf.float32, partition_info=None):
+      num_units = self._num_units
+      assert shape == [input_dim, 2 * num_units]
+      w_z = self._w_initializer([input_dim, num_units], dtype, partition_info)
+      w_r = self._w_initializer([input_dim, num_units], dtype, partition_info)
+      return tf.concat([w_z, w_r], 1)
+
+    return _initializer
+
+  def __call__(self, inputs, state, scope=None):
+    """GRU cell with layer normalization."""
+    input_dim = inputs.get_shape().as_list()[1]
+    num_units = self._num_units
+
+    with tf.variable_scope(scope or "gru_cell"):
+      with tf.variable_scope("gates"):
+        w_h = tf.get_variable(
+            "w_h", [num_units, 2 * num_units],
+            initializer=self._w_h_initializer())
+        w_x = tf.get_variable(
+            "w_x", [input_dim, 2 * num_units],
+            initializer=self._w_x_initializer(input_dim))
+        z_and_r = (_layer_norm(tf.matmul(state, w_h), scope="layer_norm/w_h") +
+                   _layer_norm(tf.matmul(inputs, w_x), scope="layer_norm/w_x"))
+        z, r = tf.split(tf.sigmoid(z_and_r), 2, 1)
+      with tf.variable_scope("candidate"):
+        w = tf.get_variable(
+            "w", [input_dim, num_units], initializer=self._w_initializer)
+        u = tf.get_variable(
+            "u", [num_units, num_units], initializer=self._u_initializer)
+        h_hat = (r * _layer_norm(tf.matmul(state, u), scope="layer_norm/u") +
+                 _layer_norm(tf.matmul(inputs, w), scope="layer_norm/w"))
+      new_h = (1 - z) * state + z * self._activation(h_hat)
+    return new_h, new_h

+ 118 - 0
skip_thoughts/skip_thoughts/ops/input_ops.py

@@ -0,0 +1,118 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Input ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+
+import tensorflow as tf
+
+# A SentenceBatch is a pair of Tensors:
+#  ids: Batch of input sentences represented as sequences of word ids: an int64
+#    Tensor with shape [batch_size, padded_length].
+#  mask: Boolean mask distinguishing real words (1) from padded words (0): an
+#    int32 Tensor with shape [batch_size, padded_length].
+SentenceBatch = collections.namedtuple("SentenceBatch", ("ids", "mask"))
+
+
+def parse_example_batch(serialized):
+  """Parses a batch of tf.Example protos.
+
+  Args:
+    serialized: A 1-D string Tensor; a batch of serialized tf.Example protos.
+  Returns:
+    encode: A SentenceBatch of encode sentences.
+    decode_pre: A SentenceBatch of "previous" sentences to decode.
+    decode_post: A SentenceBatch of "post" sentences to decode.
+  """
+  features = tf.parse_example(
+      serialized,
+      features={
+          "encode": tf.VarLenFeature(dtype=tf.int64),
+          "decode_pre": tf.VarLenFeature(dtype=tf.int64),
+          "decode_post": tf.VarLenFeature(dtype=tf.int64),
+      })
+
+  def _sparse_to_batch(sparse):
+    ids = tf.sparse_tensor_to_dense(sparse)  # Padding with zeroes.
+    mask = tf.sparse_to_dense(sparse.indices, sparse.dense_shape,
+                              tf.ones_like(sparse.values, dtype=tf.int32))
+    return SentenceBatch(ids=ids, mask=mask)
+
+  output_names = ("encode", "decode_pre", "decode_post")
+  return tuple(_sparse_to_batch(features[x]) for x in output_names)
+
+
+def prefetch_input_data(reader,
+                        file_pattern,
+                        shuffle,
+                        capacity,
+                        num_reader_threads=1):
+  """Prefetches string values from disk into an input queue.
+
+  Args:
+    reader: Instance of tf.ReaderBase.
+    file_pattern: Comma-separated list of file patterns (e.g.
+        "/tmp/train_data-?????-of-00100", where '?' acts as a wildcard that
+        matches any character).
+    shuffle: Boolean; whether to randomly shuffle the input data.
+    capacity: Queue capacity (number of records).
+    num_reader_threads: Number of reader threads feeding into the queue.
+
+  Returns:
+    A Queue containing prefetched string values.
+  """
+  data_files = []
+  for pattern in file_pattern.split(","):
+    data_files.extend(tf.gfile.Glob(pattern))
+  if not data_files:
+    tf.logging.fatal("Found no input files matching %s", file_pattern)
+  else:
+    tf.logging.info("Prefetching values from %d files matching %s",
+                    len(data_files), file_pattern)
+
+  filename_queue = tf.train.string_input_producer(
+      data_files, shuffle=shuffle, capacity=16, name="filename_queue")
+
+  if shuffle:
+    min_after_dequeue = int(0.6 * capacity)
+    values_queue = tf.RandomShuffleQueue(
+        capacity=capacity,
+        min_after_dequeue=min_after_dequeue,
+        dtypes=[tf.string],
+        shapes=[[]],
+        name="random_input_queue")
+  else:
+    values_queue = tf.FIFOQueue(
+        capacity=capacity,
+        dtypes=[tf.string],
+        shapes=[[]],
+        name="fifo_input_queue")
+
+  enqueue_ops = []
+  for _ in range(num_reader_threads):
+    _, value = reader.read(filename_queue)
+    enqueue_ops.append(values_queue.enqueue([value]))
+  tf.train.queue_runner.add_queue_runner(
+      tf.train.queue_runner.QueueRunner(values_queue, enqueue_ops))
+  tf.summary.scalar("queue/%s/fraction_of_%d_full" % (values_queue.name,
+                                                      capacity),
+                    tf.cast(values_queue.size(), tf.float32) * (1.0 / capacity))
+
+  return values_queue

+ 258 - 0
skip_thoughts/skip_thoughts/skip_thoughts_encoder.py

@@ -0,0 +1,258 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Class for encoding text using a trained SkipThoughtsModel.
+
+Example usage:
+  g = tf.Graph()
+  with g.as_default():
+    encoder = SkipThoughtsEncoder(embeddings)
+    restore_fn = encoder.build_graph_from_config(model_config, checkpoint_path)
+
+  with tf.Session(graph=g) as sess:
+    restore_fn(sess)
+    skip_thought_vectors = encoder.encode(sess, data)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+
+import nltk
+import nltk.tokenize
+import numpy as np
+import tensorflow as tf
+
+from skip_thoughts import skip_thoughts_model
+from skip_thoughts.data import special_words
+
+
+def _pad(seq, target_len):
+  """Pads a sequence of word embeddings up to the target length.
+
+  Args:
+    seq: Sequence of word embeddings.
+    target_len: Desired padded sequence length.
+
+  Returns:
+    embeddings: Input sequence padded with zero embeddings up to the target
+      length.
+    mask: A 0/1 vector with zeros corresponding to padded embeddings.
+
+  Raises:
+    ValueError: If len(seq) is not in the interval (0, target_len].
+  """
+  seq_len = len(seq)
+  if seq_len <= 0 or seq_len > target_len:
+    raise ValueError("Expected 0 < len(seq) <= %d, got %d" % (target_len,
+                                                              seq_len))
+
+  emb_dim = seq[0].shape[0]
+  padded_seq = np.zeros(shape=(target_len, emb_dim), dtype=seq[0].dtype)
+  mask = np.zeros(shape=(target_len,), dtype=np.int8)
+  for i in range(seq_len):
+    padded_seq[i] = seq[i]
+    mask[i] = 1
+  return padded_seq, mask
+
+
+def _batch_and_pad(sequences):
+  """Batches and pads sequences of word embeddings into a 2D array.
+
+  Args:
+    sequences: A list of batch_size sequences of word embeddings.
+
+  Returns:
+    embeddings: A numpy array with shape [batch_size, padded_length, emb_dim].
+    mask: A numpy 0/1 array with shape [batch_size, padded_length] with zeros
+      corresponding to padded elements.
+  """
+  batch_embeddings = []
+  batch_mask = []
+  batch_len = max([len(seq) for seq in sequences])
+  for seq in sequences:
+    embeddings, mask = _pad(seq, batch_len)
+    batch_embeddings.append(embeddings)
+    batch_mask.append(mask)
+  return np.array(batch_embeddings), np.array(batch_mask)
+
+
+class SkipThoughtsEncoder(object):
+  """Skip-thoughts sentence encoder."""
+
+  def __init__(self, embeddings):
+    """Initializes the encoder.
+
+    Args:
+      embeddings: Dictionary of word to embedding vector (1D numpy array).
+    """
+    self._sentence_detector = nltk.data.load("tokenizers/punkt/english.pickle")
+    self._embeddings = embeddings
+
+  def _create_restore_fn(self, checkpoint_path, saver):
+    """Creates a function that restores a model from checkpoint.
+
+    Args:
+      checkpoint_path: Checkpoint file or a directory containing a checkpoint
+        file.
+      saver: Saver for restoring variables from the checkpoint file.
+
+    Returns:
+      restore_fn: A function such that restore_fn(sess) loads model variables
+        from the checkpoint file.
+
+    Raises:
+      ValueError: If checkpoint_path does not refer to a checkpoint file or a
+        directory containing a checkpoint file.
+    """
+    if tf.gfile.IsDirectory(checkpoint_path):
+      latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
+      if not latest_checkpoint:
+        raise ValueError("No checkpoint file found in: %s" % checkpoint_path)
+      checkpoint_path = latest_checkpoint
+
+    def _restore_fn(sess):
+      tf.logging.info("Loading model from checkpoint: %s", checkpoint_path)
+      saver.restore(sess, checkpoint_path)
+      tf.logging.info("Successfully loaded checkpoint: %s",
+                      os.path.basename(checkpoint_path))
+
+    return _restore_fn
+
+  def build_graph_from_config(self, model_config, checkpoint_path):
+    """Builds the inference graph from a configuration object.
+
+    Args:
+      model_config: Object containing configuration for building the model.
+      checkpoint_path: Checkpoint file or a directory containing a checkpoint
+        file.
+
+    Returns:
+      restore_fn: A function such that restore_fn(sess) loads model variables
+        from the checkpoint file.
+    """
+    tf.logging.info("Building model.")
+    model = skip_thoughts_model.SkipThoughtsModel(model_config, mode="encode")
+    model.build()
+    saver = tf.train.Saver()
+
+    return self._create_restore_fn(checkpoint_path, saver)
+
+  def build_graph_from_proto(self, graph_def_file, saver_def_file,
+                             checkpoint_path):
+    """Builds the inference graph from serialized GraphDef and SaverDef protos.
+
+    Args:
+      graph_def_file: File containing a serialized GraphDef proto.
+      saver_def_file: File containing a serialized SaverDef proto.
+      checkpoint_path: Checkpoint file or a directory containing a checkpoint
+        file.
+
+    Returns:
+      restore_fn: A function such that restore_fn(sess) loads model variables
+        from the checkpoint file.
+    """
+    # Load the Graph.
+    tf.logging.info("Loading GraphDef from file: %s", graph_def_file)
+    graph_def = tf.GraphDef()
+    with tf.gfile.FastGFile(graph_def_file, "rb") as f:
+      graph_def.ParseFromString(f.read())
+    tf.import_graph_def(graph_def, name="")
+
+    # Load the Saver.
+    tf.logging.info("Loading SaverDef from file: %s", saver_def_file)
+    saver_def = tf.train.SaverDef()
+    with tf.gfile.FastGFile(saver_def_file, "rb") as f:
+      saver_def.ParseFromString(f.read())
+    saver = tf.train.Saver(saver_def=saver_def)
+
+    return self._create_restore_fn(checkpoint_path, saver)
+
+  def _tokenize(self, item):
+    """Tokenizes an input string into a list of words."""
+    tokenized = []
+    for s in self._sentence_detector.tokenize(item):
+      tokenized.extend(nltk.tokenize.word_tokenize(s))
+
+    return tokenized
+
+  def _word_to_embedding(self, w):
+    """Returns the embedding of a word."""
+    return self._embeddings.get(w, self._embeddings[special_words.UNK])
+
+  def _preprocess(self, data, use_eos):
+    """Preprocesses text for the encoder.
+
+    Args:
+      data: A list of input strings.
+      use_eos: Whether to append the end-of-sentence word to each sentence.
+
+    Returns:
+      embeddings: A list of word embedding sequences corresponding to the input
+        strings.
+    """
+    preprocessed_data = []
+    for item in data:
+      tokenized = self._tokenize(item)
+      if use_eos:
+        tokenized.append(special_words.EOS)
+      preprocessed_data.append([self._word_to_embedding(w) for w in tokenized])
+    return preprocessed_data
+
+  def encode(self,
+             sess,
+             data,
+             use_norm=True,
+             verbose=True,
+             batch_size=128,
+             use_eos=False):
+    """Encodes a sequence of sentences as skip-thought vectors.
+
+    Args:
+      sess: TensorFlow Session.
+      data: A list of input strings.
+      use_norm: Whether to normalize skip-thought vectors to unit L2 norm.
+      verbose: Whether to log every batch.
+      batch_size: Batch size for the encoder.
+      use_eos: Whether to append the end-of-sentence word to each input
+        sentence.
+
+    Returns:
+      thought_vectors: A list of numpy arrays corresponding to the skip-thought
+        encodings of sentences in 'data'.
+    """
+    data = self._preprocess(data, use_eos)
+    thought_vectors = []
+
+    batch_indices = np.arange(0, len(data), batch_size)
+    for batch, start_index in enumerate(batch_indices):
+      if verbose:
+        tf.logging.info("Batch %d / %d.", batch, len(batch_indices))
+
+      embeddings, mask = _batch_and_pad(
+          data[start_index:start_index + batch_size])
+      feed_dict = {
+          "encode_emb:0": embeddings,
+          "encode_mask:0": mask,
+      }
+      thought_vectors.extend(
+          sess.run("encoder/thought_vectors:0", feed_dict=feed_dict))
+
+    if use_norm:
+      thought_vectors = [v / np.linalg.norm(v) for v in thought_vectors]
+
+    return thought_vectors

+ 369 - 0
skip_thoughts/skip_thoughts/skip_thoughts_model.py

@@ -0,0 +1,369 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Skip-Thoughts model for learning sentence vectors.
+
+The model is based on the paper:
+
+  "Skip-Thought Vectors"
+  Ryan Kiros, Yukun Zhu, Ruslan Salakhutdinov, Richard S. Zemel,
+  Antonio Torralba, Raquel Urtasun, Sanja Fidler.
+  https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf
+
+Layer normalization is applied based on the paper:
+
+  "Layer Normalization"
+  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
+  https://arxiv.org/abs/1607.06450
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from skip_thoughts.ops import gru_cell
+from skip_thoughts.ops import input_ops
+
+
+def random_orthonormal_initializer(shape, dtype=tf.float32,
+                                   partition_info=None):  # pylint: disable=unused-argument
+  """Variable initializer that produces a random orthonormal matrix."""
+  if len(shape) != 2 or shape[0] != shape[1]:
+    raise ValueError("Expecting square shape, got %s" % shape)
+  _, u, _ = tf.svd(tf.random_normal(shape, dtype=dtype), full_matrices=True)
+  return u
+
+
+class SkipThoughtsModel(object):
+  """Skip-thoughts model."""
+
+  def __init__(self, config, mode="train", input_reader=None):
+    """Basic setup. The actual TensorFlow graph is constructed in build().
+
+    Args:
+      config: Object containing configuration parameters.
+      mode: "train", "eval" or "encode".
+      input_reader: Subclass of tf.ReaderBase for reading the input serialized
+        tf.Example protocol buffers. Defaults to TFRecordReader.
+
+    Raises:
+      ValueError: If mode is invalid.
+    """
+    if mode not in ["train", "eval", "encode"]:
+      raise ValueError("Unrecognized mode: %s" % mode)
+
+    self.config = config
+    self.mode = mode
+    self.reader = input_reader if input_reader else tf.TFRecordReader()
+
+    # Initializer used for non-recurrent weights.
+    self.uniform_initializer = tf.random_uniform_initializer(
+        minval=-self.config.uniform_init_scale,
+        maxval=self.config.uniform_init_scale)
+
+    # Input sentences represented as sequences of word ids. "encode" is the
+    # source sentence, "decode_pre" is the previous sentence and "decode_post"
+    # is the next sentence.
+    # Each is an int64 Tensor with  shape [batch_size, padded_length].
+    self.encode_ids = None
+    self.decode_pre_ids = None
+    self.decode_post_ids = None
+
+    # Boolean masks distinguishing real words (1) from padded words (0).
+    # Each is an int32 Tensor with shape [batch_size, padded_length].
+    self.encode_mask = None
+    self.decode_pre_mask = None
+    self.decode_post_mask = None
+
+    # Input sentences represented as sequences of word embeddings.
+    # Each is a float32 Tensor with shape [batch_size, padded_length, emb_dim].
+    self.encode_emb = None
+    self.decode_pre_emb = None
+    self.decode_post_emb = None
+
+    # The output from the sentence encoder.
+    # A float32 Tensor with shape [batch_size, num_gru_units].
+    self.thought_vectors = None
+
+    # The cross entropy losses and corresponding weights of the decoders. Used
+    # for evaluation.
+    self.target_cross_entropy_losses = []
+    self.target_cross_entropy_loss_weights = []
+
+    # The total loss to optimize.
+    self.total_loss = None
+
+  def build_inputs(self):
+    """Builds the ops for reading input data.
+
+    Outputs:
+      self.encode_ids
+      self.decode_pre_ids
+      self.decode_post_ids
+      self.encode_mask
+      self.decode_pre_mask
+      self.decode_post_mask
+    """
+    if self.mode == "encode":
+      # Word embeddings are fed from an external vocabulary which has possibly
+      # been expanded (see vocabulary_expansion.py).
+      encode_ids = None
+      decode_pre_ids = None
+      decode_post_ids = None
+      encode_mask = tf.placeholder(tf.int8, (None, None), name="encode_mask")
+      decode_pre_mask = None
+      decode_post_mask = None
+    else:
+      # Prefetch serialized tf.Example protos.
+      input_queue = input_ops.prefetch_input_data(
+          self.reader,
+          self.config.input_file_pattern,
+          shuffle=self.config.shuffle_input_data,
+          capacity=self.config.input_queue_capacity,
+          num_reader_threads=self.config.num_input_reader_threads)
+
+      # Deserialize a batch.
+      serialized = input_queue.dequeue_many(self.config.batch_size)
+      encode, decode_pre, decode_post = input_ops.parse_example_batch(
+          serialized)
+
+      encode_ids = encode.ids
+      decode_pre_ids = decode_pre.ids
+      decode_post_ids = decode_post.ids
+
+      encode_mask = encode.mask
+      decode_pre_mask = decode_pre.mask
+      decode_post_mask = decode_post.mask
+
+    self.encode_ids = encode_ids
+    self.decode_pre_ids = decode_pre_ids
+    self.decode_post_ids = decode_post_ids
+
+    self.encode_mask = encode_mask
+    self.decode_pre_mask = decode_pre_mask
+    self.decode_post_mask = decode_post_mask
+
+  def build_word_embeddings(self):
+    """Builds the word embeddings.
+
+    Inputs:
+      self.encode_ids
+      self.decode_pre_ids
+      self.decode_post_ids
+
+    Outputs:
+      self.encode_emb
+      self.decode_pre_emb
+      self.decode_post_emb
+    """
+    if self.mode == "encode":
+      # Word embeddings are fed from an external vocabulary which has possibly
+      # been expanded (see vocabulary_expansion.py).
+      encode_emb = tf.placeholder(tf.float32, (
+          None, None, self.config.word_embedding_dim), "encode_emb")
+      # No sequences to decode.
+      decode_pre_emb = None
+      decode_post_emb = None
+    else:
+      word_emb = tf.get_variable(
+          name="word_embedding",
+          shape=[self.config.vocab_size, self.config.word_embedding_dim],
+          initializer=self.uniform_initializer)
+
+      encode_emb = tf.nn.embedding_lookup(word_emb, self.encode_ids)
+      decode_pre_emb = tf.nn.embedding_lookup(word_emb, self.decode_pre_ids)
+      decode_post_emb = tf.nn.embedding_lookup(word_emb, self.decode_post_ids)
+
+    self.encode_emb = encode_emb
+    self.decode_pre_emb = decode_pre_emb
+    self.decode_post_emb = decode_post_emb
+
+  def _initialize_gru_cell(self, num_units):
+    """Initializes a GRU cell.
+
+    The Variables of the GRU cell are initialized in a way that exactly matches
+    the skip-thoughts paper: recurrent weights are initialized from random
+    orthonormal matrices and non-recurrent weights are initialized from random
+    uniform matrices.
+
+    Args:
+      num_units: Number of output units.
+
+    Returns:
+      cell: An instance of RNNCell with variable initializers that match the
+        skip-thoughts paper.
+    """
+    return gru_cell.LayerNormGRUCell(
+        num_units,
+        w_initializer=self.uniform_initializer,
+        u_initializer=random_orthonormal_initializer,
+        b_initializer=tf.constant_initializer(0.0))
+
+  def build_encoder(self):
+    """Builds the sentence encoder.
+
+    Inputs:
+      self.encode_emb
+      self.encode_mask
+
+    Outputs:
+      self.thought_vectors
+
+    Raises:
+      ValueError: if config.bidirectional_encoder is True and config.encoder_dim
+        is odd.
+    """
+    with tf.variable_scope("encoder") as scope:
+      length = tf.to_int32(tf.reduce_sum(self.encode_mask, 1), name="length")
+
+      if self.config.bidirectional_encoder:
+        if self.config.encoder_dim % 2:
+          raise ValueError(
+              "encoder_dim must be even when using a bidirectional encoder.")
+        num_units = self.config.encoder_dim // 2
+        cell_fw = self._initialize_gru_cell(num_units)  # Forward encoder
+        cell_bw = self._initialize_gru_cell(num_units)  # Backward encoder
+        _, states = tf.nn.bidirectional_dynamic_rnn(
+            cell_fw=cell_fw,
+            cell_bw=cell_bw,
+            inputs=self.encode_emb,
+            sequence_length=length,
+            dtype=tf.float32,
+            scope=scope)
+        thought_vectors = tf.concat(states, 1, name="thought_vectors")
+      else:
+        cell = self._initialize_gru_cell(self.config.encoder_dim)
+        _, state = tf.nn.dynamic_rnn(
+            cell=cell,
+            inputs=self.encode_emb,
+            sequence_length=length,
+            dtype=tf.float32,
+            scope=scope)
+        # Use an identity operation to name the Tensor in the Graph.
+        thought_vectors = tf.identity(state, name="thought_vectors")
+
+    self.thought_vectors = thought_vectors
+
+  def _build_decoder(self, name, embeddings, targets, mask, initial_state,
+                     reuse_logits):
+    """Builds a sentence decoder.
+
+    Args:
+      name: Decoder name.
+      embeddings: Batch of sentences to decode; a float32 Tensor with shape
+        [batch_size, padded_length, emb_dim].
+      targets: Batch of target word ids; an int64 Tensor with shape
+        [batch_size, padded_length].
+      mask: A 0/1 Tensor with shape [batch_size, padded_length].
+      initial_state: Initial state of the GRU. A float32 Tensor with shape
+        [batch_size, num_gru_cells].
+      reuse_logits: Whether to reuse the logits weights.
+    """
+    # Decoder RNN.
+    cell = self._initialize_gru_cell(self.config.encoder_dim)
+    with tf.variable_scope(name) as scope:
+      # Add a padding word at the start of each sentence (to correspond to the
+      # prediction of the first word) and remove the last word.
+      decoder_input = tf.pad(
+          embeddings[:, :-1, :], [[0, 0], [1, 0], [0, 0]], name="input")
+      length = tf.reduce_sum(mask, 1, name="length")
+      decoder_output, _ = tf.nn.dynamic_rnn(
+          cell=cell,
+          inputs=decoder_input,
+          sequence_length=length,
+          initial_state=initial_state,
+          scope=scope)
+
+    # Stack batch vertically.
+    decoder_output = tf.reshape(decoder_output, [-1, self.config.encoder_dim])
+    targets = tf.reshape(targets, [-1])
+    weights = tf.to_float(tf.reshape(mask, [-1]))
+
+    # Logits.
+    with tf.variable_scope("logits", reuse=reuse_logits) as scope:
+      logits = tf.contrib.layers.fully_connected(
+          inputs=decoder_output,
+          num_outputs=self.config.vocab_size,
+          activation_fn=None,
+          weights_initializer=self.uniform_initializer,
+          scope=scope)
+
+    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
+        labels=targets, logits=logits)
+    batch_loss = tf.reduce_sum(losses * weights)
+    tf.losses.add_loss(batch_loss)
+
+    tf.summary.scalar("losses/" + name, batch_loss)
+
+    self.target_cross_entropy_losses.append(losses)
+    self.target_cross_entropy_loss_weights.append(weights)
+
+  def build_decoders(self):
+    """Builds the sentence decoders.
+
+    Inputs:
+      self.decode_pre_emb
+      self.decode_post_emb
+      self.decode_pre_ids
+      self.decode_post_ids
+      self.decode_pre_mask
+      self.decode_post_mask
+      self.thought_vectors
+
+    Outputs:
+      self.target_cross_entropy_losses
+      self.target_cross_entropy_loss_weights
+    """
+    if self.mode != "encode":
+      # Pre-sentence decoder.
+      self._build_decoder("decoder_pre", self.decode_pre_emb,
+                          self.decode_pre_ids, self.decode_pre_mask,
+                          self.thought_vectors, False)
+
+      # Post-sentence decoder. Logits weights are reused.
+      self._build_decoder("decoder_post", self.decode_post_emb,
+                          self.decode_post_ids, self.decode_post_mask,
+                          self.thought_vectors, True)
+
+  def build_loss(self):
+    """Builds the loss Tensor.
+
+    Outputs:
+      self.total_loss
+    """
+    if self.mode != "encode":
+      total_loss = tf.losses.get_total_loss()
+      tf.summary.scalar("losses/total", total_loss)
+
+      self.total_loss = total_loss
+
+  def build_global_step(self):
+    """Builds the global step Tensor.
+
+    Outputs:
+      self.global_step
+    """
+    self.global_step = tf.contrib.framework.create_global_step()
+
+  def build(self):
+    """Creates all ops for training, evaluation or encoding."""
+    self.build_inputs()
+    self.build_word_embeddings()
+    self.build_encoder()
+    self.build_decoders()
+    self.build_loss()
+    self.build_global_step()

+ 191 - 0
skip_thoughts/skip_thoughts/skip_thoughts_model_test.py

@@ -0,0 +1,191 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow_models.skip_thoughts.skip_thoughts_model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import numpy as np
+import tensorflow as tf
+
+from skip_thoughts import configuration
+from skip_thoughts import skip_thoughts_model
+
+
+class SkipThoughtsModel(skip_thoughts_model.SkipThoughtsModel):
+  """Subclass of SkipThoughtsModel without the disk I/O."""
+
+  def build_inputs(self):
+    if self.mode == "encode":
+      # Encode mode doesn't read from disk, so defer to parent.
+      return super(SkipThoughtsModel, self).build_inputs()
+    else:
+      # Replace disk I/O with random Tensors.
+      self.encode_ids = tf.random_uniform(
+          [self.config.batch_size, 15],
+          minval=0,
+          maxval=self.config.vocab_size,
+          dtype=tf.int64)
+      self.decode_pre_ids = tf.random_uniform(
+          [self.config.batch_size, 15],
+          minval=0,
+          maxval=self.config.vocab_size,
+          dtype=tf.int64)
+      self.decode_post_ids = tf.random_uniform(
+          [self.config.batch_size, 15],
+          minval=0,
+          maxval=self.config.vocab_size,
+          dtype=tf.int64)
+      self.encode_mask = tf.ones_like(self.encode_ids)
+      self.decode_pre_mask = tf.ones_like(self.decode_pre_ids)
+      self.decode_post_mask = tf.ones_like(self.decode_post_ids)
+
+
+class SkipThoughtsModelTest(tf.test.TestCase):
+
+  def setUp(self):
+    super(SkipThoughtsModelTest, self).setUp()
+    self._model_config = configuration.model_config()
+
+  def _countModelParameters(self):
+    """Counts the number of parameters in the model at top level scope."""
+    counter = {}
+    for v in tf.global_variables():
+      name = v.op.name.split("/")[0]
+      num_params = v.get_shape().num_elements()
+      if not num_params:
+        self.fail("Could not infer num_elements from Variable %s" % v.op.name)
+      counter[name] = counter.get(name, 0) + num_params
+    return counter
+
+  def _checkModelParameters(self):
+    """Verifies the number of parameters in the model."""
+    param_counts = self._countModelParameters()
+    expected_param_counts = {
+        # vocab_size * embedding_size
+        "word_embedding": 12400000,
+        # GRU Cells
+        "encoder": 21772800,
+        "decoder_pre": 21772800,
+        "decoder_post": 21772800,
+        # (encoder_dim + 1) * vocab_size
+        "logits": 48020000,
+        "global_step": 1,
+    }
+    self.assertDictEqual(expected_param_counts, param_counts)
+
+  def _checkOutputs(self, expected_shapes, feed_dict=None):
+    """Verifies that the model produces expected outputs.
+
+    Args:
+      expected_shapes: A dict mapping Tensor or Tensor name to expected output
+        shape.
+      feed_dict: Values of Tensors to feed into Session.run().
+    """
+    fetches = expected_shapes.keys()
+
+    with self.test_session() as sess:
+      sess.run(tf.global_variables_initializer())
+      outputs = sess.run(fetches, feed_dict)
+
+    for index, output in enumerate(outputs):
+      tensor = fetches[index]
+      expected = expected_shapes[tensor]
+      actual = output.shape
+      if expected != actual:
+        self.fail("Tensor %s has shape %s (expected %s)." % (tensor, actual,
+                                                             expected))
+
+  def testBuildForTraining(self):
+    model = SkipThoughtsModel(self._model_config, mode="train")
+    model.build()
+
+    self._checkModelParameters()
+
+    expected_shapes = {
+        # [batch_size, length]
+        model.encode_ids: (128, 15),
+        model.decode_pre_ids: (128, 15),
+        model.decode_post_ids: (128, 15),
+        model.encode_mask: (128, 15),
+        model.decode_pre_mask: (128, 15),
+        model.decode_post_mask: (128, 15),
+        # [batch_size, length, word_embedding_dim]
+        model.encode_emb: (128, 15, 620),
+        model.decode_pre_emb: (128, 15, 620),
+        model.decode_post_emb: (128, 15, 620),
+        # [batch_size, encoder_dim]
+        model.thought_vectors: (128, 2400),
+        # [batch_size * length]
+        model.target_cross_entropy_losses[0]: (1920,),
+        model.target_cross_entropy_losses[1]: (1920,),
+        # [batch_size * length]
+        model.target_cross_entropy_loss_weights[0]: (1920,),
+        model.target_cross_entropy_loss_weights[1]: (1920,),
+        # Scalar
+        model.total_loss: (),
+    }
+    self._checkOutputs(expected_shapes)
+
+  def testBuildForEval(self):
+    model = SkipThoughtsModel(self._model_config, mode="eval")
+    model.build()
+
+    self._checkModelParameters()
+
+    expected_shapes = {
+        # [batch_size, length]
+        model.encode_ids: (128, 15),
+        model.decode_pre_ids: (128, 15),
+        model.decode_post_ids: (128, 15),
+        model.encode_mask: (128, 15),
+        model.decode_pre_mask: (128, 15),
+        model.decode_post_mask: (128, 15),
+        # [batch_size, length, word_embedding_dim]
+        model.encode_emb: (128, 15, 620),
+        model.decode_pre_emb: (128, 15, 620),
+        model.decode_post_emb: (128, 15, 620),
+        # [batch_size, encoder_dim]
+        model.thought_vectors: (128, 2400),
+        # [batch_size * length]
+        model.target_cross_entropy_losses[0]: (1920,),
+        model.target_cross_entropy_losses[1]: (1920,),
+        # [batch_size * length]
+        model.target_cross_entropy_loss_weights[0]: (1920,),
+        model.target_cross_entropy_loss_weights[1]: (1920,),
+        # Scalar
+        model.total_loss: (),
+    }
+    self._checkOutputs(expected_shapes)
+
+  def testBuildForEncode(self):
+    model = SkipThoughtsModel(self._model_config, mode="encode")
+    model.build()
+
+    # Test feeding a batch of word embeddings to get skip thought vectors.
+    encode_emb = np.random.rand(64, 15, 620)
+    encode_mask = np.ones((64, 15), dtype=np.int64)
+    feed_dict = {model.encode_emb: encode_emb, model.encode_mask: encode_mask}
+    expected_shapes = {
+        # [batch_size, encoder_dim]
+        model.thought_vectors: (64, 2400),
+    }
+    self._checkOutputs(expected_shapes, feed_dict)
+
+
+if __name__ == "__main__":
+  tf.test.main()

+ 199 - 0
skip_thoughts/skip_thoughts/track_perplexity.py

@@ -0,0 +1,199 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tracks training progress via per-word perplexity.
+
+This script should be run concurrently with training so that summaries show up
+in TensorBoard.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import os.path
+import time
+
+
+import numpy as np
+import tensorflow as tf
+
+from skip_thoughts import configuration
+from skip_thoughts import skip_thoughts_model
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_string("input_file_pattern", None,
+                       "File pattern of sharded TFRecord input files.")
+tf.flags.DEFINE_string("checkpoint_dir", None,
+                       "Directory containing model checkpoints.")
+tf.flags.DEFINE_string("eval_dir", None, "Directory to write event logs to.")
+
+tf.flags.DEFINE_integer("eval_interval_secs", 600,
+                        "Interval between evaluation runs.")
+tf.flags.DEFINE_integer("num_eval_examples", 50000,
+                        "Number of examples for evaluation.")
+
+tf.flags.DEFINE_integer("min_global_step", 100,
+                        "Minimum global step to run evaluation.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def evaluate_model(sess, losses, weights, num_batches, global_step,
+                   summary_writer, summary_op):
+  """Computes perplexity-per-word over the evaluation dataset.
+
+  Summaries and perplexity-per-word are written out to the eval directory.
+
+  Args:
+    sess: Session object.
+    losses: A Tensor of any shape; the target cross entropy losses for the
+      current batch.
+    weights: A Tensor of weights corresponding to losses.
+    num_batches: Integer; the number of evaluation batches.
+    global_step: Integer; global step of the model checkpoint.
+    summary_writer: Instance of SummaryWriter.
+    summary_op: Op for generating model summaries.
+  """
+  # Log model summaries on a single batch.
+  summary_str = sess.run(summary_op)
+  summary_writer.add_summary(summary_str, global_step)
+
+  start_time = time.time()
+  sum_losses = 0.0
+  sum_weights = 0.0
+  for i in xrange(num_batches):
+    batch_losses, batch_weights = sess.run([losses, weights])
+    sum_losses += np.sum(batch_losses * batch_weights)
+    sum_weights += np.sum(batch_weights)
+    if not i % 100:
+      tf.logging.info("Computed losses for %d of %d batches.", i + 1,
+                      num_batches)
+  eval_time = time.time() - start_time
+
+  perplexity = math.exp(sum_losses / sum_weights)
+  tf.logging.info("Perplexity = %f (%.2f sec)", perplexity, eval_time)
+
+  # Log perplexity to the SummaryWriter.
+  summary = tf.Summary()
+  value = summary.value.add()
+  value.simple_value = perplexity
+  value.tag = "perplexity"
+  summary_writer.add_summary(summary, global_step)
+
+  # Write the Events file to the eval directory.
+  summary_writer.flush()
+  tf.logging.info("Finished processing evaluation at global step %d.",
+                  global_step)
+
+
+def run_once(model, losses, weights, saver, summary_writer, summary_op):
+  """Evaluates the latest model checkpoint.
+
+  Args:
+    model: Instance of SkipThoughtsModel; the model to evaluate.
+    losses: Tensor; the target cross entropy losses for the current batch.
+    weights: A Tensor of weights corresponding to losses.
+    saver: Instance of tf.train.Saver for restoring model Variables.
+    summary_writer: Instance of FileWriter.
+    summary_op: Op for generating model summaries.
+  """
+  model_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
+  if not model_path:
+    tf.logging.info("Skipping evaluation. No checkpoint found in: %s",
+                    FLAGS.checkpoint_dir)
+    return
+
+  with tf.Session() as sess:
+    # Load model from checkpoint.
+    tf.logging.info("Loading model from checkpoint: %s", model_path)
+    saver.restore(sess, model_path)
+    global_step = tf.train.global_step(sess, model.global_step.name)
+    tf.logging.info("Successfully loaded %s at global step = %d.",
+                    os.path.basename(model_path), global_step)
+    if global_step < FLAGS.min_global_step:
+      tf.logging.info("Skipping evaluation. Global step = %d < %d", global_step,
+                      FLAGS.min_global_step)
+      return
+
+    # Start the queue runners.
+    coord = tf.train.Coordinator()
+    threads = tf.train.start_queue_runners(coord=coord)
+
+    num_eval_batches = int(
+        math.ceil(FLAGS.num_eval_examples / model.config.batch_size))
+
+    # Run evaluation on the latest checkpoint.
+    try:
+      evaluate_model(sess, losses, weights, num_eval_batches, global_step,
+                     summary_writer, summary_op)
+    except tf.InvalidArgumentError:
+      tf.logging.error(
+          "Evaluation raised InvalidArgumentError (e.g. due to Nans).")
+    finally:
+      coord.request_stop()
+      coord.join(threads, stop_grace_period_secs=10)
+
+
+def main(unused_argv):
+  if not FLAGS.input_file_pattern:
+    raise ValueError("--input_file_pattern is required.")
+  if not FLAGS.checkpoint_dir:
+    raise ValueError("--checkpoint_dir is required.")
+  if not FLAGS.eval_dir:
+    raise ValueError("--eval_dir is required.")
+
+  # Create the evaluation directory if it doesn't exist.
+  eval_dir = FLAGS.eval_dir
+  if not tf.gfile.IsDirectory(eval_dir):
+    tf.logging.info("Creating eval directory: %s", eval_dir)
+    tf.gfile.MakeDirs(eval_dir)
+
+  g = tf.Graph()
+  with g.as_default():
+    # Build the model for evaluation.
+    model_config = configuration.model_config(
+        input_file_pattern=FLAGS.input_file_pattern,
+        input_queue_capacity=FLAGS.num_eval_examples,
+        shuffle_input_data=False)
+    model = skip_thoughts_model.SkipThoughtsModel(model_config, mode="eval")
+    model.build()
+
+    losses = tf.concat(model.target_cross_entropy_losses, 0)
+    weights = tf.concat(model.target_cross_entropy_loss_weights, 0)
+
+    # Create the Saver to restore model Variables.
+    saver = tf.train.Saver()
+
+    # Create the summary operation and the summary writer.
+    summary_op = tf.summary.merge_all()
+    summary_writer = tf.summary.FileWriter(eval_dir)
+
+    g.finalize()
+
+    # Run a new evaluation run every eval_interval_secs.
+    while True:
+      start = time.time()
+      tf.logging.info("Starting evaluation at " + time.strftime(
+          "%Y-%m-%d-%H:%M:%S", time.localtime()))
+      run_once(model, losses, weights, saver, summary_writer, summary_op)
+      time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
+      if time_to_next_eval > 0:
+        time.sleep(time_to_next_eval)
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 99 - 0
skip_thoughts/skip_thoughts/train.py

@@ -0,0 +1,99 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Train the skip-thoughts model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from skip_thoughts import configuration
+from skip_thoughts import skip_thoughts_model
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_string("input_file_pattern", None,
+                       "File pattern of sharded TFRecord files containing "
+                       "tf.Example protos.")
+tf.flags.DEFINE_string("train_dir", None,
+                       "Directory for saving and loading checkpoints.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def _setup_learning_rate(config, global_step):
+  """Sets up the learning rate with optional exponential decay.
+
+  Args:
+    config: Object containing learning rate configuration parameters.
+    global_step: Tensor; the global step.
+
+  Returns:
+    learning_rate: Tensor; the learning rate with exponential decay.
+  """
+  if config.learning_rate_decay_factor > 0:
+    learning_rate = tf.train.exponential_decay(
+        learning_rate=float(config.learning_rate),
+        global_step=global_step,
+        decay_steps=config.learning_rate_decay_steps,
+        decay_rate=config.learning_rate_decay_factor,
+        staircase=False)
+  else:
+    learning_rate = tf.constant(config.learning_rate)
+  return learning_rate
+
+
+def main(unused_argv):
+  if not FLAGS.input_file_pattern:
+    raise ValueError("--input_file_pattern is required.")
+  if not FLAGS.train_dir:
+    raise ValueError("--train_dir is required.")
+
+  model_config = configuration.model_config(
+      input_file_pattern=FLAGS.input_file_pattern)
+  training_config = configuration.training_config()
+
+  tf.logging.info("Building training graph.")
+  g = tf.Graph()
+  with g.as_default():
+    model = skip_thoughts_model.SkipThoughtsModel(model_config, mode="train")
+    model.build()
+
+    learning_rate = _setup_learning_rate(training_config, model.global_step)
+    optimizer = tf.train.AdamOptimizer(learning_rate)
+
+    train_tensor = tf.contrib.slim.learning.create_train_op(
+        total_loss=model.total_loss,
+        optimizer=optimizer,
+        global_step=model.global_step,
+        clip_gradient_norm=training_config.clip_gradient_norm)
+
+    saver = tf.train.Saver()
+
+  tf.contrib.slim.learning.train(
+      train_op=train_tensor,
+      logdir=FLAGS.train_dir,
+      graph=g,
+      global_step=model.global_step,
+      number_of_steps=training_config.number_of_steps,
+      save_summaries_secs=training_config.save_summaries_secs,
+      saver=saver,
+      save_interval_secs=training_config.save_model_secs)
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 203 - 0
skip_thoughts/skip_thoughts/vocabulary_expansion.py

@@ -0,0 +1,203 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Compute an expanded vocabulary of embeddings using a word2vec model.
+
+This script loads the word embeddings from a trained skip-thoughts model and
+from a trained word2vec model (typically with a larger vocabulary). It trains a
+linear regression model without regularization to learn a linear mapping from
+the word2vec embedding space to the skip-thoughts embedding space. The model is
+then applied to all words in the word2vec vocabulary, yielding vectors in the
+skip-thoughts word embedding space for the union of the two vocabularies.
+
+The linear regression task is to learn a parameter matrix W to minimize
+  || X - Y * W ||^2,
+where X is a matrix of skip-thoughts embeddings of shape [num_words, dim1],
+Y is a matrix of word2vec embeddings of shape [num_words, dim2], and W is a
+matrix of shape [dim2, dim1].
+
+This is based on the "Translation Matrix" method from the paper:
+
+  "Exploiting Similarities among Languages for Machine Translation"
+  Tomas Mikolov, Quoc V. Le, Ilya Sutskever
+  https://arxiv.org/abs/1309.4168
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os.path
+
+
+import gensim.models
+import numpy as np
+import sklearn.linear_model
+import tensorflow as tf
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_string("skip_thoughts_model", None,
+                       "Checkpoint file or directory containing a checkpoint "
+                       "file.")
+
+tf.flags.DEFINE_string("skip_thoughts_vocab", None,
+                       "Path to vocabulary file containing a list of newline-"
+                       "separated words where the word id is the "
+                       "corresponding 0-based index in the file.")
+
+tf.flags.DEFINE_string("word2vec_model", None,
+                       "File containing a word2vec model in binary format.")
+
+tf.flags.DEFINE_string("output_dir", None, "Output directory.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def _load_skip_thoughts_embeddings(checkpoint_path):
+  """Loads the embedding matrix from a skip-thoughts model checkpoint.
+
+  Args:
+    checkpoint_path: Model checkpoint file or directory containing a checkpoint
+        file.
+
+  Returns:
+    word_embedding: A numpy array of shape [vocab_size, embedding_dim].
+
+  Raises:
+    ValueError: If no checkpoint file matches checkpoint_path.
+  """
+  if tf.gfile.IsDirectory(checkpoint_path):
+    checkpoint_file = tf.train.latest_checkpoint(checkpoint_path)
+    if not checkpoint_file:
+      raise ValueError("No checkpoint file found in %s" % checkpoint_path)
+  else:
+    checkpoint_file = checkpoint_path
+
+  tf.logging.info("Loading skip-thoughts embedding matrix from %s",
+                  checkpoint_file)
+  reader = tf.train.NewCheckpointReader(checkpoint_file)
+  word_embedding = reader.get_tensor("word_embedding")
+  tf.logging.info("Loaded skip-thoughts embedding matrix of shape %s",
+                  word_embedding.shape)
+
+  return word_embedding
+
+
+def _load_vocabulary(filename):
+  """Loads a vocabulary file.
+
+  Args:
+    filename: Path to text file containing newline-separated words.
+
+  Returns:
+    vocab: A dictionary mapping word to word id.
+  """
+  tf.logging.info("Reading vocabulary from %s", filename)
+  vocab = collections.OrderedDict()
+  with tf.gfile.GFile(filename, mode="r") as f:
+    for i, line in enumerate(f):
+      word = line.decode("utf-8").strip()
+      assert word not in vocab, "Attempting to add word twice: %s" % word
+      vocab[word] = i
+  tf.logging.info("Read vocabulary of size %d", len(vocab))
+  return vocab
+
+
+def _expand_vocabulary(skip_thoughts_emb, skip_thoughts_vocab, word2vec):
+  """Runs vocabulary expansion on a skip-thoughts model using a word2vec model.
+
+  Args:
+    skip_thoughts_emb: A numpy array of shape [skip_thoughts_vocab_size,
+        skip_thoughts_embedding_dim].
+    skip_thoughts_vocab: A dictionary of word to id.
+    word2vec: An instance of gensim.models.Word2Vec.
+
+  Returns:
+    combined_emb: A dictionary mapping words to embedding vectors.
+  """
+  # Find words shared between the two vocabularies.
+  tf.logging.info("Finding shared words")
+  shared_words = [w for w in word2vec.vocab if w in skip_thoughts_vocab]
+
+  # Select embedding vectors for shared words.
+  tf.logging.info("Selecting embeddings for %d shared words", len(shared_words))
+  shared_st_emb = skip_thoughts_emb[[
+      skip_thoughts_vocab[w] for w in shared_words
+  ]]
+  shared_w2v_emb = word2vec[shared_words]
+
+  # Train a linear regression model on the shared embedding vectors.
+  tf.logging.info("Training linear regression model")
+  model = sklearn.linear_model.LinearRegression()
+  model.fit(shared_w2v_emb, shared_st_emb)
+
+  # Create the expanded vocabulary.
+  tf.logging.info("Creating embeddings for expanded vocabuary")
+  combined_emb = collections.OrderedDict()
+  for w in word2vec.vocab:
+    # Ignore words with underscores (spaces).
+    if "_" not in w:
+      w_emb = model.predict(word2vec[w].reshape(1, -1))
+      combined_emb[w] = w_emb.reshape(-1)
+
+  for w in skip_thoughts_vocab:
+    combined_emb[w] = skip_thoughts_emb[skip_thoughts_vocab[w]]
+
+  tf.logging.info("Created expanded vocabulary of %d words", len(combined_emb))
+
+  return combined_emb
+
+
+def main(unused_argv):
+  if not FLAGS.skip_thoughts_model:
+    raise ValueError("--skip_thoughts_model is required.")
+  if not FLAGS.skip_thoughts_vocab:
+    raise ValueError("--skip_thoughts_vocab is required.")
+  if not FLAGS.word2vec_model:
+    raise ValueError("--word2vec_model is required.")
+  if not FLAGS.output_dir:
+    raise ValueError("--output_dir is required.")
+
+  if not tf.gfile.IsDirectory(FLAGS.output_dir):
+    tf.gfile.MakeDirs(FLAGS.output_dir)
+
+  # Load the skip-thoughts embeddings and vocabulary.
+  skip_thoughts_emb = _load_skip_thoughts_embeddings(FLAGS.skip_thoughts_model)
+  skip_thoughts_vocab = _load_vocabulary(FLAGS.skip_thoughts_vocab)
+
+  # Load the Word2Vec model.
+  word2vec = gensim.models.Word2Vec.load_word2vec_format(
+      FLAGS.word2vec_model, binary=True)
+
+  # Run vocabulary expansion.
+  embedding_map = _expand_vocabulary(skip_thoughts_emb, skip_thoughts_vocab,
+                                     word2vec)
+
+  # Save the output.
+  vocab = embedding_map.keys()
+  vocab_file = os.path.join(FLAGS.output_dir, "vocab.txt")
+  with tf.gfile.GFile(vocab_file, "w") as f:
+    f.write("\n".join(vocab))
+  tf.logging.info("Wrote vocabulary file to %s", vocab_file)
+
+  embeddings = np.array(embedding_map.values())
+  embeddings_file = os.path.join(FLAGS.output_dir, "embeddings.npy")
+  np.save(embeddings_file, embeddings)
+  tf.logging.info("Wrote embeddings file to %s", embeddings_file)
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 1 - 1
syntaxnet/g3doc/universal.md

@@ -18,7 +18,7 @@ The following table shows their accuracy on Universal
 Dependencies test sets for different types of annotations.
 Dependencies test sets for different types of annotations.
 
 
 Language | No. tokens | POS | fPOS | Morph | UAS | LAS
 Language | No. tokens | POS | fPOS | Morph | UAS | LAS
---------  | :--: | :--: | :--: | :--: | :--: | :--: | :--:
+--------  | :--: | :--: | :--: | :--: | :--: | :--:
 Ancient_Greek-PROIEL | 18502 | 97.14% | 96.97% | 89.77% | 78.74% | 73.15%
 Ancient_Greek-PROIEL | 18502 | 97.14% | 96.97% | 89.77% | 78.74% | 73.15%
 Ancient_Greek | 25251 | 93.22% | 84.22% | 90.01% | 68.98% | 62.07%
 Ancient_Greek | 25251 | 93.22% | 84.22% | 90.01% | 68.98% | 62.07%
 Arabic | 28268 | 95.65% | 91.03% | 91.23% | 81.49% | 75.82%
 Arabic | 28268 | 95.65% | 91.03% | 91.23% | 81.49% | 75.82%

+ 3 - 3
textsum/seq2seq_attention_model.py

@@ -166,7 +166,7 @@ class Seq2SeqAttentionModel(object):
               hps.num_hidden,
               hps.num_hidden,
               initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
               initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
               state_is_tuple=False)
               state_is_tuple=False)
-          (emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn(
+          (emb_encoder_inputs, fw_state, _) = tf.contrib.rnn.static_bidirectional_rnn(
               cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
               cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
               sequence_length=article_lens)
               sequence_length=article_lens)
       encoder_outputs = emb_encoder_inputs
       encoder_outputs = emb_encoder_inputs
@@ -200,7 +200,7 @@ class Seq2SeqAttentionModel(object):
         # During decoding, follow up _dec_in_state are fed from beam_search.
         # During decoding, follow up _dec_in_state are fed from beam_search.
         # dec_out_state are stored by beam_search for next step feeding.
         # dec_out_state are stored by beam_search for next step feeding.
         initial_state_attention = (hps.mode == 'decode')
         initial_state_attention = (hps.mode == 'decode')
-        decoder_outputs, self._dec_out_state = tf.nn.seq2seq.attention_decoder(
+        decoder_outputs, self._dec_out_state = tf.contrib.legacy_seq2seq.attention_decoder(
             emb_decoder_inputs, self._dec_in_state, self._enc_top_states,
             emb_decoder_inputs, self._dec_in_state, self._enc_top_states,
             cell, num_heads=1, loop_function=loop_function,
             cell, num_heads=1, loop_function=loop_function,
             initial_state_attention=initial_state_attention)
             initial_state_attention=initial_state_attention)
@@ -234,7 +234,7 @@ class Seq2SeqAttentionModel(object):
           self._loss = seq2seq_lib.sampled_sequence_loss(
           self._loss = seq2seq_lib.sampled_sequence_loss(
               decoder_outputs, targets, loss_weights, sampled_loss_func)
               decoder_outputs, targets, loss_weights, sampled_loss_func)
         else:
         else:
-          self._loss = tf.nn.seq2seq.sequence_loss(
+          self._loss = tf.contrib.legacy_seq2seq.sequence_loss(
               model_outputs, targets, loss_weights)
               model_outputs, targets, loss_weights)
         tf.summary.scalar('loss', tf.minimum(12.0, self._loss))
         tf.summary.scalar('loss', tf.minimum(12.0, self._loss))
 
 

+ 1 - 1
transformer/cluttered_mnist.py

@@ -123,7 +123,7 @@ y_logits = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
 
 
 # %% Define loss/eval/training functions
 # %% Define loss/eval/training functions
 cross_entropy = tf.reduce_mean(
 cross_entropy = tf.reduce_mean(
-    tf.nn.softmax_cross_entropy_with_logits(logits=y_logits, targets=y))
+    tf.nn.softmax_cross_entropy_with_logits(logits=y_logits, labels=y))
 opt = tf.train.AdamOptimizer()
 opt = tf.train.AdamOptimizer()
 optimizer = opt.minimize(cross_entropy)
 optimizer = opt.minimize(cross_entropy)
 grads = opt.compute_gradients(cross_entropy, [b_fc_loc2])
 grads = opt.compute_gradients(cross_entropy, [b_fc_loc2])

+ 1 - 1
tutorials/README.md

@@ -1,3 +1,3 @@
 # Tutorial Models
 # Tutorial Models
 
 
-This repository contains models referenced to from the [TensorFlow tutorials](https://www.tensorflow.org/tutorials/). We recommend installing TensorFlow from the [nightly builds](https://github.com/tensorflow/tensorflow#installation) rather than the r0.12 release before running these models.
+This folder contains models referenced to from the [TensorFlow tutorials](https://www.tensorflow.org/tutorials/).

+ 1 - 1
tutorials/rnn/translate/seq2seq_model.py

@@ -100,7 +100,7 @@ class Seq2SeqModel(object):
       b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype)
       b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype)
       output_projection = (w, b)
       output_projection = (w, b)
 
 
-      def sampled_loss(inputs, labels):
+      def sampled_loss(labels, inputs):
         labels = tf.reshape(labels, [-1, 1])
         labels = tf.reshape(labels, [-1, 1])
         # We need to compute the sampled_softmax_loss using 32bit floats to
         # We need to compute the sampled_softmax_loss using 32bit floats to
         # avoid numerical instabilities.
         # avoid numerical instabilities.