Browse Source

Moving example models from github.com/tensorflow/tensorflow to github.com/tensorflow/models

Neal Wu 9 năm trước cách đây
mục cha
commit
86ecc9730d
45 tập tin đã thay đổi với 5867 bổ sung0 xóa
  1. 0 0
      tutorials/__init__.py
  2. 125 0
      tutorials/embedding/BUILD
  3. 51 0
      tutorials/embedding/README.md
  4. 21 0
      tutorials/embedding/__init__.py
  5. 534 0
      tutorials/embedding/word2vec.py
  6. 355 0
      tutorials/embedding/word2vec_kernels.cc
  7. 73 0
      tutorials/embedding/word2vec_ops.cc
  8. 439 0
      tutorials/embedding/word2vec_optimized.py
  9. 62 0
      tutorials/embedding/word2vec_optimized_test.py
  10. 62 0
      tutorials/embedding/word2vec_test.py
  11. 0 0
      tutorials/image/__init__.py
  12. 29 0
      tutorials/image/alexnet/BUILD
  13. 0 0
      tutorials/image/alexnet/__init__.py
  14. 246 0
      tutorials/image/alexnet/alexnet_benchmark.py
  15. 87 0
      tutorials/image/cifar10/BUILD
  16. 10 0
      tutorials/image/cifar10/README.md
  17. 22 0
      tutorials/image/cifar10/__init__.py
  18. 399 0
      tutorials/image/cifar10/cifar10.py
  19. 157 0
      tutorials/image/cifar10/cifar10_eval.py
  20. 253 0
      tutorials/image/cifar10/cifar10_input.py
  21. 66 0
      tutorials/image/cifar10/cifar10_input_test.py
  22. 273 0
      tutorials/image/cifar10/cifar10_multi_gpu_train.py
  23. 120 0
      tutorials/image/cifar10/cifar10_train.py
  24. 30 0
      tutorials/image/imagenet/BUILD
  25. 227 0
      tutorials/image/imagenet/classify_image.py
  26. 42 0
      tutorials/image/mnist/BUILD
  27. 0 0
      tutorials/image/mnist/__init__.py
  28. 339 0
      tutorials/image/mnist/convolutional.py
  29. 80 0
      tutorials/rnn/BUILD
  30. 13 0
      tutorials/rnn/README.md
  31. 19 0
      tutorials/rnn/__init__.py
  32. 20 0
      tutorials/rnn/linear.py
  33. 61 0
      tutorials/rnn/ptb/BUILD
  34. 21 0
      tutorials/rnn/ptb/__init__.py
  35. 371 0
      tutorials/rnn/ptb/ptb_word_lm.py
  36. 122 0
      tutorials/rnn/ptb/reader.py
  37. 68 0
      tutorials/rnn/ptb/reader_test.py
  38. 21 0
      tutorials/rnn/rnn.py
  39. 21 0
      tutorials/rnn/rnn_cell.py
  40. 22 0
      tutorials/rnn/seq2seq.py
  41. 84 0
      tutorials/rnn/translate/BUILD
  42. 22 0
      tutorials/rnn/translate/__init__.py
  43. 290 0
      tutorials/rnn/translate/data_utils.py
  44. 313 0
      tutorials/rnn/translate/seq2seq_model.py
  45. 297 0
      tutorials/rnn/translate/translate.py

+ 0 - 0
tutorials/__init__.py


+ 125 - 0
tutorials/embedding/BUILD

@@ -0,0 +1,125 @@
+# Description:
+# TensorFlow model for word2vec
+
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+
+py_library(
+    name = "package",
+    srcs = [
+        "__init__.py",
+    ],
+    srcs_version = "PY2AND3",
+    visibility = ["//tensorflow:__subpackages__"],
+    deps = [
+        ":gen_word2vec",
+        ":word2vec",
+        ":word2vec_optimized",
+    ],
+)
+
+py_binary(
+    name = "word2vec",
+    srcs = [
+        "word2vec.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":gen_word2vec",
+        ":word2vec_kernels",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:platform",
+    ],
+)
+
+py_binary(
+    name = "word2vec_optimized",
+    srcs = [
+        "word2vec_optimized.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":gen_word2vec",
+        ":word2vec_kernels",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:platform",
+    ],
+)
+
+py_test(
+    name = "word2vec_test",
+    size = "small",
+    srcs = ["word2vec_test.py"],
+    srcs_version = "PY2AND3",
+    tags = [
+        "notsan",  # b/25864127
+    ],
+    deps = [
+        ":word2vec",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+py_test(
+    name = "word2vec_optimized_test",
+    size = "small",
+    srcs = ["word2vec_optimized_test.py"],
+    srcs_version = "PY2AND3",
+    tags = [
+        "notsan",
+    ],
+    deps = [
+        ":word2vec_optimized",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+cc_library(
+    name = "word2vec_ops",
+    srcs = [
+        "word2vec_ops.cc",
+    ],
+    linkstatic = 1,
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        "//tensorflow/core:framework",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "word2vec_kernels",
+    srcs = [
+        "word2vec_kernels.cc",
+    ],
+    linkstatic = 1,
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        ":word2vec_ops",
+        "//tensorflow/core",
+    ],
+    alwayslink = 1,
+)
+
+tf_gen_op_wrapper_py(
+    name = "gen_word2vec",
+    out = "gen_word2vec.py",
+    deps = [":word2vec_ops"],
+)
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)

+ 51 - 0
tutorials/embedding/README.md

@@ -0,0 +1,51 @@
+This directory contains models for unsupervised training of word embeddings
+using the model described in:
+
+(Mikolov, et. al.) [Efficient Estimation of Word Representations in Vector Space](http://arxiv.org/abs/1301.3781),
+ICLR 2013.
+
+Detailed instructions on how to get started and use them are available in the
+tutorials. Brief instructions are below.
+
+* [Word2Vec Tutorial](http://tensorflow.org/tutorials/word2vec/index.md)
+
+To download the example text and evaluation data:
+
+```shell
+wget http://mattmahoney.net/dc/text8.zip -O text8.zip
+unzip text8.zip
+wget https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/word2vec/source-archive.zip
+unzip -p source-archive.zip  word2vec/trunk/questions-words.txt > questions-words.txt
+rm source-archive.zip
+```
+
+Assuming you are using the pip package install and have cloned the git
+repository, navigate into this directory and run using:
+
+```shell
+cd tensorflow/models/embedding
+python word2vec_optimized.py \
+  --train_data=text8 \
+  --eval_data=questions-words.txt \
+  --save_path=/tmp/
+```
+
+To run the code from sources using bazel:
+
+```shell
+bazel run -c opt tensorflow/models/embedding/word2vec_optimized -- \
+  --train_data=text8 \
+  --eval_data=questions-words.txt \
+  --save_path=/tmp/
+```
+
+Here is a short overview of what is in this directory.
+
+File | What's in it?
+--- | ---
+`word2vec.py` | A version of word2vec implemented using TensorFlow ops and minibatching.
+`word2vec_test.py` | Integration test for word2vec.
+`word2vec_optimized.py` | A version of word2vec implemented using C ops that does no minibatching.
+`word2vec_optimized_test.py` | Integration test for word2vec_optimized.
+`word2vec_kernels.cc` | Kernels for the custom input and training ops.
+`word2vec_ops.cc` | The declarations of the custom ops.

+ 21 - 0
tutorials/embedding/__init__.py

@@ -0,0 +1,21 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Import generated word2vec optimized ops into embedding package."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.models.embedding import gen_word2vec

+ 534 - 0
tutorials/embedding/word2vec.py

@@ -0,0 +1,534 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Multi-threaded word2vec mini-batched skip-gram model.
+
+Trains the model described in:
+(Mikolov, et. al.) Efficient Estimation of Word Representations in Vector Space
+ICLR 2013.
+http://arxiv.org/abs/1301.3781
+This model does traditional minibatching.
+
+The key ops used are:
+* placeholder for feeding in tensors for each example.
+* embedding_lookup for fetching rows from the embedding matrix.
+* sigmoid_cross_entropy_with_logits to calculate the loss.
+* GradientDescentOptimizer for optimizing the loss.
+* skipgram custom op that does input processing.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import threading
+import time
+
+from six.moves import xrange  # pylint: disable=redefined-builtin
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.embedding import gen_word2vec as word2vec
+
+flags = tf.app.flags
+
+flags.DEFINE_string("save_path", None, "Directory to write the model and "
+                    "training summaries.")
+flags.DEFINE_string("train_data", None, "Training text file. "
+                    "E.g., unzipped file http://mattmahoney.net/dc/text8.zip.")
+flags.DEFINE_string(
+    "eval_data", None, "File consisting of analogies of four tokens."
+    "embedding 2 - embedding 1 + embedding 3 should be close "
+    "to embedding 4."
+    "See README.md for how to get 'questions-words.txt'.")
+flags.DEFINE_integer("embedding_size", 200, "The embedding dimension size.")
+flags.DEFINE_integer(
+    "epochs_to_train", 15,
+    "Number of epochs to train. Each epoch processes the training data once "
+    "completely.")
+flags.DEFINE_float("learning_rate", 0.2, "Initial learning rate.")
+flags.DEFINE_integer("num_neg_samples", 100,
+                     "Negative samples per training example.")
+flags.DEFINE_integer("batch_size", 16,
+                     "Number of training examples processed per step "
+                     "(size of a minibatch).")
+flags.DEFINE_integer("concurrent_steps", 12,
+                     "The number of concurrent training steps.")
+flags.DEFINE_integer("window_size", 5,
+                     "The number of words to predict to the left and right "
+                     "of the target word.")
+flags.DEFINE_integer("min_count", 5,
+                     "The minimum number of word occurrences for it to be "
+                     "included in the vocabulary.")
+flags.DEFINE_float("subsample", 1e-3,
+                   "Subsample threshold for word occurrence. Words that appear "
+                   "with higher frequency will be randomly down-sampled. Set "
+                   "to 0 to disable.")
+flags.DEFINE_boolean(
+    "interactive", False,
+    "If true, enters an IPython interactive session to play with the trained "
+    "model. E.g., try model.analogy(b'france', b'paris', b'russia') and "
+    "model.nearby([b'proton', b'elephant', b'maxwell'])")
+flags.DEFINE_integer("statistics_interval", 5,
+                     "Print statistics every n seconds.")
+flags.DEFINE_integer("summary_interval", 5,
+                     "Save training summary to file every n seconds (rounded "
+                     "up to statistics interval).")
+flags.DEFINE_integer("checkpoint_interval", 600,
+                     "Checkpoint the model (i.e. save the parameters) every n "
+                     "seconds (rounded up to statistics interval).")
+
+FLAGS = flags.FLAGS
+
+
+class Options(object):
+  """Options used by our word2vec model."""
+
+  def __init__(self):
+    # Model options.
+
+    # Embedding dimension.
+    self.emb_dim = FLAGS.embedding_size
+
+    # Training options.
+    # The training text file.
+    self.train_data = FLAGS.train_data
+
+    # Number of negative samples per example.
+    self.num_samples = FLAGS.num_neg_samples
+
+    # The initial learning rate.
+    self.learning_rate = FLAGS.learning_rate
+
+    # Number of epochs to train. After these many epochs, the learning
+    # rate decays linearly to zero and the training stops.
+    self.epochs_to_train = FLAGS.epochs_to_train
+
+    # Concurrent training steps.
+    self.concurrent_steps = FLAGS.concurrent_steps
+
+    # Number of examples for one training step.
+    self.batch_size = FLAGS.batch_size
+
+    # The number of words to predict to the left and right of the target word.
+    self.window_size = FLAGS.window_size
+
+    # The minimum number of word occurrences for it to be included in the
+    # vocabulary.
+    self.min_count = FLAGS.min_count
+
+    # Subsampling threshold for word occurrence.
+    self.subsample = FLAGS.subsample
+
+    # How often to print statistics.
+    self.statistics_interval = FLAGS.statistics_interval
+
+    # How often to write to the summary file (rounds up to the nearest
+    # statistics_interval).
+    self.summary_interval = FLAGS.summary_interval
+
+    # How often to write checkpoints (rounds up to the nearest statistics
+    # interval).
+    self.checkpoint_interval = FLAGS.checkpoint_interval
+
+    # Where to write out summaries.
+    self.save_path = FLAGS.save_path
+    if not os.path.exists(self.save_path):
+      os.makedirs(self.save_path)
+
+    # Eval options.
+    # The text file for eval.
+    self.eval_data = FLAGS.eval_data
+
+
+class Word2Vec(object):
+  """Word2Vec model (Skipgram)."""
+
+  def __init__(self, options, session):
+    self._options = options
+    self._session = session
+    self._word2id = {}
+    self._id2word = []
+    self.build_graph()
+    self.build_eval_graph()
+    self.save_vocab()
+
+  def read_analogies(self):
+    """Reads through the analogy question file.
+
+    Returns:
+      questions: a [n, 4] numpy array containing the analogy question's
+                 word ids.
+      questions_skipped: questions skipped due to unknown words.
+    """
+    questions = []
+    questions_skipped = 0
+    with open(self._options.eval_data, "rb") as analogy_f:
+      for line in analogy_f:
+        if line.startswith(b":"):  # Skip comments.
+          continue
+        words = line.strip().lower().split(b" ")
+        ids = [self._word2id.get(w.strip()) for w in words]
+        if None in ids or len(ids) != 4:
+          questions_skipped += 1
+        else:
+          questions.append(np.array(ids))
+    print("Eval analogy file: ", self._options.eval_data)
+    print("Questions: ", len(questions))
+    print("Skipped: ", questions_skipped)
+    self._analogy_questions = np.array(questions, dtype=np.int32)
+
+  def forward(self, examples, labels):
+    """Build the graph for the forward pass."""
+    opts = self._options
+
+    # Declare all variables we need.
+    # Embedding: [vocab_size, emb_dim]
+    init_width = 0.5 / opts.emb_dim
+    emb = tf.Variable(
+        tf.random_uniform(
+            [opts.vocab_size, opts.emb_dim], -init_width, init_width),
+        name="emb")
+    self._emb = emb
+
+    # Softmax weight: [vocab_size, emb_dim]. Transposed.
+    sm_w_t = tf.Variable(
+        tf.zeros([opts.vocab_size, opts.emb_dim]),
+        name="sm_w_t")
+
+    # Softmax bias: [emb_dim].
+    sm_b = tf.Variable(tf.zeros([opts.vocab_size]), name="sm_b")
+
+    # Global step: scalar, i.e., shape [].
+    self.global_step = tf.Variable(0, name="global_step")
+
+    # Nodes to compute the nce loss w/ candidate sampling.
+    labels_matrix = tf.reshape(
+        tf.cast(labels,
+                dtype=tf.int64),
+        [opts.batch_size, 1])
+
+    # Negative sampling.
+    sampled_ids, _, _ = (tf.nn.fixed_unigram_candidate_sampler(
+        true_classes=labels_matrix,
+        num_true=1,
+        num_sampled=opts.num_samples,
+        unique=True,
+        range_max=opts.vocab_size,
+        distortion=0.75,
+        unigrams=opts.vocab_counts.tolist()))
+
+    # Embeddings for examples: [batch_size, emb_dim]
+    example_emb = tf.nn.embedding_lookup(emb, examples)
+
+    # Weights for labels: [batch_size, emb_dim]
+    true_w = tf.nn.embedding_lookup(sm_w_t, labels)
+    # Biases for labels: [batch_size, 1]
+    true_b = tf.nn.embedding_lookup(sm_b, labels)
+
+    # Weights for sampled ids: [num_sampled, emb_dim]
+    sampled_w = tf.nn.embedding_lookup(sm_w_t, sampled_ids)
+    # Biases for sampled ids: [num_sampled, 1]
+    sampled_b = tf.nn.embedding_lookup(sm_b, sampled_ids)
+
+    # True logits: [batch_size, 1]
+    true_logits = tf.reduce_sum(tf.mul(example_emb, true_w), 1) + true_b
+
+    # Sampled logits: [batch_size, num_sampled]
+    # We replicate sampled noise labels for all examples in the batch
+    # using the matmul.
+    sampled_b_vec = tf.reshape(sampled_b, [opts.num_samples])
+    sampled_logits = tf.matmul(example_emb,
+                               sampled_w,
+                               transpose_b=True) + sampled_b_vec
+    return true_logits, sampled_logits
+
+  def nce_loss(self, true_logits, sampled_logits):
+    """Build the graph for the NCE loss."""
+
+    # cross-entropy(logits, labels)
+    opts = self._options
+    true_xent = tf.nn.sigmoid_cross_entropy_with_logits(
+        true_logits, tf.ones_like(true_logits))
+    sampled_xent = tf.nn.sigmoid_cross_entropy_with_logits(
+        sampled_logits, tf.zeros_like(sampled_logits))
+
+    # NCE-loss is the sum of the true and noise (sampled words)
+    # contributions, averaged over the batch.
+    nce_loss_tensor = (tf.reduce_sum(true_xent) +
+                       tf.reduce_sum(sampled_xent)) / opts.batch_size
+    return nce_loss_tensor
+
+  def optimize(self, loss):
+    """Build the graph to optimize the loss function."""
+
+    # Optimizer nodes.
+    # Linear learning rate decay.
+    opts = self._options
+    words_to_train = float(opts.words_per_epoch * opts.epochs_to_train)
+    lr = opts.learning_rate * tf.maximum(
+        0.0001, 1.0 - tf.cast(self._words, tf.float32) / words_to_train)
+    self._lr = lr
+    optimizer = tf.train.GradientDescentOptimizer(lr)
+    train = optimizer.minimize(loss,
+                               global_step=self.global_step,
+                               gate_gradients=optimizer.GATE_NONE)
+    self._train = train
+
+  def build_eval_graph(self):
+    """Build the eval graph."""
+    # Eval graph
+
+    # Each analogy task is to predict the 4th word (d) given three
+    # words: a, b, c.  E.g., a=italy, b=rome, c=france, we should
+    # predict d=paris.
+
+    # The eval feeds three vectors of word ids for a, b, c, each of
+    # which is of size N, where N is the number of analogies we want to
+    # evaluate in one batch.
+    analogy_a = tf.placeholder(dtype=tf.int32)  # [N]
+    analogy_b = tf.placeholder(dtype=tf.int32)  # [N]
+    analogy_c = tf.placeholder(dtype=tf.int32)  # [N]
+
+    # Normalized word embeddings of shape [vocab_size, emb_dim].
+    nemb = tf.nn.l2_normalize(self._emb, 1)
+
+    # Each row of a_emb, b_emb, c_emb is a word's embedding vector.
+    # They all have the shape [N, emb_dim]
+    a_emb = tf.gather(nemb, analogy_a)  # a's embs
+    b_emb = tf.gather(nemb, analogy_b)  # b's embs
+    c_emb = tf.gather(nemb, analogy_c)  # c's embs
+
+    # We expect that d's embedding vectors on the unit hyper-sphere is
+    # near: c_emb + (b_emb - a_emb), which has the shape [N, emb_dim].
+    target = c_emb + (b_emb - a_emb)
+
+    # Compute cosine distance between each pair of target and vocab.
+    # dist has shape [N, vocab_size].
+    dist = tf.matmul(target, nemb, transpose_b=True)
+
+    # For each question (row in dist), find the top 4 words.
+    _, pred_idx = tf.nn.top_k(dist, 4)
+
+    # Nodes for computing neighbors for a given word according to
+    # their cosine distance.
+    nearby_word = tf.placeholder(dtype=tf.int32)  # word id
+    nearby_emb = tf.gather(nemb, nearby_word)
+    nearby_dist = tf.matmul(nearby_emb, nemb, transpose_b=True)
+    nearby_val, nearby_idx = tf.nn.top_k(nearby_dist,
+                                         min(1000, self._options.vocab_size))
+
+    # Nodes in the construct graph which are used by training and
+    # evaluation to run/feed/fetch.
+    self._analogy_a = analogy_a
+    self._analogy_b = analogy_b
+    self._analogy_c = analogy_c
+    self._analogy_pred_idx = pred_idx
+    self._nearby_word = nearby_word
+    self._nearby_val = nearby_val
+    self._nearby_idx = nearby_idx
+
+  def build_graph(self):
+    """Build the graph for the full model."""
+    opts = self._options
+    # The training data. A text file.
+    (words, counts, words_per_epoch, self._epoch, self._words, examples,
+     labels) = word2vec.skipgram_word2vec(filename=opts.train_data,
+                                          batch_size=opts.batch_size,
+                                          window_size=opts.window_size,
+                                          min_count=opts.min_count,
+                                          subsample=opts.subsample)
+    (opts.vocab_words, opts.vocab_counts,
+     opts.words_per_epoch) = self._session.run([words, counts, words_per_epoch])
+    opts.vocab_size = len(opts.vocab_words)
+    print("Data file: ", opts.train_data)
+    print("Vocab size: ", opts.vocab_size - 1, " + UNK")
+    print("Words per epoch: ", opts.words_per_epoch)
+    self._examples = examples
+    self._labels = labels
+    self._id2word = opts.vocab_words
+    for i, w in enumerate(self._id2word):
+      self._word2id[w] = i
+    true_logits, sampled_logits = self.forward(examples, labels)
+    loss = self.nce_loss(true_logits, sampled_logits)
+    tf.contrib.deprecated.scalar_summary("NCE loss", loss)
+    self._loss = loss
+    self.optimize(loss)
+
+    # Properly initialize all variables.
+    tf.global_variables_initializer().run()
+
+    self.saver = tf.train.Saver()
+
+  def save_vocab(self):
+    """Save the vocabulary to a file so the model can be reloaded."""
+    opts = self._options
+    with open(os.path.join(opts.save_path, "vocab.txt"), "w") as f:
+      for i in xrange(opts.vocab_size):
+        vocab_word = tf.compat.as_text(opts.vocab_words[i]).encode("utf-8")
+        f.write("%s %d\n" % (vocab_word,
+                             opts.vocab_counts[i]))
+
+  def _train_thread_body(self):
+    initial_epoch, = self._session.run([self._epoch])
+    while True:
+      _, epoch = self._session.run([self._train, self._epoch])
+      if epoch != initial_epoch:
+        break
+
+  def train(self):
+    """Train the model."""
+    opts = self._options
+
+    initial_epoch, initial_words = self._session.run([self._epoch, self._words])
+
+    summary_op = tf.summary.merge_all()
+    summary_writer = tf.summary.FileWriter(opts.save_path, self._session.graph)
+    workers = []
+    for _ in xrange(opts.concurrent_steps):
+      t = threading.Thread(target=self._train_thread_body)
+      t.start()
+      workers.append(t)
+
+    last_words, last_time, last_summary_time = initial_words, time.time(), 0
+    last_checkpoint_time = 0
+    while True:
+      time.sleep(opts.statistics_interval)  # Reports our progress once a while.
+      (epoch, step, loss, words, lr) = self._session.run(
+          [self._epoch, self.global_step, self._loss, self._words, self._lr])
+      now = time.time()
+      last_words, last_time, rate = words, now, (words - last_words) / (
+          now - last_time)
+      print("Epoch %4d Step %8d: lr = %5.3f loss = %6.2f words/sec = %8.0f\r" %
+            (epoch, step, lr, loss, rate), end="")
+      sys.stdout.flush()
+      if now - last_summary_time > opts.summary_interval:
+        summary_str = self._session.run(summary_op)
+        summary_writer.add_summary(summary_str, step)
+        last_summary_time = now
+      if now - last_checkpoint_time > opts.checkpoint_interval:
+        self.saver.save(self._session,
+                        os.path.join(opts.save_path, "model.ckpt"),
+                        global_step=step.astype(int))
+        last_checkpoint_time = now
+      if epoch != initial_epoch:
+        break
+
+    for t in workers:
+      t.join()
+
+    return epoch
+
+  def _predict(self, analogy):
+    """Predict the top 4 answers for analogy questions."""
+    idx, = self._session.run([self._analogy_pred_idx], {
+        self._analogy_a: analogy[:, 0],
+        self._analogy_b: analogy[:, 1],
+        self._analogy_c: analogy[:, 2]
+    })
+    return idx
+
+  def eval(self):
+    """Evaluate analogy questions and reports accuracy."""
+
+    # How many questions we get right at precision@1.
+    correct = 0
+
+    try:
+      total = self._analogy_questions.shape[0]
+    except AttributeError as e:
+      raise AttributeError("Need to read analogy questions.")
+
+    start = 0
+    while start < total:
+      limit = start + 2500
+      sub = self._analogy_questions[start:limit, :]
+      idx = self._predict(sub)
+      start = limit
+      for question in xrange(sub.shape[0]):
+        for j in xrange(4):
+          if idx[question, j] == sub[question, 3]:
+            # Bingo! We predicted correctly. E.g., [italy, rome, france, paris].
+            correct += 1
+            break
+          elif idx[question, j] in sub[question, :3]:
+            # We need to skip words already in the question.
+            continue
+          else:
+            # The correct label is not the precision@1
+            break
+    print()
+    print("Eval %4d/%d accuracy = %4.1f%%" % (correct, total,
+                                              correct * 100.0 / total))
+
+  def analogy(self, w0, w1, w2):
+    """Predict word w3 as in w0:w1 vs w2:w3."""
+    wid = np.array([[self._word2id.get(w, 0) for w in [w0, w1, w2]]])
+    idx = self._predict(wid)
+    for c in [self._id2word[i] for i in idx[0, :]]:
+      if c not in [w0, w1, w2]:
+        print(c)
+        break
+    print("unknown")
+
+  def nearby(self, words, num=20):
+    """Prints out nearby words given a list of words."""
+    ids = np.array([self._word2id.get(x, 0) for x in words])
+    vals, idx = self._session.run(
+        [self._nearby_val, self._nearby_idx], {self._nearby_word: ids})
+    for i in xrange(len(words)):
+      print("\n%s\n=====================================" % (words[i]))
+      for (neighbor, distance) in zip(idx[i, :num], vals[i, :num]):
+        print("%-20s %6.4f" % (self._id2word[neighbor], distance))
+
+
+def _start_shell(local_ns=None):
+  # An interactive shell is useful for debugging/development.
+  import IPython
+  user_ns = {}
+  if local_ns:
+    user_ns.update(local_ns)
+  user_ns.update(globals())
+  IPython.start_ipython(argv=[], user_ns=user_ns)
+
+
+def main(_):
+  """Train a word2vec model."""
+  if not FLAGS.train_data or not FLAGS.eval_data or not FLAGS.save_path:
+    print("--train_data --eval_data and --save_path must be specified.")
+    sys.exit(1)
+  opts = Options()
+  with tf.Graph().as_default(), tf.Session() as session:
+    with tf.device("/cpu:0"):
+      model = Word2Vec(opts, session)
+      model.read_analogies() # Read analogy questions
+    for _ in xrange(opts.epochs_to_train):
+      model.train()  # Process one epoch
+      model.eval()  # Eval analogies.
+    # Perform a final save.
+    model.saver.save(session,
+                     os.path.join(opts.save_path, "model.ckpt"),
+                     global_step=model.global_step)
+    if FLAGS.interactive:
+      # E.g.,
+      # [0]: model.analogy(b'france', b'paris', b'russia')
+      # [1]: model.nearby([b'proton', b'elephant', b'maxwell'])
+      _start_shell(locals())
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 355 - 0
tutorials/embedding/word2vec_kernels.cc

@@ -0,0 +1,355 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/random/distribution_sampler.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/util/guarded_philox_random.h"
+
+namespace tensorflow {
+
+// Number of examples to precalculate.
+const int kPrecalc = 3000;
+// Number of words to read into a sentence before processing.
+const int kSentenceSize = 1000;
+
+namespace {
+
+bool ScanWord(StringPiece* input, string* word) {
+  str_util::RemoveLeadingWhitespace(input);
+  StringPiece tmp;
+  if (str_util::ConsumeNonWhitespace(input, &tmp)) {
+    word->assign(tmp.data(), tmp.size());
+    return true;
+  } else {
+    return false;
+  }
+}
+
+}  // end namespace
+
+class SkipgramWord2vecOp : public OpKernel {
+ public:
+  explicit SkipgramWord2vecOp(OpKernelConstruction* ctx)
+      : OpKernel(ctx), rng_(&philox_) {
+    string filename;
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("filename", &filename));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_size", &batch_size_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("min_count", &min_count_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("subsample", &subsample_));
+    OP_REQUIRES_OK(ctx, Init(ctx->env(), filename));
+
+    mutex_lock l(mu_);
+    example_pos_ = corpus_size_;
+    label_pos_ = corpus_size_;
+    label_limit_ = corpus_size_;
+    sentence_index_ = kSentenceSize;
+    for (int i = 0; i < kPrecalc; ++i) {
+      NextExample(&precalc_examples_[i].input, &precalc_examples_[i].label);
+    }
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    Tensor words_per_epoch(DT_INT64, TensorShape({}));
+    Tensor current_epoch(DT_INT32, TensorShape({}));
+    Tensor total_words_processed(DT_INT64, TensorShape({}));
+    Tensor examples(DT_INT32, TensorShape({batch_size_}));
+    auto Texamples = examples.flat<int32>();
+    Tensor labels(DT_INT32, TensorShape({batch_size_}));
+    auto Tlabels = labels.flat<int32>();
+    {
+      mutex_lock l(mu_);
+      for (int i = 0; i < batch_size_; ++i) {
+        Texamples(i) = precalc_examples_[precalc_index_].input;
+        Tlabels(i) = precalc_examples_[precalc_index_].label;
+        precalc_index_++;
+        if (precalc_index_ >= kPrecalc) {
+          precalc_index_ = 0;
+          for (int j = 0; j < kPrecalc; ++j) {
+            NextExample(&precalc_examples_[j].input,
+                        &precalc_examples_[j].label);
+          }
+        }
+      }
+      words_per_epoch.scalar<int64>()() = corpus_size_;
+      current_epoch.scalar<int32>()() = current_epoch_;
+      total_words_processed.scalar<int64>()() = total_words_processed_;
+    }
+    ctx->set_output(0, word_);
+    ctx->set_output(1, freq_);
+    ctx->set_output(2, words_per_epoch);
+    ctx->set_output(3, current_epoch);
+    ctx->set_output(4, total_words_processed);
+    ctx->set_output(5, examples);
+    ctx->set_output(6, labels);
+  }
+
+ private:
+  struct Example {
+    int32 input;
+    int32 label;
+  };
+
+  int32 batch_size_ = 0;
+  int32 window_size_ = 5;
+  float subsample_ = 1e-3;
+  int min_count_ = 5;
+  int32 vocab_size_ = 0;
+  Tensor word_;
+  Tensor freq_;
+  int64 corpus_size_ = 0;
+  std::vector<int32> corpus_;
+  std::vector<Example> precalc_examples_;
+  int precalc_index_ = 0;
+  std::vector<int32> sentence_;
+  int sentence_index_ = 0;
+
+  mutex mu_;
+  random::PhiloxRandom philox_ GUARDED_BY(mu_);
+  random::SimplePhilox rng_ GUARDED_BY(mu_);
+  int32 current_epoch_ GUARDED_BY(mu_) = -1;
+  int64 total_words_processed_ GUARDED_BY(mu_) = 0;
+  int32 example_pos_ GUARDED_BY(mu_);
+  int32 label_pos_ GUARDED_BY(mu_);
+  int32 label_limit_ GUARDED_BY(mu_);
+
+  // {example_pos_, label_pos_} is the cursor for the next example.
+  // example_pos_ wraps around at the end of corpus_. For each
+  // example, we randomly generate [label_pos_, label_limit) for
+  // labels.
+  void NextExample(int32* example, int32* label) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+    while (true) {
+      if (label_pos_ >= label_limit_) {
+        ++total_words_processed_;
+        ++sentence_index_;
+        if (sentence_index_ >= kSentenceSize) {
+          sentence_index_ = 0;
+          for (int i = 0; i < kSentenceSize; ++i, ++example_pos_) {
+            if (example_pos_ >= corpus_size_) {
+              ++current_epoch_;
+              example_pos_ = 0;
+            }
+            if (subsample_ > 0) {
+              int32 word_freq = freq_.flat<int32>()(corpus_[example_pos_]);
+              // See Eq. 5 in http://arxiv.org/abs/1310.4546
+              float keep_prob =
+                  (std::sqrt(word_freq / (subsample_ * corpus_size_)) + 1) *
+                  (subsample_ * corpus_size_) / word_freq;
+              if (rng_.RandFloat() > keep_prob) {
+                i--;
+                continue;
+              }
+            }
+            sentence_[i] = corpus_[example_pos_];
+          }
+        }
+        const int32 skip = 1 + rng_.Uniform(window_size_);
+        label_pos_ = std::max<int32>(0, sentence_index_ - skip);
+        label_limit_ =
+            std::min<int32>(kSentenceSize, sentence_index_ + skip + 1);
+      }
+      if (sentence_index_ != label_pos_) {
+        break;
+      }
+      ++label_pos_;
+    }
+    *example = sentence_[sentence_index_];
+    *label = sentence_[label_pos_++];
+  }
+
+  Status Init(Env* env, const string& filename) {
+    string data;
+    TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &data));
+    StringPiece input = data;
+    string w;
+    corpus_size_ = 0;
+    std::unordered_map<string, int32> word_freq;
+    while (ScanWord(&input, &w)) {
+      ++(word_freq[w]);
+      ++corpus_size_;
+    }
+    if (corpus_size_ < window_size_ * 10) {
+      return errors::InvalidArgument("The text file ", filename,
+                                     " contains too little data: ",
+                                     corpus_size_, " words");
+    }
+    typedef std::pair<string, int32> WordFreq;
+    std::vector<WordFreq> ordered;
+    for (const auto& p : word_freq) {
+      if (p.second >= min_count_) ordered.push_back(p);
+    }
+    LOG(INFO) << "Data file: " << filename << " contains " << data.size()
+              << " bytes, " << corpus_size_ << " words, " << word_freq.size()
+              << " unique words, " << ordered.size()
+              << " unique frequent words.";
+    word_freq.clear();
+    std::sort(ordered.begin(), ordered.end(),
+              [](const WordFreq& x, const WordFreq& y) {
+                return x.second > y.second;
+              });
+    vocab_size_ = static_cast<int32>(1 + ordered.size());
+    Tensor word(DT_STRING, TensorShape({vocab_size_}));
+    Tensor freq(DT_INT32, TensorShape({vocab_size_}));
+    word.flat<string>()(0) = "UNK";
+    static const int32 kUnkId = 0;
+    std::unordered_map<string, int32> word_id;
+    int64 total_counted = 0;
+    for (std::size_t i = 0; i < ordered.size(); ++i) {
+      const auto& w = ordered[i].first;
+      auto id = i + 1;
+      word.flat<string>()(id) = w;
+      auto word_count = ordered[i].second;
+      freq.flat<int32>()(id) = word_count;
+      total_counted += word_count;
+      word_id[w] = id;
+    }
+    freq.flat<int32>()(kUnkId) = corpus_size_ - total_counted;
+    word_ = word;
+    freq_ = freq;
+    corpus_.reserve(corpus_size_);
+    input = data;
+    while (ScanWord(&input, &w)) {
+      corpus_.push_back(gtl::FindWithDefault(word_id, w, kUnkId));
+    }
+    precalc_examples_.resize(kPrecalc);
+    sentence_.resize(kSentenceSize);
+    return Status::OK();
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("SkipgramWord2vec").Device(DEVICE_CPU), SkipgramWord2vecOp);
+
+class NegTrainWord2vecOp : public OpKernel {
+ public:
+  explicit NegTrainWord2vecOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    base_.Init(0, 0);
+
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("num_negative_samples", &num_samples_));
+
+    std::vector<int32> vocab_count;
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_count", &vocab_count));
+
+    std::vector<float> vocab_weights;
+    vocab_weights.reserve(vocab_count.size());
+    for (const auto& f : vocab_count) {
+      float r = std::pow(static_cast<float>(f), 0.75f);
+      vocab_weights.push_back(r);
+    }
+    sampler_ = new random::DistributionSampler(vocab_weights);
+  }
+
+  ~NegTrainWord2vecOp() { delete sampler_; }
+
+  void Compute(OpKernelContext* ctx) override {
+    Tensor w_in = ctx->mutable_input(0, false);
+    OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_in.shape()),
+                errors::InvalidArgument("Must be a matrix"));
+    Tensor w_out = ctx->mutable_input(1, false);
+    OP_REQUIRES(ctx, w_in.shape() == w_out.shape(),
+                errors::InvalidArgument("w_in.shape == w_out.shape"));
+    const Tensor& examples = ctx->input(2);
+    OP_REQUIRES(ctx, TensorShapeUtils::IsVector(examples.shape()),
+                errors::InvalidArgument("Must be a vector"));
+    const Tensor& labels = ctx->input(3);
+    OP_REQUIRES(ctx, examples.shape() == labels.shape(),
+                errors::InvalidArgument("examples.shape == labels.shape"));
+    const Tensor& learning_rate = ctx->input(4);
+    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(learning_rate.shape()),
+                errors::InvalidArgument("Must be a scalar"));
+
+    auto Tw_in = w_in.matrix<float>();
+    auto Tw_out = w_out.matrix<float>();
+    auto Texamples = examples.flat<int32>();
+    auto Tlabels = labels.flat<int32>();
+    auto lr = learning_rate.scalar<float>()();
+    const int64 vocab_size = w_in.dim_size(0);
+    const int64 dims = w_in.dim_size(1);
+    const int64 batch_size = examples.dim_size(0);
+    OP_REQUIRES(ctx, vocab_size == sampler_->num(),
+                errors::InvalidArgument("vocab_size mismatches: ", vocab_size,
+                                        " vs. ", sampler_->num()));
+
+    // Gradient accumulator for v_in.
+    Tensor buf(DT_FLOAT, TensorShape({dims}));
+    auto Tbuf = buf.flat<float>();
+
+    // Scalar buffer to hold sigmoid(+/- dot).
+    Tensor g_buf(DT_FLOAT, TensorShape({}));
+    auto g = g_buf.scalar<float>();
+
+    // The following loop needs 2 random 32-bit values per negative
+    // sample.  We reserve 8 values per sample just in case the
+    // underlying implementation changes.
+    auto rnd = base_.ReserveSamples32(batch_size * num_samples_ * 8);
+    random::SimplePhilox srnd(&rnd);
+
+    for (int64 i = 0; i < batch_size; ++i) {
+      const int32 example = Texamples(i);
+      DCHECK(0 <= example && example < vocab_size) << example;
+      const int32 label = Tlabels(i);
+      DCHECK(0 <= label && label < vocab_size) << label;
+      auto v_in = Tw_in.chip<0>(example);
+
+      // Positive: example predicts label.
+      //   forward: x = v_in' * v_out
+      //            l = log(sigmoid(x))
+      //   backward: dl/dx = g = sigmoid(-x)
+      //             dl/d(v_in) = g * v_out'
+      //             dl/d(v_out) = v_in' * g
+      {
+        auto v_out = Tw_out.chip<0>(label);
+        auto dot = (v_in * v_out).sum();
+        g = (dot.exp() + 1.f).inverse();
+        Tbuf = v_out * (g() * lr);
+        v_out += v_in * (g() * lr);
+      }
+
+      // Negative samples:
+      //   forward: x = v_in' * v_sample
+      //            l = log(sigmoid(-x))
+      //   backward: dl/dx = g = -sigmoid(x)
+      //             dl/d(v_in) = g * v_out'
+      //             dl/d(v_out) = v_in' * g
+      for (int j = 0; j < num_samples_; ++j) {
+        const int sample = sampler_->Sample(&srnd);
+        if (sample == label) continue;  // Skip.
+        auto v_sample = Tw_out.chip<0>(sample);
+        auto dot = (v_in * v_sample).sum();
+        g = -((-dot).exp() + 1.f).inverse();
+        Tbuf += v_sample * (g() * lr);
+        v_sample += v_in * (g() * lr);
+      }
+
+      // Applies the gradient on v_in.
+      v_in += Tbuf;
+    }
+  }
+
+ private:
+  int32 num_samples_ = 0;
+  random::DistributionSampler* sampler_ = nullptr;
+  GuardedPhiloxRandom base_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("NegTrainWord2vec").Device(DEVICE_CPU), NegTrainWord2vecOp);
+
+}  // end namespace tensorflow

+ 73 - 0
tutorials/embedding/word2vec_ops.cc

@@ -0,0 +1,73 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("SkipgramWord2vec")
+    .Output("vocab_word: string")
+    .Output("vocab_freq: int32")
+    .Output("words_per_epoch: int64")
+    .Output("current_epoch: int32")
+    .Output("total_words_processed: int64")
+    .Output("examples: int32")
+    .Output("labels: int32")
+    .SetIsStateful()
+    .Attr("filename: string")
+    .Attr("batch_size: int")
+    .Attr("window_size: int = 5")
+    .Attr("min_count: int = 5")
+    .Attr("subsample: float = 1e-3")
+    .Doc(R"doc(
+Parses a text file and creates a batch of examples.
+
+vocab_word: A vector of words in the corpus.
+vocab_freq: Frequencies of words. Sorted in the non-ascending order.
+words_per_epoch: Number of words per epoch in the data file.
+current_epoch: The current epoch number.
+total_words_processed: The total number of words processed so far.
+examples: A vector of word ids.
+labels: A vector of word ids.
+filename: The corpus's text file name.
+batch_size: The size of produced batch.
+window_size: The number of words to predict to the left and right of the target.
+min_count: The minimum number of word occurrences for it to be included in the
+    vocabulary.
+subsample: Threshold for word occurrence. Words that appear with higher
+    frequency will be randomly down-sampled. Set to 0 to disable.
+)doc");
+
+REGISTER_OP("NegTrainWord2vec")
+    .Input("w_in: Ref(float)")
+    .Input("w_out: Ref(float)")
+    .Input("examples: int32")
+    .Input("labels: int32")
+    .Input("lr: float")
+    .SetIsStateful()
+    .Attr("vocab_count: list(int)")
+    .Attr("num_negative_samples: int")
+    .Doc(R"doc(
+Training via negative sampling.
+
+w_in: input word embedding.
+w_out: output word embedding.
+examples: A vector of word ids.
+labels: A vector of word ids.
+vocab_count: Count of words in the vocabulary.
+num_negative_samples: Number of negative samples per example.
+)doc");
+
+}  // end namespace tensorflow

+ 439 - 0
tutorials/embedding/word2vec_optimized.py

@@ -0,0 +1,439 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Multi-threaded word2vec unbatched skip-gram model.
+
+Trains the model described in:
+(Mikolov, et. al.) Efficient Estimation of Word Representations in Vector Space
+ICLR 2013.
+http://arxiv.org/abs/1301.3781
+This model does true SGD (i.e. no minibatching). To do this efficiently, custom
+ops are used to sequentially process data within a 'batch'.
+
+The key ops used are:
+* skipgram custom op that does input processing.
+* neg_train custom op that efficiently calculates and applies the gradient using
+  true SGD.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import threading
+import time
+
+from six.moves import xrange  # pylint: disable=redefined-builtin
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.embedding import gen_word2vec as word2vec
+
+flags = tf.app.flags
+
+flags.DEFINE_string("save_path", None, "Directory to write the model.")
+flags.DEFINE_string(
+    "train_data", None,
+    "Training data. E.g., unzipped file http://mattmahoney.net/dc/text8.zip.")
+flags.DEFINE_string(
+    "eval_data", None, "Analogy questions. "
+    "See README.md for how to get 'questions-words.txt'.")
+flags.DEFINE_integer("embedding_size", 200, "The embedding dimension size.")
+flags.DEFINE_integer(
+    "epochs_to_train", 15,
+    "Number of epochs to train. Each epoch processes the training data once "
+    "completely.")
+flags.DEFINE_float("learning_rate", 0.025, "Initial learning rate.")
+flags.DEFINE_integer("num_neg_samples", 25,
+                     "Negative samples per training example.")
+flags.DEFINE_integer("batch_size", 500,
+                     "Numbers of training examples each step processes "
+                     "(no minibatching).")
+flags.DEFINE_integer("concurrent_steps", 12,
+                     "The number of concurrent training steps.")
+flags.DEFINE_integer("window_size", 5,
+                     "The number of words to predict to the left and right "
+                     "of the target word.")
+flags.DEFINE_integer("min_count", 5,
+                     "The minimum number of word occurrences for it to be "
+                     "included in the vocabulary.")
+flags.DEFINE_float("subsample", 1e-3,
+                   "Subsample threshold for word occurrence. Words that appear "
+                   "with higher frequency will be randomly down-sampled. Set "
+                   "to 0 to disable.")
+flags.DEFINE_boolean(
+    "interactive", False,
+    "If true, enters an IPython interactive session to play with the trained "
+    "model. E.g., try model.analogy(b'france', b'paris', b'russia') and "
+    "model.nearby([b'proton', b'elephant', b'maxwell'])")
+
+FLAGS = flags.FLAGS
+
+
+class Options(object):
+  """Options used by our word2vec model."""
+
+  def __init__(self):
+    # Model options.
+
+    # Embedding dimension.
+    self.emb_dim = FLAGS.embedding_size
+
+    # Training options.
+
+    # The training text file.
+    self.train_data = FLAGS.train_data
+
+    # Number of negative samples per example.
+    self.num_samples = FLAGS.num_neg_samples
+
+    # The initial learning rate.
+    self.learning_rate = FLAGS.learning_rate
+
+    # Number of epochs to train. After these many epochs, the learning
+    # rate decays linearly to zero and the training stops.
+    self.epochs_to_train = FLAGS.epochs_to_train
+
+    # Concurrent training steps.
+    self.concurrent_steps = FLAGS.concurrent_steps
+
+    # Number of examples for one training step.
+    self.batch_size = FLAGS.batch_size
+
+    # The number of words to predict to the left and right of the target word.
+    self.window_size = FLAGS.window_size
+
+    # The minimum number of word occurrences for it to be included in the
+    # vocabulary.
+    self.min_count = FLAGS.min_count
+
+    # Subsampling threshold for word occurrence.
+    self.subsample = FLAGS.subsample
+
+    # Where to write out summaries.
+    self.save_path = FLAGS.save_path
+    if not os.path.exists(self.save_path):
+      os.makedirs(self.save_path)
+
+    # Eval options.
+
+    # The text file for eval.
+    self.eval_data = FLAGS.eval_data
+
+
+class Word2Vec(object):
+  """Word2Vec model (Skipgram)."""
+
+  def __init__(self, options, session):
+    self._options = options
+    self._session = session
+    self._word2id = {}
+    self._id2word = []
+    self.build_graph()
+    self.build_eval_graph()
+    self.save_vocab()
+
+  def read_analogies(self):
+    """Reads through the analogy question file.
+
+    Returns:
+      questions: a [n, 4] numpy array containing the analogy question's
+                 word ids.
+      questions_skipped: questions skipped due to unknown words.
+    """
+    questions = []
+    questions_skipped = 0
+    with open(self._options.eval_data, "rb") as analogy_f:
+      for line in analogy_f:
+        if line.startswith(b":"):  # Skip comments.
+          continue
+        words = line.strip().lower().split(b" ")
+        ids = [self._word2id.get(w.strip()) for w in words]
+        if None in ids or len(ids) != 4:
+          questions_skipped += 1
+        else:
+          questions.append(np.array(ids))
+    print("Eval analogy file: ", self._options.eval_data)
+    print("Questions: ", len(questions))
+    print("Skipped: ", questions_skipped)
+    self._analogy_questions = np.array(questions, dtype=np.int32)
+
+  def build_graph(self):
+    """Build the model graph."""
+    opts = self._options
+
+    # The training data. A text file.
+    (words, counts, words_per_epoch, current_epoch, total_words_processed,
+     examples, labels) = word2vec.skipgram_word2vec(filename=opts.train_data,
+                                                    batch_size=opts.batch_size,
+                                                    window_size=opts.window_size,
+                                                    min_count=opts.min_count,
+                                                    subsample=opts.subsample)
+    (opts.vocab_words, opts.vocab_counts,
+     opts.words_per_epoch) = self._session.run([words, counts, words_per_epoch])
+    opts.vocab_size = len(opts.vocab_words)
+    print("Data file: ", opts.train_data)
+    print("Vocab size: ", opts.vocab_size - 1, " + UNK")
+    print("Words per epoch: ", opts.words_per_epoch)
+
+    self._id2word = opts.vocab_words
+    for i, w in enumerate(self._id2word):
+      self._word2id[w] = i
+
+    # Declare all variables we need.
+    # Input words embedding: [vocab_size, emb_dim]
+    w_in = tf.Variable(
+        tf.random_uniform(
+            [opts.vocab_size,
+             opts.emb_dim], -0.5 / opts.emb_dim, 0.5 / opts.emb_dim),
+        name="w_in")
+
+    # Global step: scalar, i.e., shape [].
+    w_out = tf.Variable(tf.zeros([opts.vocab_size, opts.emb_dim]), name="w_out")
+
+    # Global step: []
+    global_step = tf.Variable(0, name="global_step")
+
+    # Linear learning rate decay.
+    words_to_train = float(opts.words_per_epoch * opts.epochs_to_train)
+    lr = opts.learning_rate * tf.maximum(
+        0.0001,
+        1.0 - tf.cast(total_words_processed, tf.float32) / words_to_train)
+
+    # Training nodes.
+    inc = global_step.assign_add(1)
+    with tf.control_dependencies([inc]):
+      train = word2vec.neg_train_word2vec(w_in,
+                                          w_out,
+                                          examples,
+                                          labels,
+                                          lr,
+                                          vocab_count=opts.vocab_counts.tolist(),
+                                          num_negative_samples=opts.num_samples)
+
+    self._w_in = w_in
+    self._examples = examples
+    self._labels = labels
+    self._lr = lr
+    self._train = train
+    self.global_step = global_step
+    self._epoch = current_epoch
+    self._words = total_words_processed
+
+  def save_vocab(self):
+    """Save the vocabulary to a file so the model can be reloaded."""
+    opts = self._options
+    with open(os.path.join(opts.save_path, "vocab.txt"), "w") as f:
+      for i in xrange(opts.vocab_size):
+        vocab_word = tf.compat.as_text(opts.vocab_words[i]).encode("utf-8")
+        f.write("%s %d\n" % (vocab_word,
+                             opts.vocab_counts[i]))
+
+  def build_eval_graph(self):
+    """Build the evaluation graph."""
+    # Eval graph
+    opts = self._options
+
+    # Each analogy task is to predict the 4th word (d) given three
+    # words: a, b, c.  E.g., a=italy, b=rome, c=france, we should
+    # predict d=paris.
+
+    # The eval feeds three vectors of word ids for a, b, c, each of
+    # which is of size N, where N is the number of analogies we want to
+    # evaluate in one batch.
+    analogy_a = tf.placeholder(dtype=tf.int32)  # [N]
+    analogy_b = tf.placeholder(dtype=tf.int32)  # [N]
+    analogy_c = tf.placeholder(dtype=tf.int32)  # [N]
+
+    # Normalized word embeddings of shape [vocab_size, emb_dim].
+    nemb = tf.nn.l2_normalize(self._w_in, 1)
+
+    # Each row of a_emb, b_emb, c_emb is a word's embedding vector.
+    # They all have the shape [N, emb_dim]
+    a_emb = tf.gather(nemb, analogy_a)  # a's embs
+    b_emb = tf.gather(nemb, analogy_b)  # b's embs
+    c_emb = tf.gather(nemb, analogy_c)  # c's embs
+
+    # We expect that d's embedding vectors on the unit hyper-sphere is
+    # near: c_emb + (b_emb - a_emb), which has the shape [N, emb_dim].
+    target = c_emb + (b_emb - a_emb)
+
+    # Compute cosine distance between each pair of target and vocab.
+    # dist has shape [N, vocab_size].
+    dist = tf.matmul(target, nemb, transpose_b=True)
+
+    # For each question (row in dist), find the top 4 words.
+    _, pred_idx = tf.nn.top_k(dist, 4)
+
+    # Nodes for computing neighbors for a given word according to
+    # their cosine distance.
+    nearby_word = tf.placeholder(dtype=tf.int32)  # word id
+    nearby_emb = tf.gather(nemb, nearby_word)
+    nearby_dist = tf.matmul(nearby_emb, nemb, transpose_b=True)
+    nearby_val, nearby_idx = tf.nn.top_k(nearby_dist,
+                                         min(1000, opts.vocab_size))
+
+    # Nodes in the construct graph which are used by training and
+    # evaluation to run/feed/fetch.
+    self._analogy_a = analogy_a
+    self._analogy_b = analogy_b
+    self._analogy_c = analogy_c
+    self._analogy_pred_idx = pred_idx
+    self._nearby_word = nearby_word
+    self._nearby_val = nearby_val
+    self._nearby_idx = nearby_idx
+
+    # Properly initialize all variables.
+    tf.global_variables_initializer().run()
+
+    self.saver = tf.train.Saver()
+
+  def _train_thread_body(self):
+    initial_epoch, = self._session.run([self._epoch])
+    while True:
+      _, epoch = self._session.run([self._train, self._epoch])
+      if epoch != initial_epoch:
+        break
+
+  def train(self):
+    """Train the model."""
+    opts = self._options
+
+    initial_epoch, initial_words = self._session.run([self._epoch, self._words])
+
+    workers = []
+    for _ in xrange(opts.concurrent_steps):
+      t = threading.Thread(target=self._train_thread_body)
+      t.start()
+      workers.append(t)
+
+    last_words, last_time = initial_words, time.time()
+    while True:
+      time.sleep(5)  # Reports our progress once a while.
+      (epoch, step, words, lr) = self._session.run(
+          [self._epoch, self.global_step, self._words, self._lr])
+      now = time.time()
+      last_words, last_time, rate = words, now, (words - last_words) / (
+          now - last_time)
+      print("Epoch %4d Step %8d: lr = %5.3f words/sec = %8.0f\r" % (epoch, step,
+                                                                    lr, rate),
+            end="")
+      sys.stdout.flush()
+      if epoch != initial_epoch:
+        break
+
+    for t in workers:
+      t.join()
+
+  def _predict(self, analogy):
+    """Predict the top 4 answers for analogy questions."""
+    idx, = self._session.run([self._analogy_pred_idx], {
+        self._analogy_a: analogy[:, 0],
+        self._analogy_b: analogy[:, 1],
+        self._analogy_c: analogy[:, 2]
+    })
+    return idx
+
+  def eval(self):
+    """Evaluate analogy questions and reports accuracy."""
+
+    # How many questions we get right at precision@1.
+    correct = 0
+
+    try:
+      total = self._analogy_questions.shape[0]
+    except AttributeError as e:
+      raise AttributeError("Need to read analogy questions.")
+
+    start = 0
+    while start < total:
+      limit = start + 2500
+      sub = self._analogy_questions[start:limit, :]
+      idx = self._predict(sub)
+      start = limit
+      for question in xrange(sub.shape[0]):
+        for j in xrange(4):
+          if idx[question, j] == sub[question, 3]:
+            # Bingo! We predicted correctly. E.g., [italy, rome, france, paris].
+            correct += 1
+            break
+          elif idx[question, j] in sub[question, :3]:
+            # We need to skip words already in the question.
+            continue
+          else:
+            # The correct label is not the precision@1
+            break
+    print()
+    print("Eval %4d/%d accuracy = %4.1f%%" % (correct, total,
+                                              correct * 100.0 / total))
+
+  def analogy(self, w0, w1, w2):
+    """Predict word w3 as in w0:w1 vs w2:w3."""
+    wid = np.array([[self._word2id.get(w, 0) for w in [w0, w1, w2]]])
+    idx = self._predict(wid)
+    for c in [self._id2word[i] for i in idx[0, :]]:
+      if c not in [w0, w1, w2]:
+        print(c)
+        break
+    print("unknown")
+
+  def nearby(self, words, num=20):
+    """Prints out nearby words given a list of words."""
+    ids = np.array([self._word2id.get(x, 0) for x in words])
+    vals, idx = self._session.run(
+        [self._nearby_val, self._nearby_idx], {self._nearby_word: ids})
+    for i in xrange(len(words)):
+      print("\n%s\n=====================================" % (words[i]))
+      for (neighbor, distance) in zip(idx[i, :num], vals[i, :num]):
+        print("%-20s %6.4f" % (self._id2word[neighbor], distance))
+
+
+def _start_shell(local_ns=None):
+  # An interactive shell is useful for debugging/development.
+  import IPython
+  user_ns = {}
+  if local_ns:
+    user_ns.update(local_ns)
+  user_ns.update(globals())
+  IPython.start_ipython(argv=[], user_ns=user_ns)
+
+
+def main(_):
+  """Train a word2vec model."""
+  if not FLAGS.train_data or not FLAGS.eval_data or not FLAGS.save_path:
+    print("--train_data --eval_data and --save_path must be specified.")
+    sys.exit(1)
+  opts = Options()
+  with tf.Graph().as_default(), tf.Session() as session:
+    with tf.device("/cpu:0"):
+      model = Word2Vec(opts, session)
+      model.read_analogies() # Read analogy questions
+    for _ in xrange(opts.epochs_to_train):
+      model.train()  # Process one epoch
+      model.eval()  # Eval analogies.
+    # Perform a final save.
+    model.saver.save(session, os.path.join(opts.save_path, "model.ckpt"),
+                     global_step=model.global_step)
+    if FLAGS.interactive:
+      # E.g.,
+      # [0]: model.analogy(b'france', b'paris', b'russia')
+      # [1]: model.nearby([b'proton', b'elephant', b'maxwell'])
+      _start_shell(locals())
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 62 - 0
tutorials/embedding/word2vec_optimized_test.py

@@ -0,0 +1,62 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for word2vec_optimized module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import tensorflow as tf
+
+from tensorflow.models.embedding import word2vec_optimized
+
+flags = tf.app.flags
+
+FLAGS = flags.FLAGS
+
+
+class Word2VecTest(tf.test.TestCase):
+
+  def setUp(self):
+    FLAGS.train_data = os.path.join(self.get_temp_dir() + "test-text.txt")
+    FLAGS.eval_data = os.path.join(self.get_temp_dir() + "eval-text.txt")
+    FLAGS.save_path = self.get_temp_dir()
+    with open(FLAGS.train_data, "w") as f:
+      f.write(
+          """alice was beginning to get very tired of sitting by her sister on
+          the bank, and of having nothing to do: once or twice she had peeped
+          into the book her sister was reading, but it had no pictures or
+          conversations in it, 'and what is the use of a book,' thought alice
+          'without pictures or conversations?' So she was considering in her own
+          mind (as well as she could, for the hot day made her feel very sleepy
+          and stupid), whether the pleasure of making a daisy-chain would be
+          worth the trouble of getting up and picking the daisies, when suddenly
+          a White rabbit with pink eyes ran close by her.\n""")
+      with open(FLAGS.eval_data, "w") as f:
+        f.write("alice she rabbit once\n")
+
+  def testWord2VecOptimized(self):
+    FLAGS.batch_size = 5
+    FLAGS.num_neg_samples = 10
+    FLAGS.epochs_to_train = 1
+    FLAGS.min_count = 0
+    word2vec_optimized.main([])
+
+
+if __name__ == "__main__":
+  tf.test.main()

+ 62 - 0
tutorials/embedding/word2vec_test.py

@@ -0,0 +1,62 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for word2vec module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import tensorflow as tf
+
+from tensorflow.models.embedding import word2vec
+
+flags = tf.app.flags
+
+FLAGS = flags.FLAGS
+
+
+class Word2VecTest(tf.test.TestCase):
+
+  def setUp(self):
+    FLAGS.train_data = os.path.join(self.get_temp_dir(), "test-text.txt")
+    FLAGS.eval_data = os.path.join(self.get_temp_dir(), "eval-text.txt")
+    FLAGS.save_path = self.get_temp_dir()
+    with open(FLAGS.train_data, "w") as f:
+      f.write(
+          """alice was beginning to get very tired of sitting by her sister on
+          the bank, and of having nothing to do: once or twice she had peeped
+          into the book her sister was reading, but it had no pictures or
+          conversations in it, 'and what is the use of a book,' thought alice
+          'without pictures or conversations?' So she was considering in her own
+          mind (as well as she could, for the hot day made her feel very sleepy
+          and stupid), whether the pleasure of making a daisy-chain would be
+          worth the trouble of getting up and picking the daisies, when suddenly
+          a White rabbit with pink eyes ran close by her.\n""")
+      with open(FLAGS.eval_data, "w") as f:
+        f.write("alice she rabbit once\n")
+
+  def testWord2Vec(self):
+    FLAGS.batch_size = 5
+    FLAGS.num_neg_samples = 10
+    FLAGS.epochs_to_train = 1
+    FLAGS.min_count = 0
+    word2vec.main([])
+
+
+if __name__ == "__main__":
+  tf.test.main()

+ 0 - 0
tutorials/image/__init__.py


+ 29 - 0
tutorials/image/alexnet/BUILD

@@ -0,0 +1,29 @@
+# Description:
+# Benchmark for AlexNet.
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_binary(
+    name = "alexnet_benchmark",
+    srcs = [
+        "alexnet_benchmark.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)

+ 0 - 0
tutorials/image/alexnet/__init__.py


+ 246 - 0
tutorials/image/alexnet/alexnet_benchmark.py

@@ -0,0 +1,246 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Timing benchmark for AlexNet inference.
+
+To run, use:
+  bazel run -c opt --config=cuda \
+      third_party/tensorflow/models/image/alexnet:alexnet_benchmark
+
+Across 100 steps on batch size = 128.
+
+Forward pass:
+Run on Tesla K40c: 145 +/- 1.5 ms / batch
+Run on Titan X:     70 +/- 0.1 ms / batch
+
+Forward-backward pass:
+Run on Tesla K40c: 480 +/- 48 ms / batch
+Run on Titan X:    244 +/- 30 ms / batch
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+from datetime import datetime
+import math
+import sys
+import time
+
+from six.moves import xrange  # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+FLAGS = None
+
+
+def print_activations(t):
+  print(t.op.name, ' ', t.get_shape().as_list())
+
+
+def inference(images):
+  """Build the AlexNet model.
+
+  Args:
+    images: Images Tensor
+
+  Returns:
+    pool5: the last Tensor in the convolutional component of AlexNet.
+    parameters: a list of Tensors corresponding to the weights and biases of the
+        AlexNet model.
+  """
+  parameters = []
+  # conv1
+  with tf.name_scope('conv1') as scope:
+    kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype=tf.float32,
+                                             stddev=1e-1), name='weights')
+    conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding='SAME')
+    biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32),
+                         trainable=True, name='biases')
+    bias = tf.nn.bias_add(conv, biases)
+    conv1 = tf.nn.relu(bias, name=scope)
+    print_activations(conv1)
+    parameters += [kernel, biases]
+
+  # lrn1
+  # TODO(shlens, jiayq): Add a GPU version of local response normalization.
+
+  # pool1
+  pool1 = tf.nn.max_pool(conv1,
+                         ksize=[1, 3, 3, 1],
+                         strides=[1, 2, 2, 1],
+                         padding='VALID',
+                         name='pool1')
+  print_activations(pool1)
+
+  # conv2
+  with tf.name_scope('conv2') as scope:
+    kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], dtype=tf.float32,
+                                             stddev=1e-1), name='weights')
+    conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding='SAME')
+    biases = tf.Variable(tf.constant(0.0, shape=[192], dtype=tf.float32),
+                         trainable=True, name='biases')
+    bias = tf.nn.bias_add(conv, biases)
+    conv2 = tf.nn.relu(bias, name=scope)
+    parameters += [kernel, biases]
+  print_activations(conv2)
+
+  # pool2
+  pool2 = tf.nn.max_pool(conv2,
+                         ksize=[1, 3, 3, 1],
+                         strides=[1, 2, 2, 1],
+                         padding='VALID',
+                         name='pool2')
+  print_activations(pool2)
+
+  # conv3
+  with tf.name_scope('conv3') as scope:
+    kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384],
+                                             dtype=tf.float32,
+                                             stddev=1e-1), name='weights')
+    conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding='SAME')
+    biases = tf.Variable(tf.constant(0.0, shape=[384], dtype=tf.float32),
+                         trainable=True, name='biases')
+    bias = tf.nn.bias_add(conv, biases)
+    conv3 = tf.nn.relu(bias, name=scope)
+    parameters += [kernel, biases]
+    print_activations(conv3)
+
+  # conv4
+  with tf.name_scope('conv4') as scope:
+    kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256],
+                                             dtype=tf.float32,
+                                             stddev=1e-1), name='weights')
+    conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding='SAME')
+    biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32),
+                         trainable=True, name='biases')
+    bias = tf.nn.bias_add(conv, biases)
+    conv4 = tf.nn.relu(bias, name=scope)
+    parameters += [kernel, biases]
+    print_activations(conv4)
+
+  # conv5
+  with tf.name_scope('conv5') as scope:
+    kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256],
+                                             dtype=tf.float32,
+                                             stddev=1e-1), name='weights')
+    conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding='SAME')
+    biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32),
+                         trainable=True, name='biases')
+    bias = tf.nn.bias_add(conv, biases)
+    conv5 = tf.nn.relu(bias, name=scope)
+    parameters += [kernel, biases]
+    print_activations(conv5)
+
+  # pool5
+  pool5 = tf.nn.max_pool(conv5,
+                         ksize=[1, 3, 3, 1],
+                         strides=[1, 2, 2, 1],
+                         padding='VALID',
+                         name='pool5')
+  print_activations(pool5)
+
+  return pool5, parameters
+
+
+def time_tensorflow_run(session, target, info_string):
+  """Run the computation to obtain the target tensor and print timing stats.
+
+  Args:
+    session: the TensorFlow session to run the computation under.
+    target: the target Tensor that is passed to the session's run() function.
+    info_string: a string summarizing this run, to be printed with the stats.
+
+  Returns:
+    None
+  """
+  num_steps_burn_in = 10
+  total_duration = 0.0
+  total_duration_squared = 0.0
+  for i in xrange(FLAGS.num_batches + num_steps_burn_in):
+    start_time = time.time()
+    _ = session.run(target)
+    duration = time.time() - start_time
+    if i >= num_steps_burn_in:
+      if not i % 10:
+        print ('%s: step %d, duration = %.3f' %
+               (datetime.now(), i - num_steps_burn_in, duration))
+      total_duration += duration
+      total_duration_squared += duration * duration
+  mn = total_duration / FLAGS.num_batches
+  vr = total_duration_squared / FLAGS.num_batches - mn * mn
+  sd = math.sqrt(vr)
+  print ('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %
+         (datetime.now(), info_string, FLAGS.num_batches, mn, sd))
+
+
+
+def run_benchmark():
+  """Run the benchmark on AlexNet."""
+  with tf.Graph().as_default():
+    # Generate some dummy images.
+    image_size = 224
+    # Note that our padding definition is slightly different the cuda-convnet.
+    # In order to force the model to start with the same activations sizes,
+    # we add 3 to the image_size and employ VALID padding above.
+    images = tf.Variable(tf.random_normal([FLAGS.batch_size,
+                                           image_size,
+                                           image_size, 3],
+                                          dtype=tf.float32,
+                                          stddev=1e-1))
+
+    # Build a Graph that computes the logits predictions from the
+    # inference model.
+    pool5, parameters = inference(images)
+
+    # Build an initialization operation.
+    init = tf.global_variables_initializer()
+
+    # Start running operations on the Graph.
+    config = tf.ConfigProto()
+    config.gpu_options.allocator_type = 'BFC'
+    sess = tf.Session(config=config)
+    sess.run(init)
+
+    # Run the forward benchmark.
+    time_tensorflow_run(sess, pool5, "Forward")
+
+    # Add a simple objective so we can calculate the backward pass.
+    objective = tf.nn.l2_loss(pool5)
+    # Compute the gradient with respect to all the parameters.
+    grad = tf.gradients(objective, parameters)
+    # Run the backward benchmark.
+    time_tensorflow_run(sess, grad, "Forward-backward")
+
+
+def main(_):
+  run_benchmark()
+
+
+if __name__ == '__main__':
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--batch_size',
+      type=int,
+      default=128,
+      help='Batch size.'
+  )
+  parser.add_argument(
+      '--num_batches',
+      type=int,
+      default=100,
+      help='Number of batches to run.'
+  )
+  FLAGS, unparsed = parser.parse_known_args()
+  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

+ 87 - 0
tutorials/image/cifar10/BUILD

@@ -0,0 +1,87 @@
+# Description:
+# Example TensorFlow models for CIFAR-10
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+    name = "cifar10_input",
+    srcs = ["cifar10_input.py"],
+    srcs_version = "PY2AND3",
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+py_test(
+    name = "cifar10_input_test",
+    size = "small",
+    srcs = ["cifar10_input_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":cifar10_input",
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+py_library(
+    name = "cifar10",
+    srcs = ["cifar10.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":cifar10_input",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+py_binary(
+    name = "cifar10_eval",
+    srcs = [
+        "cifar10_eval.py",
+    ],
+    srcs_version = "PY2AND3",
+    visibility = ["//tensorflow:__subpackages__"],
+    deps = [
+        ":cifar10",
+    ],
+)
+
+py_binary(
+    name = "cifar10_train",
+    srcs = [
+        "cifar10_train.py",
+    ],
+    srcs_version = "PY2AND3",
+    visibility = ["//tensorflow:__subpackages__"],
+    deps = [
+        ":cifar10",
+    ],
+)
+
+py_binary(
+    name = "cifar10_multi_gpu_train",
+    srcs = [
+        "cifar10_multi_gpu_train.py",
+    ],
+    srcs_version = "PY2AND3",
+    visibility = ["//tensorflow:__subpackages__"],
+    deps = [
+        ":cifar10",
+    ],
+)
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)

+ 10 - 0
tutorials/image/cifar10/README.md

@@ -0,0 +1,10 @@
+CIFAR-10 is a common benchmark in machine learning for image recognition.
+
+http://www.cs.toronto.edu/~kriz/cifar.html
+
+Code in this directory demonstrates how to use TensorFlow to train and evaluate a convolutional neural network (CNN) on both CPU and GPU. We also demonstrate how to train a CNN over multiple GPUs.
+
+Detailed instructions on how to get started available at:
+
+http://tensorflow.org/tutorials/deep_cnn/
+

+ 22 - 0
tutorials/image/cifar10/__init__.py

@@ -0,0 +1,22 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Makes helper libraries available in the cifar10 package."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.models.image.cifar10 import cifar10
+from tensorflow.models.image.cifar10 import cifar10_input

+ 399 - 0
tutorials/image/cifar10/cifar10.py

@@ -0,0 +1,399 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Builds the CIFAR-10 network.
+
+Summary of available functions:
+
+ # Compute input images and labels for training. If you would like to run
+ # evaluations, use inputs() instead.
+ inputs, labels = distorted_inputs()
+
+ # Compute inference on the model inputs to make a prediction.
+ predictions = inference(inputs)
+
+ # Compute the total loss of the prediction with respect to the labels.
+ loss = loss(predictions, labels)
+
+ # Create a graph to run one step of training with respect to the loss.
+ train_op = train(loss, global_step)
+"""
+# pylint: disable=missing-docstring
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import re
+import sys
+import tarfile
+
+from six.moves import urllib
+import tensorflow as tf
+
+from tensorflow.models.image.cifar10 import cifar10_input
+
+FLAGS = tf.app.flags.FLAGS
+
+# Basic model parameters.
+tf.app.flags.DEFINE_integer('batch_size', 128,
+                            """Number of images to process in a batch.""")
+tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
+                           """Path to the CIFAR-10 data directory.""")
+tf.app.flags.DEFINE_boolean('use_fp16', False,
+                            """Train the model using fp16.""")
+
+# Global constants describing the CIFAR-10 data set.
+IMAGE_SIZE = cifar10_input.IMAGE_SIZE
+NUM_CLASSES = cifar10_input.NUM_CLASSES
+NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
+NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
+
+
+# Constants describing the training process.
+MOVING_AVERAGE_DECAY = 0.9999     # The decay to use for the moving average.
+NUM_EPOCHS_PER_DECAY = 350.0      # Epochs after which learning rate decays.
+LEARNING_RATE_DECAY_FACTOR = 0.1  # Learning rate decay factor.
+INITIAL_LEARNING_RATE = 0.1       # Initial learning rate.
+
+# If a model is trained with multiple GPUs, prefix all Op names with tower_name
+# to differentiate the operations. Note that this prefix is removed from the
+# names of the summaries when visualizing a model.
+TOWER_NAME = 'tower'
+
+DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
+
+
+def _activation_summary(x):
+  """Helper to create summaries for activations.
+
+  Creates a summary that provides a histogram of activations.
+  Creates a summary that measures the sparsity of activations.
+
+  Args:
+    x: Tensor
+  Returns:
+    nothing
+  """
+  # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
+  # session. This helps the clarity of presentation on tensorboard.
+  tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
+  tf.contrib.deprecated.histogram_summary(tensor_name + '/activations', x)
+  tf.contrib.deprecated.scalar_summary(tensor_name + '/sparsity',
+                                       tf.nn.zero_fraction(x))
+
+
+def _variable_on_cpu(name, shape, initializer):
+  """Helper to create a Variable stored on CPU memory.
+
+  Args:
+    name: name of the variable
+    shape: list of ints
+    initializer: initializer for Variable
+
+  Returns:
+    Variable Tensor
+  """
+  with tf.device('/cpu:0'):
+    dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
+    var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
+  return var
+
+
+def _variable_with_weight_decay(name, shape, stddev, wd):
+  """Helper to create an initialized Variable with weight decay.
+
+  Note that the Variable is initialized with a truncated normal distribution.
+  A weight decay is added only if one is specified.
+
+  Args:
+    name: name of the variable
+    shape: list of ints
+    stddev: standard deviation of a truncated Gaussian
+    wd: add L2Loss weight decay multiplied by this float. If None, weight
+        decay is not added for this Variable.
+
+  Returns:
+    Variable Tensor
+  """
+  dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
+  var = _variable_on_cpu(
+      name,
+      shape,
+      tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
+  if wd is not None:
+    weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss')
+    tf.add_to_collection('losses', weight_decay)
+  return var
+
+
+def distorted_inputs():
+  """Construct distorted input for CIFAR training using the Reader ops.
+
+  Returns:
+    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
+    labels: Labels. 1D tensor of [batch_size] size.
+
+  Raises:
+    ValueError: If no data_dir
+  """
+  if not FLAGS.data_dir:
+    raise ValueError('Please supply a data_dir')
+  data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
+  images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
+                                                  batch_size=FLAGS.batch_size)
+  if FLAGS.use_fp16:
+    images = tf.cast(images, tf.float16)
+    labels = tf.cast(labels, tf.float16)
+  return images, labels
+
+
+def inputs(eval_data):
+  """Construct input for CIFAR evaluation using the Reader ops.
+
+  Args:
+    eval_data: bool, indicating if one should use the train or eval data set.
+
+  Returns:
+    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
+    labels: Labels. 1D tensor of [batch_size] size.
+
+  Raises:
+    ValueError: If no data_dir
+  """
+  if not FLAGS.data_dir:
+    raise ValueError('Please supply a data_dir')
+  data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
+  images, labels = cifar10_input.inputs(eval_data=eval_data,
+                                        data_dir=data_dir,
+                                        batch_size=FLAGS.batch_size)
+  if FLAGS.use_fp16:
+    images = tf.cast(images, tf.float16)
+    labels = tf.cast(labels, tf.float16)
+  return images, labels
+
+
+def inference(images):
+  """Build the CIFAR-10 model.
+
+  Args:
+    images: Images returned from distorted_inputs() or inputs().
+
+  Returns:
+    Logits.
+  """
+  # We instantiate all variables using tf.get_variable() instead of
+  # tf.Variable() in order to share variables across multiple GPU training runs.
+  # If we only ran this model on a single GPU, we could simplify this function
+  # by replacing all instances of tf.get_variable() with tf.Variable().
+  #
+  # conv1
+  with tf.variable_scope('conv1') as scope:
+    kernel = _variable_with_weight_decay('weights',
+                                         shape=[5, 5, 3, 64],
+                                         stddev=5e-2,
+                                         wd=0.0)
+    conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
+    biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
+    pre_activation = tf.nn.bias_add(conv, biases)
+    conv1 = tf.nn.relu(pre_activation, name=scope.name)
+    _activation_summary(conv1)
+
+  # pool1
+  pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
+                         padding='SAME', name='pool1')
+  # norm1
+  norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
+                    name='norm1')
+
+  # conv2
+  with tf.variable_scope('conv2') as scope:
+    kernel = _variable_with_weight_decay('weights',
+                                         shape=[5, 5, 64, 64],
+                                         stddev=5e-2,
+                                         wd=0.0)
+    conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')
+    biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
+    pre_activation = tf.nn.bias_add(conv, biases)
+    conv2 = tf.nn.relu(pre_activation, name=scope.name)
+    _activation_summary(conv2)
+
+  # norm2
+  norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
+                    name='norm2')
+  # pool2
+  pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],
+                         strides=[1, 2, 2, 1], padding='SAME', name='pool2')
+
+  # local3
+  with tf.variable_scope('local3') as scope:
+    # Move everything into depth so we can perform a single matrix multiply.
+    reshape = tf.reshape(pool2, [FLAGS.batch_size, -1])
+    dim = reshape.get_shape()[1].value
+    weights = _variable_with_weight_decay('weights', shape=[dim, 384],
+                                          stddev=0.04, wd=0.004)
+    biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
+    local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
+    _activation_summary(local3)
+
+  # local4
+  with tf.variable_scope('local4') as scope:
+    weights = _variable_with_weight_decay('weights', shape=[384, 192],
+                                          stddev=0.04, wd=0.004)
+    biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
+    local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name=scope.name)
+    _activation_summary(local4)
+
+  # linear layer(WX + b),
+  # We don't apply softmax here because
+  # tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits
+  # and performs the softmax internally for efficiency.
+  with tf.variable_scope('softmax_linear') as scope:
+    weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES],
+                                          stddev=1/192.0, wd=0.0)
+    biases = _variable_on_cpu('biases', [NUM_CLASSES],
+                              tf.constant_initializer(0.0))
+    softmax_linear = tf.add(tf.matmul(local4, weights), biases, name=scope.name)
+    _activation_summary(softmax_linear)
+
+  return softmax_linear
+
+
+def loss(logits, labels):
+  """Add L2Loss to all the trainable variables.
+
+  Add summary for "Loss" and "Loss/avg".
+  Args:
+    logits: Logits from inference().
+    labels: Labels from distorted_inputs or inputs(). 1-D tensor
+            of shape [batch_size]
+
+  Returns:
+    Loss tensor of type float.
+  """
+  # Calculate the average cross entropy loss across the batch.
+  labels = tf.cast(labels, tf.int64)
+  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
+      logits, labels, name='cross_entropy_per_example')
+  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
+  tf.add_to_collection('losses', cross_entropy_mean)
+
+  # The total loss is defined as the cross entropy loss plus all of the weight
+  # decay terms (L2 loss).
+  return tf.add_n(tf.get_collection('losses'), name='total_loss')
+
+
+def _add_loss_summaries(total_loss):
+  """Add summaries for losses in CIFAR-10 model.
+
+  Generates moving average for all losses and associated summaries for
+  visualizing the performance of the network.
+
+  Args:
+    total_loss: Total loss from loss().
+  Returns:
+    loss_averages_op: op for generating moving averages of losses.
+  """
+  # Compute the moving average of all individual losses and the total loss.
+  loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
+  losses = tf.get_collection('losses')
+  loss_averages_op = loss_averages.apply(losses + [total_loss])
+
+  # Attach a scalar summary to all individual losses and the total loss; do the
+  # same for the averaged version of the losses.
+  for l in losses + [total_loss]:
+    # Name each loss as '(raw)' and name the moving average version of the loss
+    # as the original loss name.
+    tf.contrib.deprecated.scalar_summary(l.op.name + ' (raw)', l)
+    tf.contrib.deprecated.scalar_summary(l.op.name, loss_averages.average(l))
+
+  return loss_averages_op
+
+
+def train(total_loss, global_step):
+  """Train CIFAR-10 model.
+
+  Create an optimizer and apply to all trainable variables. Add moving
+  average for all trainable variables.
+
+  Args:
+    total_loss: Total loss from loss().
+    global_step: Integer Variable counting the number of training steps
+      processed.
+  Returns:
+    train_op: op for training.
+  """
+  # Variables that affect learning rate.
+  num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
+  decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
+
+  # Decay the learning rate exponentially based on the number of steps.
+  lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
+                                  global_step,
+                                  decay_steps,
+                                  LEARNING_RATE_DECAY_FACTOR,
+                                  staircase=True)
+  tf.contrib.deprecated.scalar_summary('learning_rate', lr)
+
+  # Generate moving averages of all losses and associated summaries.
+  loss_averages_op = _add_loss_summaries(total_loss)
+
+  # Compute gradients.
+  with tf.control_dependencies([loss_averages_op]):
+    opt = tf.train.GradientDescentOptimizer(lr)
+    grads = opt.compute_gradients(total_loss)
+
+  # Apply gradients.
+  apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
+
+  # Add histograms for trainable variables.
+  for var in tf.trainable_variables():
+    tf.contrib.deprecated.histogram_summary(var.op.name, var)
+
+  # Add histograms for gradients.
+  for grad, var in grads:
+    if grad is not None:
+      tf.contrib.deprecated.histogram_summary(var.op.name + '/gradients', grad)
+
+  # Track the moving averages of all trainable variables.
+  variable_averages = tf.train.ExponentialMovingAverage(
+      MOVING_AVERAGE_DECAY, global_step)
+  variables_averages_op = variable_averages.apply(tf.trainable_variables())
+
+  with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
+    train_op = tf.no_op(name='train')
+
+  return train_op
+
+
+def maybe_download_and_extract():
+  """Download and extract the tarball from Alex's website."""
+  dest_directory = FLAGS.data_dir
+  if not os.path.exists(dest_directory):
+    os.makedirs(dest_directory)
+  filename = DATA_URL.split('/')[-1]
+  filepath = os.path.join(dest_directory, filename)
+  if not os.path.exists(filepath):
+    def _progress(count, block_size, total_size):
+      sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
+          float(count * block_size) / float(total_size) * 100.0))
+      sys.stdout.flush()
+    filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
+    print()
+    statinfo = os.stat(filepath)
+    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
+
+  tarfile.open(filepath, 'r:gz').extractall(dest_directory)

+ 157 - 0
tutorials/image/cifar10/cifar10_eval.py

@@ -0,0 +1,157 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Evaluation for CIFAR-10.
+
+Accuracy:
+cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs
+of data) as judged by cifar10_eval.py.
+
+Speed:
+On a single Tesla K40, cifar10_train.py processes a single batch of 128 images
+in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86%
+accuracy after 100K steps in 8 hours of training time.
+
+Usage:
+Please see the tutorial and website for how to download the CIFAR-10
+data set, compile the program and train the model.
+
+http://tensorflow.org/tutorials/deep_cnn/
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from datetime import datetime
+import math
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.image.cifar10 import cifar10
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('eval_dir', '/tmp/cifar10_eval',
+                           """Directory where to write event logs.""")
+tf.app.flags.DEFINE_string('eval_data', 'test',
+                           """Either 'test' or 'train_eval'.""")
+tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train',
+                           """Directory where to read model checkpoints.""")
+tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5,
+                            """How often to run the eval.""")
+tf.app.flags.DEFINE_integer('num_examples', 10000,
+                            """Number of examples to run.""")
+tf.app.flags.DEFINE_boolean('run_once', False,
+                         """Whether to run eval only once.""")
+
+
+def eval_once(saver, summary_writer, top_k_op, summary_op):
+  """Run Eval once.
+
+  Args:
+    saver: Saver.
+    summary_writer: Summary writer.
+    top_k_op: Top K op.
+    summary_op: Summary op.
+  """
+  with tf.Session() as sess:
+    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
+    if ckpt and ckpt.model_checkpoint_path:
+      # Restores from checkpoint
+      saver.restore(sess, ckpt.model_checkpoint_path)
+      # Assuming model_checkpoint_path looks something like:
+      #   /my-favorite-path/cifar10_train/model.ckpt-0,
+      # extract global_step from it.
+      global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
+    else:
+      print('No checkpoint file found')
+      return
+
+    # Start the queue runners.
+    coord = tf.train.Coordinator()
+    try:
+      threads = []
+      for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
+        threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
+                                         start=True))
+
+      num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
+      true_count = 0  # Counts the number of correct predictions.
+      total_sample_count = num_iter * FLAGS.batch_size
+      step = 0
+      while step < num_iter and not coord.should_stop():
+        predictions = sess.run([top_k_op])
+        true_count += np.sum(predictions)
+        step += 1
+
+      # Compute precision @ 1.
+      precision = true_count / total_sample_count
+      print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
+
+      summary = tf.Summary()
+      summary.ParseFromString(sess.run(summary_op))
+      summary.value.add(tag='Precision @ 1', simple_value=precision)
+      summary_writer.add_summary(summary, global_step)
+    except Exception as e:  # pylint: disable=broad-except
+      coord.request_stop(e)
+
+    coord.request_stop()
+    coord.join(threads, stop_grace_period_secs=10)
+
+
+def evaluate():
+  """Eval CIFAR-10 for a number of steps."""
+  with tf.Graph().as_default() as g:
+    # Get images and labels for CIFAR-10.
+    eval_data = FLAGS.eval_data == 'test'
+    images, labels = cifar10.inputs(eval_data=eval_data)
+
+    # Build a Graph that computes the logits predictions from the
+    # inference model.
+    logits = cifar10.inference(images)
+
+    # Calculate predictions.
+    top_k_op = tf.nn.in_top_k(logits, labels, 1)
+
+    # Restore the moving average version of the learned variables for eval.
+    variable_averages = tf.train.ExponentialMovingAverage(
+        cifar10.MOVING_AVERAGE_DECAY)
+    variables_to_restore = variable_averages.variables_to_restore()
+    saver = tf.train.Saver(variables_to_restore)
+
+    # Build the summary operation based on the TF collection of Summaries.
+    summary_op = tf.summary.merge_all()
+
+    summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
+
+    while True:
+      eval_once(saver, summary_writer, top_k_op, summary_op)
+      if FLAGS.run_once:
+        break
+      time.sleep(FLAGS.eval_interval_secs)
+
+
+def main(argv=None):  # pylint: disable=unused-argument
+  cifar10.maybe_download_and_extract()
+  if tf.gfile.Exists(FLAGS.eval_dir):
+    tf.gfile.DeleteRecursively(FLAGS.eval_dir)
+  tf.gfile.MakeDirs(FLAGS.eval_dir)
+  evaluate()
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 253 - 0
tutorials/image/cifar10/cifar10_input.py

@@ -0,0 +1,253 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Routine for decoding the CIFAR-10 binary file format."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from six.moves import xrange  # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+# Process images of this size. Note that this differs from the original CIFAR
+# image size of 32 x 32. If one alters this number, then the entire model
+# architecture will change and any model would need to be retrained.
+IMAGE_SIZE = 24
+
+# Global constants describing the CIFAR-10 data set.
+NUM_CLASSES = 10
+NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
+NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
+
+
+def read_cifar10(filename_queue):
+  """Reads and parses examples from CIFAR10 data files.
+
+  Recommendation: if you want N-way read parallelism, call this function
+  N times.  This will give you N independent Readers reading different
+  files & positions within those files, which will give better mixing of
+  examples.
+
+  Args:
+    filename_queue: A queue of strings with the filenames to read from.
+
+  Returns:
+    An object representing a single example, with the following fields:
+      height: number of rows in the result (32)
+      width: number of columns in the result (32)
+      depth: number of color channels in the result (3)
+      key: a scalar string Tensor describing the filename & record number
+        for this example.
+      label: an int32 Tensor with the label in the range 0..9.
+      uint8image: a [height, width, depth] uint8 Tensor with the image data
+  """
+
+  class CIFAR10Record(object):
+    pass
+  result = CIFAR10Record()
+
+  # Dimensions of the images in the CIFAR-10 dataset.
+  # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
+  # input format.
+  label_bytes = 1  # 2 for CIFAR-100
+  result.height = 32
+  result.width = 32
+  result.depth = 3
+  image_bytes = result.height * result.width * result.depth
+  # Every record consists of a label followed by the image, with a
+  # fixed number of bytes for each.
+  record_bytes = label_bytes + image_bytes
+
+  # Read a record, getting filenames from the filename_queue.  No
+  # header or footer in the CIFAR-10 format, so we leave header_bytes
+  # and footer_bytes at their default of 0.
+  reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
+  result.key, value = reader.read(filename_queue)
+
+  # Convert from a string to a vector of uint8 that is record_bytes long.
+  record_bytes = tf.decode_raw(value, tf.uint8)
+
+  # The first bytes represent the label, which we convert from uint8->int32.
+  result.label = tf.cast(
+      tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
+
+  # The remaining bytes after the label represent the image, which we reshape
+  # from [depth * height * width] to [depth, height, width].
+  depth_major = tf.reshape(
+      tf.strided_slice(record_bytes, [label_bytes],
+                       [label_bytes + image_bytes]),
+      [result.depth, result.height, result.width])
+  # Convert from [depth, height, width] to [height, width, depth].
+  result.uint8image = tf.transpose(depth_major, [1, 2, 0])
+
+  return result
+
+
+def _generate_image_and_label_batch(image, label, min_queue_examples,
+                                    batch_size, shuffle):
+  """Construct a queued batch of images and labels.
+
+  Args:
+    image: 3-D Tensor of [height, width, 3] of type.float32.
+    label: 1-D Tensor of type.int32
+    min_queue_examples: int32, minimum number of samples to retain
+      in the queue that provides of batches of examples.
+    batch_size: Number of images per batch.
+    shuffle: boolean indicating whether to use a shuffling queue.
+
+  Returns:
+    images: Images. 4D tensor of [batch_size, height, width, 3] size.
+    labels: Labels. 1D tensor of [batch_size] size.
+  """
+  # Create a queue that shuffles the examples, and then
+  # read 'batch_size' images + labels from the example queue.
+  num_preprocess_threads = 16
+  if shuffle:
+    images, label_batch = tf.train.shuffle_batch(
+        [image, label],
+        batch_size=batch_size,
+        num_threads=num_preprocess_threads,
+        capacity=min_queue_examples + 3 * batch_size,
+        min_after_dequeue=min_queue_examples)
+  else:
+    images, label_batch = tf.train.batch(
+        [image, label],
+        batch_size=batch_size,
+        num_threads=num_preprocess_threads,
+        capacity=min_queue_examples + 3 * batch_size)
+
+  # Display the training images in the visualizer.
+  tf.contrib.deprecated.image_summary('images', images)
+
+  return images, tf.reshape(label_batch, [batch_size])
+
+
+def distorted_inputs(data_dir, batch_size):
+  """Construct distorted input for CIFAR training using the Reader ops.
+
+  Args:
+    data_dir: Path to the CIFAR-10 data directory.
+    batch_size: Number of images per batch.
+
+  Returns:
+    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
+    labels: Labels. 1D tensor of [batch_size] size.
+  """
+  filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
+               for i in xrange(1, 6)]
+  for f in filenames:
+    if not tf.gfile.Exists(f):
+      raise ValueError('Failed to find file: ' + f)
+
+  # Create a queue that produces the filenames to read.
+  filename_queue = tf.train.string_input_producer(filenames)
+
+  # Read examples from files in the filename queue.
+  read_input = read_cifar10(filename_queue)
+  reshaped_image = tf.cast(read_input.uint8image, tf.float32)
+
+  height = IMAGE_SIZE
+  width = IMAGE_SIZE
+
+  # Image processing for training the network. Note the many random
+  # distortions applied to the image.
+
+  # Randomly crop a [height, width] section of the image.
+  distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
+
+  # Randomly flip the image horizontally.
+  distorted_image = tf.image.random_flip_left_right(distorted_image)
+
+  # Because these operations are not commutative, consider randomizing
+  # the order their operation.
+  distorted_image = tf.image.random_brightness(distorted_image,
+                                               max_delta=63)
+  distorted_image = tf.image.random_contrast(distorted_image,
+                                             lower=0.2, upper=1.8)
+
+  # Subtract off the mean and divide by the variance of the pixels.
+  float_image = tf.image.per_image_standardization(distorted_image)
+
+  # Set the shapes of tensors.
+  float_image.set_shape([height, width, 3])
+  read_input.label.set_shape([1])
+
+  # Ensure that the random shuffling has good mixing properties.
+  min_fraction_of_examples_in_queue = 0.4
+  min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
+                           min_fraction_of_examples_in_queue)
+  print ('Filling queue with %d CIFAR images before starting to train. '
+         'This will take a few minutes.' % min_queue_examples)
+
+  # Generate a batch of images and labels by building up a queue of examples.
+  return _generate_image_and_label_batch(float_image, read_input.label,
+                                         min_queue_examples, batch_size,
+                                         shuffle=True)
+
+
+def inputs(eval_data, data_dir, batch_size):
+  """Construct input for CIFAR evaluation using the Reader ops.
+
+  Args:
+    eval_data: bool, indicating if one should use the train or eval data set.
+    data_dir: Path to the CIFAR-10 data directory.
+    batch_size: Number of images per batch.
+
+  Returns:
+    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
+    labels: Labels. 1D tensor of [batch_size] size.
+  """
+  if not eval_data:
+    filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
+                 for i in xrange(1, 6)]
+    num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
+  else:
+    filenames = [os.path.join(data_dir, 'test_batch.bin')]
+    num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
+
+  for f in filenames:
+    if not tf.gfile.Exists(f):
+      raise ValueError('Failed to find file: ' + f)
+
+  # Create a queue that produces the filenames to read.
+  filename_queue = tf.train.string_input_producer(filenames)
+
+  # Read examples from files in the filename queue.
+  read_input = read_cifar10(filename_queue)
+  reshaped_image = tf.cast(read_input.uint8image, tf.float32)
+
+  height = IMAGE_SIZE
+  width = IMAGE_SIZE
+
+  # Image processing for evaluation.
+  # Crop the central [height, width] of the image.
+  resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
+                                                         width, height)
+
+  # Subtract off the mean and divide by the variance of the pixels.
+  float_image = tf.image.per_image_standardization(resized_image)
+
+  # Ensure that the random shuffling has good mixing properties.
+  min_fraction_of_examples_in_queue = 0.4
+  min_queue_examples = int(num_examples_per_epoch *
+                           min_fraction_of_examples_in_queue)
+
+  # Generate a batch of images and labels by building up a queue of examples.
+  return _generate_image_and_label_batch(float_image, read_input.label,
+                                         min_queue_examples, batch_size,
+                                         shuffle=False)

+ 66 - 0
tutorials/image/cifar10/cifar10_input_test.py

@@ -0,0 +1,66 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for cifar10 input."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import tensorflow as tf
+
+from tensorflow.models.image.cifar10 import cifar10_input
+
+
+class CIFAR10InputTest(tf.test.TestCase):
+
+  def _record(self, label, red, green, blue):
+    image_size = 32 * 32
+    record = bytes(bytearray([label] + [red] * image_size +
+                             [green] * image_size + [blue] * image_size))
+    expected = [[[red, green, blue]] * 32] * 32
+    return record, expected
+
+  def testSimple(self):
+    labels = [9, 3, 0]
+    records = [self._record(labels[0], 0, 128, 255),
+               self._record(labels[1], 255, 0, 1),
+               self._record(labels[2], 254, 255, 0)]
+    contents = b"".join([record for record, _ in records])
+    expected = [expected for _, expected in records]
+    filename = os.path.join(self.get_temp_dir(), "cifar")
+    open(filename, "wb").write(contents)
+
+    with self.test_session() as sess:
+      q = tf.FIFOQueue(99, [tf.string], shapes=())
+      q.enqueue([filename]).run()
+      q.close().run()
+      result = cifar10_input.read_cifar10(q)
+
+      for i in range(3):
+        key, label, uint8image = sess.run([
+            result.key, result.label, result.uint8image])
+        self.assertEqual("%s:%d" % (filename, i), tf.compat.as_text(key))
+        self.assertEqual(labels[i], label)
+        self.assertAllEqual(expected[i], uint8image)
+
+      with self.assertRaises(tf.errors.OutOfRangeError):
+        sess.run([result.key, result.uint8image])
+
+
+if __name__ == "__main__":
+  tf.test.main()

+ 273 - 0
tutorials/image/cifar10/cifar10_multi_gpu_train.py

@@ -0,0 +1,273 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A binary to train CIFAR-10 using multiple GPU's with synchronous updates.
+
+Accuracy:
+cifar10_multi_gpu_train.py achieves ~86% accuracy after 100K steps (256
+epochs of data) as judged by cifar10_eval.py.
+
+Speed: With batch_size 128.
+
+System        | Step Time (sec/batch)  |     Accuracy
+--------------------------------------------------------------------
+1 Tesla K20m  | 0.35-0.60              | ~86% at 60K steps  (5 hours)
+1 Tesla K40m  | 0.25-0.35              | ~86% at 100K steps (4 hours)
+2 Tesla K20m  | 0.13-0.20              | ~84% at 30K steps  (2.5 hours)
+3 Tesla K20m  | 0.13-0.18              | ~84% at 30K steps
+4 Tesla K20m  | ~0.10                  | ~84% at 30K steps
+
+Usage:
+Please see the tutorial and website for how to download the CIFAR-10
+data set, compile the program and train the model.
+
+http://tensorflow.org/tutorials/deep_cnn/
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from datetime import datetime
+import os.path
+import re
+import time
+
+import numpy as np
+from six.moves import xrange  # pylint: disable=redefined-builtin
+import tensorflow as tf
+from tensorflow.models.image.cifar10 import cifar10
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
+                           """Directory where to write event logs """
+                           """and checkpoint.""")
+tf.app.flags.DEFINE_integer('max_steps', 1000000,
+                            """Number of batches to run.""")
+tf.app.flags.DEFINE_integer('num_gpus', 1,
+                            """How many GPUs to use.""")
+tf.app.flags.DEFINE_boolean('log_device_placement', False,
+                            """Whether to log device placement.""")
+
+
+def tower_loss(scope):
+  """Calculate the total loss on a single tower running the CIFAR model.
+
+  Args:
+    scope: unique prefix string identifying the CIFAR tower, e.g. 'tower_0'
+
+  Returns:
+     Tensor of shape [] containing the total loss for a batch of data
+  """
+  # Get images and labels for CIFAR-10.
+  images, labels = cifar10.distorted_inputs()
+
+  # Build inference Graph.
+  logits = cifar10.inference(images)
+
+  # Build the portion of the Graph calculating the losses. Note that we will
+  # assemble the total_loss using a custom function below.
+  _ = cifar10.loss(logits, labels)
+
+  # Assemble all of the losses for the current tower only.
+  losses = tf.get_collection('losses', scope)
+
+  # Calculate the total loss for the current tower.
+  total_loss = tf.add_n(losses, name='total_loss')
+
+  # Attach a scalar summary to all individual losses and the total loss; do the
+  # same for the averaged version of the losses.
+  for l in losses + [total_loss]:
+    # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
+    # session. This helps the clarity of presentation on tensorboard.
+    loss_name = re.sub('%s_[0-9]*/' % cifar10.TOWER_NAME, '', l.op.name)
+    tf.contrib.deprecated.scalar_summary(loss_name, l)
+
+  return total_loss
+
+
+def average_gradients(tower_grads):
+  """Calculate the average gradient for each shared variable across all towers.
+
+  Note that this function provides a synchronization point across all towers.
+
+  Args:
+    tower_grads: List of lists of (gradient, variable) tuples. The outer list
+      is over individual gradients. The inner list is over the gradient
+      calculation for each tower.
+  Returns:
+     List of pairs of (gradient, variable) where the gradient has been averaged
+     across all towers.
+  """
+  average_grads = []
+  for grad_and_vars in zip(*tower_grads):
+    # Note that each grad_and_vars looks like the following:
+    #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
+    grads = []
+    for g, _ in grad_and_vars:
+      # Add 0 dimension to the gradients to represent the tower.
+      expanded_g = tf.expand_dims(g, 0)
+
+      # Append on a 'tower' dimension which we will average over below.
+      grads.append(expanded_g)
+
+    # Average over the 'tower' dimension.
+    grad = tf.concat_v2(grads, 0)
+    grad = tf.reduce_mean(grad, 0)
+
+    # Keep in mind that the Variables are redundant because they are shared
+    # across towers. So .. we will just return the first tower's pointer to
+    # the Variable.
+    v = grad_and_vars[0][1]
+    grad_and_var = (grad, v)
+    average_grads.append(grad_and_var)
+  return average_grads
+
+
+def train():
+  """Train CIFAR-10 for a number of steps."""
+  with tf.Graph().as_default(), tf.device('/cpu:0'):
+    # Create a variable to count the number of train() calls. This equals the
+    # number of batches processed * FLAGS.num_gpus.
+    global_step = tf.get_variable(
+        'global_step', [],
+        initializer=tf.constant_initializer(0), trainable=False)
+
+    # Calculate the learning rate schedule.
+    num_batches_per_epoch = (cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN /
+                             FLAGS.batch_size)
+    decay_steps = int(num_batches_per_epoch * cifar10.NUM_EPOCHS_PER_DECAY)
+
+    # Decay the learning rate exponentially based on the number of steps.
+    lr = tf.train.exponential_decay(cifar10.INITIAL_LEARNING_RATE,
+                                    global_step,
+                                    decay_steps,
+                                    cifar10.LEARNING_RATE_DECAY_FACTOR,
+                                    staircase=True)
+
+    # Create an optimizer that performs gradient descent.
+    opt = tf.train.GradientDescentOptimizer(lr)
+
+    # Calculate the gradients for each model tower.
+    tower_grads = []
+    for i in xrange(FLAGS.num_gpus):
+      with tf.device('/gpu:%d' % i):
+        with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
+          # Calculate the loss for one tower of the CIFAR model. This function
+          # constructs the entire CIFAR model but shares the variables across
+          # all towers.
+          loss = tower_loss(scope)
+
+          # Reuse variables for the next tower.
+          tf.get_variable_scope().reuse_variables()
+
+          # Retain the summaries from the final tower.
+          summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
+
+          # Calculate the gradients for the batch of data on this CIFAR tower.
+          grads = opt.compute_gradients(loss)
+
+          # Keep track of the gradients across all towers.
+          tower_grads.append(grads)
+
+    # We must calculate the mean of each gradient. Note that this is the
+    # synchronization point across all towers.
+    grads = average_gradients(tower_grads)
+
+    # Add a summary to track the learning rate.
+    summaries.append(tf.contrib.deprecated.scalar_summary('learning_rate', lr))
+
+    # Add histograms for gradients.
+    for grad, var in grads:
+      if grad is not None:
+        summaries.append(
+            tf.contrib.deprecated.histogram_summary(var.op.name + '/gradients',
+                                                    grad))
+
+    # Apply the gradients to adjust the shared variables.
+    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
+
+    # Add histograms for trainable variables.
+    for var in tf.trainable_variables():
+      summaries.append(
+          tf.contrib.deprecated.histogram_summary(var.op.name, var))
+
+    # Track the moving averages of all trainable variables.
+    variable_averages = tf.train.ExponentialMovingAverage(
+        cifar10.MOVING_AVERAGE_DECAY, global_step)
+    variables_averages_op = variable_averages.apply(tf.trainable_variables())
+
+    # Group all updates to into a single train op.
+    train_op = tf.group(apply_gradient_op, variables_averages_op)
+
+    # Create a saver.
+    saver = tf.train.Saver(tf.global_variables())
+
+    # Build the summary operation from the last tower summaries.
+    summary_op = tf.contrib.deprecated.merge_summary(summaries)
+
+    # Build an initialization operation to run below.
+    init = tf.global_variables_initializer()
+
+    # Start running operations on the Graph. allow_soft_placement must be set to
+    # True to build towers on GPU, as some of the ops do not have GPU
+    # implementations.
+    sess = tf.Session(config=tf.ConfigProto(
+        allow_soft_placement=True,
+        log_device_placement=FLAGS.log_device_placement))
+    sess.run(init)
+
+    # Start the queue runners.
+    tf.train.start_queue_runners(sess=sess)
+
+    summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
+
+    for step in xrange(FLAGS.max_steps):
+      start_time = time.time()
+      _, loss_value = sess.run([train_op, loss])
+      duration = time.time() - start_time
+
+      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
+
+      if step % 10 == 0:
+        num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
+        examples_per_sec = num_examples_per_step / duration
+        sec_per_batch = duration / FLAGS.num_gpus
+
+        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
+                      'sec/batch)')
+        print (format_str % (datetime.now(), step, loss_value,
+                             examples_per_sec, sec_per_batch))
+
+      if step % 100 == 0:
+        summary_str = sess.run(summary_op)
+        summary_writer.add_summary(summary_str, step)
+
+      # Save the model checkpoint periodically.
+      if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
+        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
+        saver.save(sess, checkpoint_path, global_step=step)
+
+
+def main(argv=None):  # pylint: disable=unused-argument
+  cifar10.maybe_download_and_extract()
+  if tf.gfile.Exists(FLAGS.train_dir):
+    tf.gfile.DeleteRecursively(FLAGS.train_dir)
+  tf.gfile.MakeDirs(FLAGS.train_dir)
+  train()
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 120 - 0
tutorials/image/cifar10/cifar10_train.py

@@ -0,0 +1,120 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A binary to train CIFAR-10 using a single GPU.
+
+Accuracy:
+cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
+data) as judged by cifar10_eval.py.
+
+Speed: With batch_size 128.
+
+System        | Step Time (sec/batch)  |     Accuracy
+------------------------------------------------------------------
+1 Tesla K20m  | 0.35-0.60              | ~86% at 60K steps  (5 hours)
+1 Tesla K40m  | 0.25-0.35              | ~86% at 100K steps (4 hours)
+
+Usage:
+Please see the tutorial and website for how to download the CIFAR-10
+data set, compile the program and train the model.
+
+http://tensorflow.org/tutorials/deep_cnn/
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from datetime import datetime
+import time
+
+import tensorflow as tf
+
+from tensorflow.models.image.cifar10 import cifar10
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
+                           """Directory where to write event logs """
+                           """and checkpoint.""")
+tf.app.flags.DEFINE_integer('max_steps', 1000000,
+                            """Number of batches to run.""")
+tf.app.flags.DEFINE_boolean('log_device_placement', False,
+                            """Whether to log device placement.""")
+
+
+def train():
+  """Train CIFAR-10 for a number of steps."""
+  with tf.Graph().as_default():
+    global_step = tf.contrib.framework.get_or_create_global_step()
+
+    # Get images and labels for CIFAR-10.
+    images, labels = cifar10.distorted_inputs()
+
+    # Build a Graph that computes the logits predictions from the
+    # inference model.
+    logits = cifar10.inference(images)
+
+    # Calculate loss.
+    loss = cifar10.loss(logits, labels)
+
+    # Build a Graph that trains the model with one batch of examples and
+    # updates the model parameters.
+    train_op = cifar10.train(loss, global_step)
+
+    class _LoggerHook(tf.train.SessionRunHook):
+      """Logs loss and runtime."""
+
+      def begin(self):
+        self._step = -1
+
+      def before_run(self, run_context):
+        self._step += 1
+        self._start_time = time.time()
+        return tf.train.SessionRunArgs(loss)  # Asks for loss value.
+
+      def after_run(self, run_context, run_values):
+        duration = time.time() - self._start_time
+        loss_value = run_values.results
+        if self._step % 10 == 0:
+          num_examples_per_step = FLAGS.batch_size
+          examples_per_sec = num_examples_per_step / duration
+          sec_per_batch = float(duration)
+
+          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
+                        'sec/batch)')
+          print (format_str % (datetime.now(), self._step, loss_value,
+                               examples_per_sec, sec_per_batch))
+
+    with tf.train.MonitoredTrainingSession(
+        checkpoint_dir=FLAGS.train_dir,
+        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
+               tf.train.NanTensorHook(loss),
+               _LoggerHook()],
+        config=tf.ConfigProto(
+            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
+      while not mon_sess.should_stop():
+        mon_sess.run(train_op)
+
+
+def main(argv=None):  # pylint: disable=unused-argument
+  cifar10.maybe_download_and_extract()
+  if tf.gfile.Exists(FLAGS.train_dir):
+    tf.gfile.DeleteRecursively(FLAGS.train_dir)
+  tf.gfile.MakeDirs(FLAGS.train_dir)
+  train()
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 30 - 0
tutorials/image/imagenet/BUILD

@@ -0,0 +1,30 @@
+# Description:
+# Example TensorFlow models for ImageNet.
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_binary(
+    name = "classify_image",
+    srcs = [
+        "classify_image.py",
+    ],
+    srcs_version = "PY2AND3",
+    visibility = ["//tensorflow:__subpackages__"],
+    deps = [
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)

+ 227 - 0
tutorials/image/imagenet/classify_image.py

@@ -0,0 +1,227 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Simple image classification with Inception.
+
+Run image classification with Inception trained on ImageNet 2012 Challenge data
+set.
+
+This program creates a graph from a saved GraphDef protocol buffer,
+and runs inference on an input JPEG image. It outputs human readable
+strings of the top 5 predictions along with their probabilities.
+
+Change the --image_file argument to any jpg image to compute a
+classification of that image.
+
+Please see the tutorial and website for a detailed description of how
+to use this script to perform image recognition.
+
+https://tensorflow.org/tutorials/image_recognition/
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os.path
+import re
+import sys
+import tarfile
+
+import numpy as np
+from six.moves import urllib
+import tensorflow as tf
+
+FLAGS = None
+
+# pylint: disable=line-too-long
+DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
+# pylint: enable=line-too-long
+
+
+class NodeLookup(object):
+  """Converts integer node ID's to human readable labels."""
+
+  def __init__(self,
+               label_lookup_path=None,
+               uid_lookup_path=None):
+    if not label_lookup_path:
+      label_lookup_path = os.path.join(
+          FLAGS.model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
+    if not uid_lookup_path:
+      uid_lookup_path = os.path.join(
+          FLAGS.model_dir, 'imagenet_synset_to_human_label_map.txt')
+    self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
+
+  def load(self, label_lookup_path, uid_lookup_path):
+    """Loads a human readable English name for each softmax node.
+
+    Args:
+      label_lookup_path: string UID to integer node ID.
+      uid_lookup_path: string UID to human-readable string.
+
+    Returns:
+      dict from integer node ID to human-readable string.
+    """
+    if not tf.gfile.Exists(uid_lookup_path):
+      tf.logging.fatal('File does not exist %s', uid_lookup_path)
+    if not tf.gfile.Exists(label_lookup_path):
+      tf.logging.fatal('File does not exist %s', label_lookup_path)
+
+    # Loads mapping from string UID to human-readable string
+    proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
+    uid_to_human = {}
+    p = re.compile(r'[n\d]*[ \S,]*')
+    for line in proto_as_ascii_lines:
+      parsed_items = p.findall(line)
+      uid = parsed_items[0]
+      human_string = parsed_items[2]
+      uid_to_human[uid] = human_string
+
+    # Loads mapping from string UID to integer node ID.
+    node_id_to_uid = {}
+    proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
+    for line in proto_as_ascii:
+      if line.startswith('  target_class:'):
+        target_class = int(line.split(': ')[1])
+      if line.startswith('  target_class_string:'):
+        target_class_string = line.split(': ')[1]
+        node_id_to_uid[target_class] = target_class_string[1:-2]
+
+    # Loads the final mapping of integer node ID to human-readable string
+    node_id_to_name = {}
+    for key, val in node_id_to_uid.items():
+      if val not in uid_to_human:
+        tf.logging.fatal('Failed to locate: %s', val)
+      name = uid_to_human[val]
+      node_id_to_name[key] = name
+
+    return node_id_to_name
+
+  def id_to_string(self, node_id):
+    if node_id not in self.node_lookup:
+      return ''
+    return self.node_lookup[node_id]
+
+
+def create_graph():
+  """Creates a graph from saved GraphDef file and returns a saver."""
+  # Creates graph from saved graph_def.pb.
+  with tf.gfile.FastGFile(os.path.join(
+      FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
+    graph_def = tf.GraphDef()
+    graph_def.ParseFromString(f.read())
+    _ = tf.import_graph_def(graph_def, name='')
+
+
+def run_inference_on_image(image):
+  """Runs inference on an image.
+
+  Args:
+    image: Image file name.
+
+  Returns:
+    Nothing
+  """
+  if not tf.gfile.Exists(image):
+    tf.logging.fatal('File does not exist %s', image)
+  image_data = tf.gfile.FastGFile(image, 'rb').read()
+
+  # Creates graph from saved GraphDef.
+  create_graph()
+
+  with tf.Session() as sess:
+    # Some useful tensors:
+    # 'softmax:0': A tensor containing the normalized prediction across
+    #   1000 labels.
+    # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
+    #   float description of the image.
+    # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
+    #   encoding of the image.
+    # Runs the softmax tensor by feeding the image_data as input to the graph.
+    softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
+    predictions = sess.run(softmax_tensor,
+                           {'DecodeJpeg/contents:0': image_data})
+    predictions = np.squeeze(predictions)
+
+    # Creates node ID --> English string lookup.
+    node_lookup = NodeLookup()
+
+    top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
+    for node_id in top_k:
+      human_string = node_lookup.id_to_string(node_id)
+      score = predictions[node_id]
+      print('%s (score = %.5f)' % (human_string, score))
+
+
+def maybe_download_and_extract():
+  """Download and extract model tar file."""
+  dest_directory = FLAGS.model_dir
+  if not os.path.exists(dest_directory):
+    os.makedirs(dest_directory)
+  filename = DATA_URL.split('/')[-1]
+  filepath = os.path.join(dest_directory, filename)
+  if not os.path.exists(filepath):
+    def _progress(count, block_size, total_size):
+      sys.stdout.write('\r>> Downloading %s %.1f%%' % (
+          filename, float(count * block_size) / float(total_size) * 100.0))
+      sys.stdout.flush()
+    filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
+    print()
+    statinfo = os.stat(filepath)
+    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
+  tarfile.open(filepath, 'r:gz').extractall(dest_directory)
+
+
+def main(_):
+  maybe_download_and_extract()
+  image = (FLAGS.image_file if FLAGS.image_file else
+           os.path.join(FLAGS.model_dir, 'cropped_panda.jpg'))
+  run_inference_on_image(image)
+
+
+if __name__ == '__main__':
+  parser = argparse.ArgumentParser()
+  # classify_image_graph_def.pb:
+  #   Binary representation of the GraphDef protocol buffer.
+  # imagenet_synset_to_human_label_map.txt:
+  #   Map from synset ID to a human readable string.
+  # imagenet_2012_challenge_label_map_proto.pbtxt:
+  #   Text representation of a protocol buffer mapping a label to synset ID.
+  parser.add_argument(
+      '--model_dir',
+      type=str,
+      default='/tmp/imagenet',
+      help="""\
+      Path to classify_image_graph_def.pb,
+      imagenet_synset_to_human_label_map.txt, and
+      imagenet_2012_challenge_label_map_proto.pbtxt.\
+      """
+  )
+  parser.add_argument(
+      '--image_file',
+      type=str,
+      default='',
+      help='Absolute path to image file.'
+  )
+  parser.add_argument(
+      '--num_top_predictions',
+      type=int,
+      default=5,
+      help='Display this many predictions.'
+  )
+  FLAGS, unparsed = parser.parse_known_args()
+  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

+ 42 - 0
tutorials/image/mnist/BUILD

@@ -0,0 +1,42 @@
+# Description:
+# Example TensorFlow models for MNIST that achieves high accuracy
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_binary(
+    name = "convolutional",
+    srcs = [
+        "convolutional.py",
+    ],
+    srcs_version = "PY2AND3",
+    visibility = ["//tensorflow:__subpackages__"],
+    deps = ["//tensorflow:tensorflow_py"],
+)
+
+py_test(
+    name = "convolutional_test",
+    size = "medium",
+    srcs = [
+        "convolutional.py",
+    ],
+    args = [
+        "--self_test",
+    ],
+    main = "convolutional.py",
+    srcs_version = "PY2AND3",
+    deps = ["//tensorflow:tensorflow_py"],
+)
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)

+ 0 - 0
tutorials/image/mnist/__init__.py


+ 339 - 0
tutorials/image/mnist/convolutional.py

@@ -0,0 +1,339 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Simple, end-to-end, LeNet-5-like convolutional MNIST model example.
+
+This should achieve a test error of 0.7%. Please keep this model as simple and
+linear as possible, it is meant as a tutorial for simple convolutional models.
+Run with --self_test on the command line to execute a short self-test.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import gzip
+import os
+import sys
+import time
+
+import numpy
+from six.moves import urllib
+from six.moves import xrange  # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
+WORK_DIRECTORY = 'data'
+IMAGE_SIZE = 28
+NUM_CHANNELS = 1
+PIXEL_DEPTH = 255
+NUM_LABELS = 10
+VALIDATION_SIZE = 5000  # Size of the validation set.
+SEED = 66478  # Set to None for random seed.
+BATCH_SIZE = 64
+NUM_EPOCHS = 10
+EVAL_BATCH_SIZE = 64
+EVAL_FREQUENCY = 100  # Number of steps between evaluations.
+
+
+FLAGS = None
+
+
+def data_type():
+  """Return the type of the activations, weights, and placeholder variables."""
+  if FLAGS.use_fp16:
+    return tf.float16
+  else:
+    return tf.float32
+
+
+def maybe_download(filename):
+  """Download the data from Yann's website, unless it's already here."""
+  if not tf.gfile.Exists(WORK_DIRECTORY):
+    tf.gfile.MakeDirs(WORK_DIRECTORY)
+  filepath = os.path.join(WORK_DIRECTORY, filename)
+  if not tf.gfile.Exists(filepath):
+    filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
+    with tf.gfile.GFile(filepath) as f:
+      size = f.size()
+    print('Successfully downloaded', filename, size, 'bytes.')
+  return filepath
+
+
+def extract_data(filename, num_images):
+  """Extract the images into a 4D tensor [image index, y, x, channels].
+
+  Values are rescaled from [0, 255] down to [-0.5, 0.5].
+  """
+  print('Extracting', filename)
+  with gzip.open(filename) as bytestream:
+    bytestream.read(16)
+    buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images * NUM_CHANNELS)
+    data = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.float32)
+    data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH
+    data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)
+    return data
+
+
+def extract_labels(filename, num_images):
+  """Extract the labels into a vector of int64 label IDs."""
+  print('Extracting', filename)
+  with gzip.open(filename) as bytestream:
+    bytestream.read(8)
+    buf = bytestream.read(1 * num_images)
+    labels = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.int64)
+  return labels
+
+
+def fake_data(num_images):
+  """Generate a fake dataset that matches the dimensions of MNIST."""
+  data = numpy.ndarray(
+      shape=(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS),
+      dtype=numpy.float32)
+  labels = numpy.zeros(shape=(num_images,), dtype=numpy.int64)
+  for image in xrange(num_images):
+    label = image % 2
+    data[image, :, :, 0] = label - 0.5
+    labels[image] = label
+  return data, labels
+
+
+def error_rate(predictions, labels):
+  """Return the error rate based on dense predictions and sparse labels."""
+  return 100.0 - (
+      100.0 *
+      numpy.sum(numpy.argmax(predictions, 1) == labels) /
+      predictions.shape[0])
+
+
+def main(_):
+  if FLAGS.self_test:
+    print('Running self-test.')
+    train_data, train_labels = fake_data(256)
+    validation_data, validation_labels = fake_data(EVAL_BATCH_SIZE)
+    test_data, test_labels = fake_data(EVAL_BATCH_SIZE)
+    num_epochs = 1
+  else:
+    # Get the data.
+    train_data_filename = maybe_download('train-images-idx3-ubyte.gz')
+    train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz')
+    test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz')
+    test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz')
+
+    # Extract it into numpy arrays.
+    train_data = extract_data(train_data_filename, 60000)
+    train_labels = extract_labels(train_labels_filename, 60000)
+    test_data = extract_data(test_data_filename, 10000)
+    test_labels = extract_labels(test_labels_filename, 10000)
+
+    # Generate a validation set.
+    validation_data = train_data[:VALIDATION_SIZE, ...]
+    validation_labels = train_labels[:VALIDATION_SIZE]
+    train_data = train_data[VALIDATION_SIZE:, ...]
+    train_labels = train_labels[VALIDATION_SIZE:]
+    num_epochs = NUM_EPOCHS
+  train_size = train_labels.shape[0]
+
+  # This is where training samples and labels are fed to the graph.
+  # These placeholder nodes will be fed a batch of training data at each
+  # training step using the {feed_dict} argument to the Run() call below.
+  train_data_node = tf.placeholder(
+      data_type(),
+      shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
+  train_labels_node = tf.placeholder(tf.int64, shape=(BATCH_SIZE,))
+  eval_data = tf.placeholder(
+      data_type(),
+      shape=(EVAL_BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
+
+  # The variables below hold all the trainable weights. They are passed an
+  # initial value which will be assigned when we call:
+  # {tf.global_variables_initializer().run()}
+  conv1_weights = tf.Variable(
+      tf.truncated_normal([5, 5, NUM_CHANNELS, 32],  # 5x5 filter, depth 32.
+                          stddev=0.1,
+                          seed=SEED, dtype=data_type()))
+  conv1_biases = tf.Variable(tf.zeros([32], dtype=data_type()))
+  conv2_weights = tf.Variable(tf.truncated_normal(
+      [5, 5, 32, 64], stddev=0.1,
+      seed=SEED, dtype=data_type()))
+  conv2_biases = tf.Variable(tf.constant(0.1, shape=[64], dtype=data_type()))
+  fc1_weights = tf.Variable(  # fully connected, depth 512.
+      tf.truncated_normal([IMAGE_SIZE // 4 * IMAGE_SIZE // 4 * 64, 512],
+                          stddev=0.1,
+                          seed=SEED,
+                          dtype=data_type()))
+  fc1_biases = tf.Variable(tf.constant(0.1, shape=[512], dtype=data_type()))
+  fc2_weights = tf.Variable(tf.truncated_normal([512, NUM_LABELS],
+                                                stddev=0.1,
+                                                seed=SEED,
+                                                dtype=data_type()))
+  fc2_biases = tf.Variable(tf.constant(
+      0.1, shape=[NUM_LABELS], dtype=data_type()))
+
+  # We will replicate the model structure for the training subgraph, as well
+  # as the evaluation subgraphs, while sharing the trainable parameters.
+  def model(data, train=False):
+    """The Model definition."""
+    # 2D convolution, with 'SAME' padding (i.e. the output feature map has
+    # the same size as the input). Note that {strides} is a 4D array whose
+    # shape matches the data layout: [image index, y, x, depth].
+    conv = tf.nn.conv2d(data,
+                        conv1_weights,
+                        strides=[1, 1, 1, 1],
+                        padding='SAME')
+    # Bias and rectified linear non-linearity.
+    relu = tf.nn.relu(tf.nn.bias_add(conv, conv1_biases))
+    # Max pooling. The kernel size spec {ksize} also follows the layout of
+    # the data. Here we have a pooling window of 2, and a stride of 2.
+    pool = tf.nn.max_pool(relu,
+                          ksize=[1, 2, 2, 1],
+                          strides=[1, 2, 2, 1],
+                          padding='SAME')
+    conv = tf.nn.conv2d(pool,
+                        conv2_weights,
+                        strides=[1, 1, 1, 1],
+                        padding='SAME')
+    relu = tf.nn.relu(tf.nn.bias_add(conv, conv2_biases))
+    pool = tf.nn.max_pool(relu,
+                          ksize=[1, 2, 2, 1],
+                          strides=[1, 2, 2, 1],
+                          padding='SAME')
+    # Reshape the feature map cuboid into a 2D matrix to feed it to the
+    # fully connected layers.
+    pool_shape = pool.get_shape().as_list()
+    reshape = tf.reshape(
+        pool,
+        [pool_shape[0], pool_shape[1] * pool_shape[2] * pool_shape[3]])
+    # Fully connected layer. Note that the '+' operation automatically
+    # broadcasts the biases.
+    hidden = tf.nn.relu(tf.matmul(reshape, fc1_weights) + fc1_biases)
+    # Add a 50% dropout during training only. Dropout also scales
+    # activations such that no rescaling is needed at evaluation time.
+    if train:
+      hidden = tf.nn.dropout(hidden, 0.5, seed=SEED)
+    return tf.matmul(hidden, fc2_weights) + fc2_biases
+
+  # Training computation: logits + cross-entropy loss.
+  logits = model(train_data_node, True)
+  loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
+      logits, train_labels_node))
+
+  # L2 regularization for the fully connected parameters.
+  regularizers = (tf.nn.l2_loss(fc1_weights) + tf.nn.l2_loss(fc1_biases) +
+                  tf.nn.l2_loss(fc2_weights) + tf.nn.l2_loss(fc2_biases))
+  # Add the regularization term to the loss.
+  loss += 5e-4 * regularizers
+
+  # Optimizer: set up a variable that's incremented once per batch and
+  # controls the learning rate decay.
+  batch = tf.Variable(0, dtype=data_type())
+  # Decay once per epoch, using an exponential schedule starting at 0.01.
+  learning_rate = tf.train.exponential_decay(
+      0.01,                # Base learning rate.
+      batch * BATCH_SIZE,  # Current index into the dataset.
+      train_size,          # Decay step.
+      0.95,                # Decay rate.
+      staircase=True)
+  # Use simple momentum for the optimization.
+  optimizer = tf.train.MomentumOptimizer(learning_rate,
+                                         0.9).minimize(loss,
+                                                       global_step=batch)
+
+  # Predictions for the current training minibatch.
+  train_prediction = tf.nn.softmax(logits)
+
+  # Predictions for the test and validation, which we'll compute less often.
+  eval_prediction = tf.nn.softmax(model(eval_data))
+
+  # Small utility function to evaluate a dataset by feeding batches of data to
+  # {eval_data} and pulling the results from {eval_predictions}.
+  # Saves memory and enables this to run on smaller GPUs.
+  def eval_in_batches(data, sess):
+    """Get all predictions for a dataset by running it in small batches."""
+    size = data.shape[0]
+    if size < EVAL_BATCH_SIZE:
+      raise ValueError("batch size for evals larger than dataset: %d" % size)
+    predictions = numpy.ndarray(shape=(size, NUM_LABELS), dtype=numpy.float32)
+    for begin in xrange(0, size, EVAL_BATCH_SIZE):
+      end = begin + EVAL_BATCH_SIZE
+      if end <= size:
+        predictions[begin:end, :] = sess.run(
+            eval_prediction,
+            feed_dict={eval_data: data[begin:end, ...]})
+      else:
+        batch_predictions = sess.run(
+            eval_prediction,
+            feed_dict={eval_data: data[-EVAL_BATCH_SIZE:, ...]})
+        predictions[begin:, :] = batch_predictions[begin - size:, :]
+    return predictions
+
+  # Create a local session to run the training.
+  start_time = time.time()
+  with tf.Session() as sess:
+    # Run all the initializers to prepare the trainable parameters.
+    tf.global_variables_initializer().run()
+    print('Initialized!')
+    # Loop through training steps.
+    for step in xrange(int(num_epochs * train_size) // BATCH_SIZE):
+      # Compute the offset of the current minibatch in the data.
+      # Note that we could use better randomization across epochs.
+      offset = (step * BATCH_SIZE) % (train_size - BATCH_SIZE)
+      batch_data = train_data[offset:(offset + BATCH_SIZE), ...]
+      batch_labels = train_labels[offset:(offset + BATCH_SIZE)]
+      # This dictionary maps the batch data (as a numpy array) to the
+      # node in the graph it should be fed to.
+      feed_dict = {train_data_node: batch_data,
+                   train_labels_node: batch_labels}
+      # Run the optimizer to update weights.
+      sess.run(optimizer, feed_dict=feed_dict)
+      # print some extra information once reach the evaluation frequency
+      if step % EVAL_FREQUENCY == 0:
+        # fetch some extra nodes' data
+        l, lr, predictions = sess.run([loss, learning_rate, train_prediction],
+                                      feed_dict=feed_dict)
+        elapsed_time = time.time() - start_time
+        start_time = time.time()
+        print('Step %d (epoch %.2f), %.1f ms' %
+              (step, float(step) * BATCH_SIZE / train_size,
+               1000 * elapsed_time / EVAL_FREQUENCY))
+        print('Minibatch loss: %.3f, learning rate: %.6f' % (l, lr))
+        print('Minibatch error: %.1f%%' % error_rate(predictions, batch_labels))
+        print('Validation error: %.1f%%' % error_rate(
+            eval_in_batches(validation_data, sess), validation_labels))
+        sys.stdout.flush()
+    # Finally print the result!
+    test_error = error_rate(eval_in_batches(test_data, sess), test_labels)
+    print('Test error: %.1f%%' % test_error)
+    if FLAGS.self_test:
+      print('test_error', test_error)
+      assert test_error == 0.0, 'expected 0.0 test_error, got %.2f' % (
+          test_error,)
+
+
+if __name__ == '__main__':
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--use_fp16',
+      default=False,
+      help='Use half floats instead of full floats if True.',
+      action='store_true')
+  parser.add_argument(
+      '--self_test',
+      default=False,
+      action='store_true',
+      help='True if running a self test.')
+
+  FLAGS, unparsed = parser.parse_known_args()
+  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

+ 80 - 0
tutorials/rnn/BUILD

@@ -0,0 +1,80 @@
+# Description:
+# Example RNN models, including language models and sequence-to-sequence models.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+    name = "linear",
+    srcs = [
+        "linear.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+py_library(
+    name = "rnn_cell",
+    srcs = [
+        "rnn_cell.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":linear",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+py_library(
+    name = "package",
+    srcs = [
+        "__init__.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":rnn",
+        ":rnn_cell",
+        ":seq2seq",
+    ],
+)
+
+py_library(
+    name = "rnn",
+    srcs = [
+        "rnn.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":rnn_cell",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+py_library(
+    name = "seq2seq",
+    srcs = [
+        "seq2seq.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":rnn",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)

+ 13 - 0
tutorials/rnn/README.md

@@ -0,0 +1,13 @@
+This directory contains functions for creating recurrent neural networks
+and sequence-to-sequence models. Detailed instructions on how to get started
+and use them are available in the tutorials.
+
+* [RNN Tutorial](http://tensorflow.org/tutorials/recurrent/index.md)
+* [Sequence-to-Sequence Tutorial](http://tensorflow.org/tutorials/seq2seq/index.md)
+
+Here is a short overview of what is in this directory.
+
+File | What's in it?
+--- | ---
+`ptb/` | PTB language model, see the [RNN Tutorial](http://tensorflow.org/tutorials/recurrent/)
+`translate/` | Translation model, see the [Sequence-to-Sequence Tutorial](http://tensorflow.org/tutorials/seq2seq/)

+ 19 - 0
tutorials/rnn/__init__.py

@@ -0,0 +1,19 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Libraries to build Recurrent Neural Networks."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function

+ 20 - 0
tutorials/rnn/linear.py

@@ -0,0 +1,20 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Import linear python op for backward compatibility."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+raise ImportError("This module is deprecated.  Use tf.contrib.layers.linear.")

+ 61 - 0
tutorials/rnn/ptb/BUILD

@@ -0,0 +1,61 @@
+# Description:
+# Python support for TensorFlow.
+
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+    name = "package",
+    srcs = [
+        "__init__.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":reader",
+    ],
+)
+
+py_library(
+    name = "reader",
+    srcs = ["reader.py"],
+    srcs_version = "PY2AND3",
+    deps = ["//tensorflow:tensorflow_py"],
+)
+
+py_test(
+    name = "reader_test",
+    size = "small",
+    srcs = ["reader_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":reader",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+py_binary(
+    name = "ptb_word_lm",
+    srcs = [
+        "ptb_word_lm.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":reader",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)

+ 21 - 0
tutorials/rnn/ptb/__init__.py

@@ -0,0 +1,21 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Makes helper libraries available in the ptb package."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.models.rnn.ptb import reader

+ 371 - 0
tutorials/rnn/ptb/ptb_word_lm.py

@@ -0,0 +1,371 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Example / benchmark for building a PTB LSTM model.
+
+Trains the model described in:
+(Zaremba, et. al.) Recurrent Neural Network Regularization
+http://arxiv.org/abs/1409.2329
+
+There are 3 supported model configurations:
+===========================================
+| config | epochs | train | valid  | test
+===========================================
+| small  | 13     | 37.99 | 121.39 | 115.91
+| medium | 39     | 48.45 |  86.16 |  82.07
+| large  | 55     | 37.87 |  82.62 |  78.29
+The exact results may vary depending on the random initialization.
+
+The hyperparameters used in the model:
+- init_scale - the initial scale of the weights
+- learning_rate - the initial value of the learning rate
+- max_grad_norm - the maximum permissible norm of the gradient
+- num_layers - the number of LSTM layers
+- num_steps - the number of unrolled steps of LSTM
+- hidden_size - the number of LSTM units
+- max_epoch - the number of epochs trained with the initial learning rate
+- max_max_epoch - the total number of epochs for training
+- keep_prob - the probability of keeping weights in the dropout layer
+- lr_decay - the decay of the learning rate for each epoch after "max_epoch"
+- batch_size - the batch size
+
+The data required for this example is in the data/ dir of the
+PTB dataset from Tomas Mikolov's webpage:
+
+$ wget http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
+$ tar xvf simple-examples.tgz
+
+To run:
+
+$ python ptb_word_lm.py --data_path=simple-examples/data/
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.models.rnn.ptb import reader
+
+flags = tf.flags
+logging = tf.logging
+
+flags.DEFINE_string(
+    "model", "small",
+    "A type of model. Possible options are: small, medium, large.")
+flags.DEFINE_string("data_path", None,
+                    "Where the training/test data is stored.")
+flags.DEFINE_string("save_path", None,
+                    "Model output directory.")
+flags.DEFINE_bool("use_fp16", False,
+                  "Train using 16-bit floats instead of 32bit floats")
+
+FLAGS = flags.FLAGS
+
+
+def data_type():
+  return tf.float16 if FLAGS.use_fp16 else tf.float32
+
+
+class PTBInput(object):
+  """The input data."""
+
+  def __init__(self, config, data, name=None):
+    self.batch_size = batch_size = config.batch_size
+    self.num_steps = num_steps = config.num_steps
+    self.epoch_size = ((len(data) // batch_size) - 1) // num_steps
+    self.input_data, self.targets = reader.ptb_producer(
+        data, batch_size, num_steps, name=name)
+
+
+class PTBModel(object):
+  """The PTB model."""
+
+  def __init__(self, is_training, config, input_):
+    self._input = input_
+
+    batch_size = input_.batch_size
+    num_steps = input_.num_steps
+    size = config.hidden_size
+    vocab_size = config.vocab_size
+
+    # Slightly better results can be obtained with forget gate biases
+    # initialized to 1 but the hyperparameters of the model would need to be
+    # different than reported in the paper.
+    lstm_cell = tf.contrib.rnn.BasicLSTMCell(
+        size, forget_bias=0.0, state_is_tuple=True)
+    if is_training and config.keep_prob < 1:
+      lstm_cell = tf.contrib.rnn.DropoutWrapper(
+          lstm_cell, output_keep_prob=config.keep_prob)
+    cell = tf.contrib.rnn.MultiRNNCell(
+        [lstm_cell] * config.num_layers, state_is_tuple=True)
+
+    self._initial_state = cell.zero_state(batch_size, data_type())
+
+    with tf.device("/cpu:0"):
+      embedding = tf.get_variable(
+          "embedding", [vocab_size, size], dtype=data_type())
+      inputs = tf.nn.embedding_lookup(embedding, input_.input_data)
+
+    if is_training and config.keep_prob < 1:
+      inputs = tf.nn.dropout(inputs, config.keep_prob)
+
+    # Simplified version of tensorflow.models.rnn.rnn.py's rnn().
+    # This builds an unrolled LSTM for tutorial purposes only.
+    # In general, use the rnn() or state_saving_rnn() from rnn.py.
+    #
+    # The alternative version of the code below is:
+    #
+    # inputs = tf.unstack(inputs, num=num_steps, axis=1)
+    # outputs, state = tf.nn.rnn(cell, inputs,
+    #                            initial_state=self._initial_state)
+    outputs = []
+    state = self._initial_state
+    with tf.variable_scope("RNN"):
+      for time_step in range(num_steps):
+        if time_step > 0: tf.get_variable_scope().reuse_variables()
+        (cell_output, state) = cell(inputs[:, time_step, :], state)
+        outputs.append(cell_output)
+
+    output = tf.reshape(tf.concat_v2(outputs, 1), [-1, size])
+    softmax_w = tf.get_variable(
+        "softmax_w", [size, vocab_size], dtype=data_type())
+    softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type())
+    logits = tf.matmul(output, softmax_w) + softmax_b
+    loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
+        [logits],
+        [tf.reshape(input_.targets, [-1])],
+        [tf.ones([batch_size * num_steps], dtype=data_type())])
+    self._cost = cost = tf.reduce_sum(loss) / batch_size
+    self._final_state = state
+
+    if not is_training:
+      return
+
+    self._lr = tf.Variable(0.0, trainable=False)
+    tvars = tf.trainable_variables()
+    grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
+                                      config.max_grad_norm)
+    optimizer = tf.train.GradientDescentOptimizer(self._lr)
+    self._train_op = optimizer.apply_gradients(
+        zip(grads, tvars),
+        global_step=tf.contrib.framework.get_or_create_global_step())
+
+    self._new_lr = tf.placeholder(
+        tf.float32, shape=[], name="new_learning_rate")
+    self._lr_update = tf.assign(self._lr, self._new_lr)
+
+  def assign_lr(self, session, lr_value):
+    session.run(self._lr_update, feed_dict={self._new_lr: lr_value})
+
+  @property
+  def input(self):
+    return self._input
+
+  @property
+  def initial_state(self):
+    return self._initial_state
+
+  @property
+  def cost(self):
+    return self._cost
+
+  @property
+  def final_state(self):
+    return self._final_state
+
+  @property
+  def lr(self):
+    return self._lr
+
+  @property
+  def train_op(self):
+    return self._train_op
+
+
+class SmallConfig(object):
+  """Small config."""
+  init_scale = 0.1
+  learning_rate = 1.0
+  max_grad_norm = 5
+  num_layers = 2
+  num_steps = 20
+  hidden_size = 200
+  max_epoch = 4
+  max_max_epoch = 13
+  keep_prob = 1.0
+  lr_decay = 0.5
+  batch_size = 20
+  vocab_size = 10000
+
+
+class MediumConfig(object):
+  """Medium config."""
+  init_scale = 0.05
+  learning_rate = 1.0
+  max_grad_norm = 5
+  num_layers = 2
+  num_steps = 35
+  hidden_size = 650
+  max_epoch = 6
+  max_max_epoch = 39
+  keep_prob = 0.5
+  lr_decay = 0.8
+  batch_size = 20
+  vocab_size = 10000
+
+
+class LargeConfig(object):
+  """Large config."""
+  init_scale = 0.04
+  learning_rate = 1.0
+  max_grad_norm = 10
+  num_layers = 2
+  num_steps = 35
+  hidden_size = 1500
+  max_epoch = 14
+  max_max_epoch = 55
+  keep_prob = 0.35
+  lr_decay = 1 / 1.15
+  batch_size = 20
+  vocab_size = 10000
+
+
+class TestConfig(object):
+  """Tiny config, for testing."""
+  init_scale = 0.1
+  learning_rate = 1.0
+  max_grad_norm = 1
+  num_layers = 1
+  num_steps = 2
+  hidden_size = 2
+  max_epoch = 1
+  max_max_epoch = 1
+  keep_prob = 1.0
+  lr_decay = 0.5
+  batch_size = 20
+  vocab_size = 10000
+
+
+def run_epoch(session, model, eval_op=None, verbose=False):
+  """Runs the model on the given data."""
+  start_time = time.time()
+  costs = 0.0
+  iters = 0
+  state = session.run(model.initial_state)
+
+  fetches = {
+      "cost": model.cost,
+      "final_state": model.final_state,
+  }
+  if eval_op is not None:
+    fetches["eval_op"] = eval_op
+
+  for step in range(model.input.epoch_size):
+    feed_dict = {}
+    for i, (c, h) in enumerate(model.initial_state):
+      feed_dict[c] = state[i].c
+      feed_dict[h] = state[i].h
+
+    vals = session.run(fetches, feed_dict)
+    cost = vals["cost"]
+    state = vals["final_state"]
+
+    costs += cost
+    iters += model.input.num_steps
+
+    if verbose and step % (model.input.epoch_size // 10) == 10:
+      print("%.3f perplexity: %.3f speed: %.0f wps" %
+            (step * 1.0 / model.input.epoch_size, np.exp(costs / iters),
+             iters * model.input.batch_size / (time.time() - start_time)))
+
+  return np.exp(costs / iters)
+
+
+def get_config():
+  if FLAGS.model == "small":
+    return SmallConfig()
+  elif FLAGS.model == "medium":
+    return MediumConfig()
+  elif FLAGS.model == "large":
+    return LargeConfig()
+  elif FLAGS.model == "test":
+    return TestConfig()
+  else:
+    raise ValueError("Invalid model: %s", FLAGS.model)
+
+
+def main(_):
+  if not FLAGS.data_path:
+    raise ValueError("Must set --data_path to PTB data directory")
+
+  raw_data = reader.ptb_raw_data(FLAGS.data_path)
+  train_data, valid_data, test_data, _ = raw_data
+
+  config = get_config()
+  eval_config = get_config()
+  eval_config.batch_size = 1
+  eval_config.num_steps = 1
+
+  with tf.Graph().as_default():
+    initializer = tf.random_uniform_initializer(-config.init_scale,
+                                                config.init_scale)
+
+    with tf.name_scope("Train"):
+      train_input = PTBInput(config=config, data=train_data, name="TrainInput")
+      with tf.variable_scope("Model", reuse=None, initializer=initializer):
+        m = PTBModel(is_training=True, config=config, input_=train_input)
+      tf.contrib.deprecated.scalar_summary("Training Loss", m.cost)
+      tf.contrib.deprecated.scalar_summary("Learning Rate", m.lr)
+
+    with tf.name_scope("Valid"):
+      valid_input = PTBInput(config=config, data=valid_data, name="ValidInput")
+      with tf.variable_scope("Model", reuse=True, initializer=initializer):
+        mvalid = PTBModel(is_training=False, config=config, input_=valid_input)
+      tf.contrib.deprecated.scalar_summary("Validation Loss", mvalid.cost)
+
+    with tf.name_scope("Test"):
+      test_input = PTBInput(config=eval_config, data=test_data, name="TestInput")
+      with tf.variable_scope("Model", reuse=True, initializer=initializer):
+        mtest = PTBModel(is_training=False, config=eval_config,
+                         input_=test_input)
+
+    sv = tf.train.Supervisor(logdir=FLAGS.save_path)
+    with sv.managed_session() as session:
+      for i in range(config.max_max_epoch):
+        lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0)
+        m.assign_lr(session, config.learning_rate * lr_decay)
+
+        print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
+        train_perplexity = run_epoch(session, m, eval_op=m.train_op,
+                                     verbose=True)
+        print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
+        valid_perplexity = run_epoch(session, mvalid)
+        print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity))
+
+      test_perplexity = run_epoch(session, mtest)
+      print("Test Perplexity: %.3f" % test_perplexity)
+
+      if FLAGS.save_path:
+        print("Saving model to %s." % FLAGS.save_path)
+        sv.saver.save(session, FLAGS.save_path, global_step=sv.global_step)
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 122 - 0
tutorials/rnn/ptb/reader.py

@@ -0,0 +1,122 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+"""Utilities for parsing PTB text files."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+
+import tensorflow as tf
+
+
+def _read_words(filename):
+  with tf.gfile.GFile(filename, "r") as f:
+    return f.read().decode("utf-8").replace("\n", "<eos>").split()
+
+
+def _build_vocab(filename):
+  data = _read_words(filename)
+
+  counter = collections.Counter(data)
+  count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
+
+  words, _ = list(zip(*count_pairs))
+  word_to_id = dict(zip(words, range(len(words))))
+
+  return word_to_id
+
+
+def _file_to_word_ids(filename, word_to_id):
+  data = _read_words(filename)
+  return [word_to_id[word] for word in data if word in word_to_id]
+
+
+def ptb_raw_data(data_path=None):
+  """Load PTB raw data from data directory "data_path".
+
+  Reads PTB text files, converts strings to integer ids,
+  and performs mini-batching of the inputs.
+
+  The PTB dataset comes from Tomas Mikolov's webpage:
+
+  http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
+
+  Args:
+    data_path: string path to the directory where simple-examples.tgz has
+      been extracted.
+
+  Returns:
+    tuple (train_data, valid_data, test_data, vocabulary)
+    where each of the data objects can be passed to PTBIterator.
+  """
+
+  train_path = os.path.join(data_path, "ptb.train.txt")
+  valid_path = os.path.join(data_path, "ptb.valid.txt")
+  test_path = os.path.join(data_path, "ptb.test.txt")
+
+  word_to_id = _build_vocab(train_path)
+  train_data = _file_to_word_ids(train_path, word_to_id)
+  valid_data = _file_to_word_ids(valid_path, word_to_id)
+  test_data = _file_to_word_ids(test_path, word_to_id)
+  vocabulary = len(word_to_id)
+  return train_data, valid_data, test_data, vocabulary
+
+
+def ptb_producer(raw_data, batch_size, num_steps, name=None):
+  """Iterate on the raw PTB data.
+
+  This chunks up raw_data into batches of examples and returns Tensors that
+  are drawn from these batches.
+
+  Args:
+    raw_data: one of the raw data outputs from ptb_raw_data.
+    batch_size: int, the batch size.
+    num_steps: int, the number of unrolls.
+    name: the name of this operation (optional).
+
+  Returns:
+    A pair of Tensors, each shaped [batch_size, num_steps]. The second element
+    of the tuple is the same data time-shifted to the right by one.
+
+  Raises:
+    tf.errors.InvalidArgumentError: if batch_size or num_steps are too high.
+  """
+  with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]):
+    raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32)
+
+    data_len = tf.size(raw_data)
+    batch_len = data_len // batch_size
+    data = tf.reshape(raw_data[0 : batch_size * batch_len],
+                      [batch_size, batch_len])
+
+    epoch_size = (batch_len - 1) // num_steps
+    assertion = tf.assert_positive(
+        epoch_size,
+        message="epoch_size == 0, decrease batch_size or num_steps")
+    with tf.control_dependencies([assertion]):
+      epoch_size = tf.identity(epoch_size, name="epoch_size")
+
+    i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
+    x = tf.strided_slice(data, [0, i * num_steps],
+                         [batch_size, (i + 1) * num_steps])
+    x.set_shape([batch_size, num_steps])
+    y = tf.strided_slice(data, [0, i * num_steps + 1],
+                         [batch_size, (i + 1) * num_steps + 1])
+    y.set_shape([batch_size, num_steps])
+    return x, y

+ 68 - 0
tutorials/rnn/ptb/reader_test.py

@@ -0,0 +1,68 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for tensorflow.models.ptb_lstm.ptb_reader."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+import tensorflow as tf
+
+from tensorflow.models.rnn.ptb import reader
+
+
+class PtbReaderTest(tf.test.TestCase):
+
+  def setUp(self):
+    self._string_data = "\n".join(
+        [" hello there i am",
+         " rain as day",
+         " want some cheesy puffs ?"])
+
+  def testPtbRawData(self):
+    tmpdir = tf.test.get_temp_dir()
+    for suffix in "train", "valid", "test":
+      filename = os.path.join(tmpdir, "ptb.%s.txt" % suffix)
+      with tf.gfile.GFile(filename, "w") as fh:
+        fh.write(self._string_data)
+    # Smoke test
+    output = reader.ptb_raw_data(tmpdir)
+    self.assertEqual(len(output), 4)
+
+  def testPtbProducer(self):
+    raw_data = [4, 3, 2, 1, 0, 5, 6, 1, 1, 1, 1, 0, 3, 4, 1]
+    batch_size = 3
+    num_steps = 2
+    x, y = reader.ptb_producer(raw_data, batch_size, num_steps)
+    with self.test_session() as session:
+      coord = tf.train.Coordinator()
+      tf.train.start_queue_runners(session, coord=coord)
+      try:
+        xval, yval = session.run([x, y])
+        self.assertAllEqual(xval, [[4, 3], [5, 6], [1, 0]])
+        self.assertAllEqual(yval, [[3, 2], [6, 1], [0, 3]])
+        xval, yval = session.run([x, y])
+        self.assertAllEqual(xval, [[2, 1], [1, 1], [3, 4]])
+        self.assertAllEqual(yval, [[1, 0], [1, 1], [4, 1]])
+      finally:
+        coord.request_stop()
+        coord.join()
+
+
+if __name__ == "__main__":
+  tf.test.main()

+ 21 - 0
tutorials/rnn/rnn.py

@@ -0,0 +1,21 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Import rnn python ops for backward compatibility."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+raise ImportError("This module is deprecated.  Use tf.nn.rnn_* instead.")

+ 21 - 0
tutorials/rnn/rnn_cell.py

@@ -0,0 +1,21 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Import rnn_cell python ops for backward compatibility."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+raise ImportError("This module is deprecated.  Use tf.contrib.rnn instead.")

+ 22 - 0
tutorials/rnn/seq2seq.py

@@ -0,0 +1,22 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Import seq2seq python ops for backward compatibility."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+raise ImportError(
+    "This module is deprecated. Use tf.contrib.legacy_seq2seq instead.")

+ 84 - 0
tutorials/rnn/translate/BUILD

@@ -0,0 +1,84 @@
+# Description:
+# Example neural translation models.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+    name = "package",
+    srcs = [
+        "__init__.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":data_utils",
+        ":seq2seq_model",
+    ],
+)
+
+py_library(
+    name = "data_utils",
+    srcs = [
+        "data_utils.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = ["//tensorflow:tensorflow_py"],
+)
+
+py_library(
+    name = "seq2seq_model",
+    srcs = [
+        "seq2seq_model.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":data_utils",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+py_binary(
+    name = "translate",
+    srcs = [
+        "translate.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":data_utils",
+        ":seq2seq_model",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+py_test(
+    name = "translate_test",
+    size = "medium",
+    srcs = [
+        "translate.py",
+    ],
+    args = [
+        "--self_test=True",
+    ],
+    main = "translate.py",
+    srcs_version = "PY2AND3",
+    deps = [
+        ":data_utils",
+        ":seq2seq_model",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)

+ 22 - 0
tutorials/rnn/translate/__init__.py

@@ -0,0 +1,22 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Makes helper libraries available in the translate package."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.models.rnn.translate import data_utils
+from tensorflow.models.rnn.translate import seq2seq_model

+ 290 - 0
tutorials/rnn/translate/data_utils.py

@@ -0,0 +1,290 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utilities for downloading data from WMT, tokenizing, vocabularies."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import re
+import tarfile
+
+from six.moves import urllib
+
+from tensorflow.python.platform import gfile
+import tensorflow as tf
+
+# Special vocabulary symbols - we always put them at the start.
+_PAD = b"_PAD"
+_GO = b"_GO"
+_EOS = b"_EOS"
+_UNK = b"_UNK"
+_START_VOCAB = [_PAD, _GO, _EOS, _UNK]
+
+PAD_ID = 0
+GO_ID = 1
+EOS_ID = 2
+UNK_ID = 3
+
+# Regular expressions used to tokenize.
+_WORD_SPLIT = re.compile(b"([.,!?\"':;)(])")
+_DIGIT_RE = re.compile(br"\d")
+
+# URLs for WMT data.
+_WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/training-giga-fren.tar"
+_WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/dev-v2.tgz"
+
+
+def maybe_download(directory, filename, url):
+  """Download filename from url unless it's already in directory."""
+  if not os.path.exists(directory):
+    print("Creating directory %s" % directory)
+    os.mkdir(directory)
+  filepath = os.path.join(directory, filename)
+  if not os.path.exists(filepath):
+    print("Downloading %s to %s" % (url, filepath))
+    filepath, _ = urllib.request.urlretrieve(url, filepath)
+    statinfo = os.stat(filepath)
+    print("Successfully downloaded", filename, statinfo.st_size, "bytes")
+  return filepath
+
+
+def gunzip_file(gz_path, new_path):
+  """Unzips from gz_path into new_path."""
+  print("Unpacking %s to %s" % (gz_path, new_path))
+  with gzip.open(gz_path, "rb") as gz_file:
+    with open(new_path, "wb") as new_file:
+      for line in gz_file:
+        new_file.write(line)
+
+
+def get_wmt_enfr_train_set(directory):
+  """Download the WMT en-fr training corpus to directory unless it's there."""
+  train_path = os.path.join(directory, "giga-fren.release2.fixed")
+  if not (gfile.Exists(train_path +".fr") and gfile.Exists(train_path +".en")):
+    corpus_file = maybe_download(directory, "training-giga-fren.tar",
+                                 _WMT_ENFR_TRAIN_URL)
+    print("Extracting tar file %s" % corpus_file)
+    with tarfile.open(corpus_file, "r") as corpus_tar:
+      corpus_tar.extractall(directory)
+    gunzip_file(train_path + ".fr.gz", train_path + ".fr")
+    gunzip_file(train_path + ".en.gz", train_path + ".en")
+  return train_path
+
+
+def get_wmt_enfr_dev_set(directory):
+  """Download the WMT en-fr training corpus to directory unless it's there."""
+  dev_name = "newstest2013"
+  dev_path = os.path.join(directory, dev_name)
+  if not (gfile.Exists(dev_path + ".fr") and gfile.Exists(dev_path + ".en")):
+    dev_file = maybe_download(directory, "dev-v2.tgz", _WMT_ENFR_DEV_URL)
+    print("Extracting tgz file %s" % dev_file)
+    with tarfile.open(dev_file, "r:gz") as dev_tar:
+      fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr")
+      en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en")
+      fr_dev_file.name = dev_name + ".fr"  # Extract without "dev/" prefix.
+      en_dev_file.name = dev_name + ".en"
+      dev_tar.extract(fr_dev_file, directory)
+      dev_tar.extract(en_dev_file, directory)
+  return dev_path
+
+
+def basic_tokenizer(sentence):
+  """Very basic tokenizer: split the sentence into a list of tokens."""
+  words = []
+  for space_separated_fragment in sentence.strip().split():
+    words.extend(_WORD_SPLIT.split(space_separated_fragment))
+  return [w for w in words if w]
+
+
+def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size,
+                      tokenizer=None, normalize_digits=True):
+  """Create vocabulary file (if it does not exist yet) from data file.
+
+  Data file is assumed to contain one sentence per line. Each sentence is
+  tokenized and digits are normalized (if normalize_digits is set).
+  Vocabulary contains the most-frequent tokens up to max_vocabulary_size.
+  We write it to vocabulary_path in a one-token-per-line format, so that later
+  token in the first line gets id=0, second line gets id=1, and so on.
+
+  Args:
+    vocabulary_path: path where the vocabulary will be created.
+    data_path: data file that will be used to create vocabulary.
+    max_vocabulary_size: limit on the size of the created vocabulary.
+    tokenizer: a function to use to tokenize each data sentence;
+      if None, basic_tokenizer will be used.
+    normalize_digits: Boolean; if true, all digits are replaced by 0s.
+  """
+  if not gfile.Exists(vocabulary_path):
+    print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path))
+    vocab = {}
+    with gfile.GFile(data_path, mode="rb") as f:
+      counter = 0
+      for line in f:
+        counter += 1
+        if counter % 100000 == 0:
+          print("  processing line %d" % counter)
+        line = tf.compat.as_bytes(line)
+        tokens = tokenizer(line) if tokenizer else basic_tokenizer(line)
+        for w in tokens:
+          word = _DIGIT_RE.sub(b"0", w) if normalize_digits else w
+          if word in vocab:
+            vocab[word] += 1
+          else:
+            vocab[word] = 1
+      vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True)
+      if len(vocab_list) > max_vocabulary_size:
+        vocab_list = vocab_list[:max_vocabulary_size]
+      with gfile.GFile(vocabulary_path, mode="wb") as vocab_file:
+        for w in vocab_list:
+          vocab_file.write(w + b"\n")
+
+
+def initialize_vocabulary(vocabulary_path):
+  """Initialize vocabulary from file.
+
+  We assume the vocabulary is stored one-item-per-line, so a file:
+    dog
+    cat
+  will result in a vocabulary {"dog": 0, "cat": 1}, and this function will
+  also return the reversed-vocabulary ["dog", "cat"].
+
+  Args:
+    vocabulary_path: path to the file containing the vocabulary.
+
+  Returns:
+    a pair: the vocabulary (a dictionary mapping string to integers), and
+    the reversed vocabulary (a list, which reverses the vocabulary mapping).
+
+  Raises:
+    ValueError: if the provided vocabulary_path does not exist.
+  """
+  if gfile.Exists(vocabulary_path):
+    rev_vocab = []
+    with gfile.GFile(vocabulary_path, mode="rb") as f:
+      rev_vocab.extend(f.readlines())
+    rev_vocab = [line.strip() for line in rev_vocab]
+    vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)])
+    return vocab, rev_vocab
+  else:
+    raise ValueError("Vocabulary file %s not found.", vocabulary_path)
+
+
+def sentence_to_token_ids(sentence, vocabulary,
+                          tokenizer=None, normalize_digits=True):
+  """Convert a string to list of integers representing token-ids.
+
+  For example, a sentence "I have a dog" may become tokenized into
+  ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2,
+  "a": 4, "dog": 7"} this function will return [1, 2, 4, 7].
+
+  Args:
+    sentence: the sentence in bytes format to convert to token-ids.
+    vocabulary: a dictionary mapping tokens to integers.
+    tokenizer: a function to use to tokenize each sentence;
+      if None, basic_tokenizer will be used.
+    normalize_digits: Boolean; if true, all digits are replaced by 0s.
+
+  Returns:
+    a list of integers, the token-ids for the sentence.
+  """
+
+  if tokenizer:
+    words = tokenizer(sentence)
+  else:
+    words = basic_tokenizer(sentence)
+  if not normalize_digits:
+    return [vocabulary.get(w, UNK_ID) for w in words]
+  # Normalize digits by 0 before looking words up in the vocabulary.
+  return [vocabulary.get(_DIGIT_RE.sub(b"0", w), UNK_ID) for w in words]
+
+
+def data_to_token_ids(data_path, target_path, vocabulary_path,
+                      tokenizer=None, normalize_digits=True):
+  """Tokenize data file and turn into token-ids using given vocabulary file.
+
+  This function loads data line-by-line from data_path, calls the above
+  sentence_to_token_ids, and saves the result to target_path. See comment
+  for sentence_to_token_ids on the details of token-ids format.
+
+  Args:
+    data_path: path to the data file in one-sentence-per-line format.
+    target_path: path where the file with token-ids will be created.
+    vocabulary_path: path to the vocabulary file.
+    tokenizer: a function to use to tokenize each sentence;
+      if None, basic_tokenizer will be used.
+    normalize_digits: Boolean; if true, all digits are replaced by 0s.
+  """
+  if not gfile.Exists(target_path):
+    print("Tokenizing data in %s" % data_path)
+    vocab, _ = initialize_vocabulary(vocabulary_path)
+    with gfile.GFile(data_path, mode="rb") as data_file:
+      with gfile.GFile(target_path, mode="w") as tokens_file:
+        counter = 0
+        for line in data_file:
+          counter += 1
+          if counter % 100000 == 0:
+            print("  tokenizing line %d" % counter)
+          token_ids = sentence_to_token_ids(tf.compat.as_bytes(line), vocab,
+                                            tokenizer, normalize_digits)
+          tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")
+
+
+def prepare_wmt_data(data_dir, en_vocabulary_size, fr_vocabulary_size, tokenizer=None):
+  """Get WMT data into data_dir, create vocabularies and tokenize data.
+
+  Args:
+    data_dir: directory in which the data sets will be stored.
+    en_vocabulary_size: size of the English vocabulary to create and use.
+    fr_vocabulary_size: size of the French vocabulary to create and use.
+    tokenizer: a function to use to tokenize each data sentence;
+      if None, basic_tokenizer will be used.
+
+  Returns:
+    A tuple of 6 elements:
+      (1) path to the token-ids for English training data-set,
+      (2) path to the token-ids for French training data-set,
+      (3) path to the token-ids for English development data-set,
+      (4) path to the token-ids for French development data-set,
+      (5) path to the English vocabulary file,
+      (6) path to the French vocabulary file.
+  """
+  # Get wmt data to the specified directory.
+  train_path = get_wmt_enfr_train_set(data_dir)
+  dev_path = get_wmt_enfr_dev_set(data_dir)
+
+  # Create vocabularies of the appropriate sizes.
+  fr_vocab_path = os.path.join(data_dir, "vocab%d.fr" % fr_vocabulary_size)
+  en_vocab_path = os.path.join(data_dir, "vocab%d.en" % en_vocabulary_size)
+  create_vocabulary(fr_vocab_path, train_path + ".fr", fr_vocabulary_size, tokenizer)
+  create_vocabulary(en_vocab_path, train_path + ".en", en_vocabulary_size, tokenizer)
+
+  # Create token ids for the training data.
+  fr_train_ids_path = train_path + (".ids%d.fr" % fr_vocabulary_size)
+  en_train_ids_path = train_path + (".ids%d.en" % en_vocabulary_size)
+  data_to_token_ids(train_path + ".fr", fr_train_ids_path, fr_vocab_path, tokenizer)
+  data_to_token_ids(train_path + ".en", en_train_ids_path, en_vocab_path, tokenizer)
+
+  # Create token ids for the development data.
+  fr_dev_ids_path = dev_path + (".ids%d.fr" % fr_vocabulary_size)
+  en_dev_ids_path = dev_path + (".ids%d.en" % en_vocabulary_size)
+  data_to_token_ids(dev_path + ".fr", fr_dev_ids_path, fr_vocab_path, tokenizer)
+  data_to_token_ids(dev_path + ".en", en_dev_ids_path, en_vocab_path, tokenizer)
+
+  return (en_train_ids_path, fr_train_ids_path,
+          en_dev_ids_path, fr_dev_ids_path,
+          en_vocab_path, fr_vocab_path)

+ 313 - 0
tutorials/rnn/translate/seq2seq_model.py

@@ -0,0 +1,313 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Sequence-to-sequence model with an attention mechanism."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+
+import numpy as np
+from six.moves import xrange  # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+from tensorflow.models.rnn.translate import data_utils
+
+
+class Seq2SeqModel(object):
+  """Sequence-to-sequence model with attention and for multiple buckets.
+
+  This class implements a multi-layer recurrent neural network as encoder,
+  and an attention-based decoder. This is the same as the model described in
+  this paper: http://arxiv.org/abs/1412.7449 - please look there for details,
+  or into the seq2seq library for complete model implementation.
+  This class also allows to use GRU cells in addition to LSTM cells, and
+  sampled softmax to handle large output vocabulary size. A single-layer
+  version of this model, but with bi-directional encoder, was presented in
+    http://arxiv.org/abs/1409.0473
+  and sampled softmax is described in Section 3 of the following paper.
+    http://arxiv.org/abs/1412.2007
+  """
+
+  def __init__(self,
+               source_vocab_size,
+               target_vocab_size,
+               buckets,
+               size,
+               num_layers,
+               max_gradient_norm,
+               batch_size,
+               learning_rate,
+               learning_rate_decay_factor,
+               use_lstm=False,
+               num_samples=512,
+               forward_only=False,
+               dtype=tf.float32):
+    """Create the model.
+
+    Args:
+      source_vocab_size: size of the source vocabulary.
+      target_vocab_size: size of the target vocabulary.
+      buckets: a list of pairs (I, O), where I specifies maximum input length
+        that will be processed in that bucket, and O specifies maximum output
+        length. Training instances that have inputs longer than I or outputs
+        longer than O will be pushed to the next bucket and padded accordingly.
+        We assume that the list is sorted, e.g., [(2, 4), (8, 16)].
+      size: number of units in each layer of the model.
+      num_layers: number of layers in the model.
+      max_gradient_norm: gradients will be clipped to maximally this norm.
+      batch_size: the size of the batches used during training;
+        the model construction is independent of batch_size, so it can be
+        changed after initialization if this is convenient, e.g., for decoding.
+      learning_rate: learning rate to start with.
+      learning_rate_decay_factor: decay learning rate by this much when needed.
+      use_lstm: if true, we use LSTM cells instead of GRU cells.
+      num_samples: number of samples for sampled softmax.
+      forward_only: if set, we do not construct the backward pass in the model.
+      dtype: the data type to use to store internal variables.
+    """
+    self.source_vocab_size = source_vocab_size
+    self.target_vocab_size = target_vocab_size
+    self.buckets = buckets
+    self.batch_size = batch_size
+    self.learning_rate = tf.Variable(
+        float(learning_rate), trainable=False, dtype=dtype)
+    self.learning_rate_decay_op = self.learning_rate.assign(
+        self.learning_rate * learning_rate_decay_factor)
+    self.global_step = tf.Variable(0, trainable=False)
+
+    # If we use sampled softmax, we need an output projection.
+    output_projection = None
+    softmax_loss_function = None
+    # Sampled softmax only makes sense if we sample less than vocabulary size.
+    if num_samples > 0 and num_samples < self.target_vocab_size:
+      w_t = tf.get_variable("proj_w", [self.target_vocab_size, size], dtype=dtype)
+      w = tf.transpose(w_t)
+      b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype)
+      output_projection = (w, b)
+
+      def sampled_loss(labels, inputs):
+        labels = tf.reshape(labels, [-1, 1])
+        # We need to compute the sampled_softmax_loss using 32bit floats to
+        # avoid numerical instabilities.
+        local_w_t = tf.cast(w_t, tf.float32)
+        local_b = tf.cast(b, tf.float32)
+        local_inputs = tf.cast(inputs, tf.float32)
+        return tf.cast(
+            tf.nn.sampled_softmax_loss(
+                weights=local_w_t,
+                biases=local_b,
+                labels=labels,
+                inputs=local_inputs,
+                num_sampled=num_samples,
+                num_classes=self.target_vocab_size),
+            dtype)
+      softmax_loss_function = sampled_loss
+
+    # Create the internal multi-layer cell for our RNN.
+    single_cell = tf.contrib.rnn.GRUCell(size)
+    if use_lstm:
+      single_cell = tf.contrib.rnn.BasicLSTMCell(size)
+    cell = single_cell
+    if num_layers > 1:
+      cell = tf.contrib.rnn.MultiRNNCell([single_cell] * num_layers)
+
+    # The seq2seq function: we use embedding for the input and attention.
+    def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
+      return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
+          encoder_inputs,
+          decoder_inputs,
+          cell,
+          num_encoder_symbols=source_vocab_size,
+          num_decoder_symbols=target_vocab_size,
+          embedding_size=size,
+          output_projection=output_projection,
+          feed_previous=do_decode,
+          dtype=dtype)
+
+    # Feeds for inputs.
+    self.encoder_inputs = []
+    self.decoder_inputs = []
+    self.target_weights = []
+    for i in xrange(buckets[-1][0]):  # Last bucket is the biggest one.
+      self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
+                                                name="encoder{0}".format(i)))
+    for i in xrange(buckets[-1][1] + 1):
+      self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
+                                                name="decoder{0}".format(i)))
+      self.target_weights.append(tf.placeholder(dtype, shape=[None],
+                                                name="weight{0}".format(i)))
+
+    # Our targets are decoder inputs shifted by one.
+    targets = [self.decoder_inputs[i + 1]
+               for i in xrange(len(self.decoder_inputs) - 1)]
+
+    # Training outputs and losses.
+    if forward_only:
+      self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
+          self.encoder_inputs, self.decoder_inputs, targets,
+          self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, True),
+          softmax_loss_function=softmax_loss_function)
+      # If we use output projection, we need to project outputs for decoding.
+      if output_projection is not None:
+        for b in xrange(len(buckets)):
+          self.outputs[b] = [
+              tf.matmul(output, output_projection[0]) + output_projection[1]
+              for output in self.outputs[b]
+          ]
+    else:
+      self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
+          self.encoder_inputs, self.decoder_inputs, targets,
+          self.target_weights, buckets,
+          lambda x, y: seq2seq_f(x, y, False),
+          softmax_loss_function=softmax_loss_function)
+
+    # Gradients and SGD update operation for training the model.
+    params = tf.trainable_variables()
+    if not forward_only:
+      self.gradient_norms = []
+      self.updates = []
+      opt = tf.train.GradientDescentOptimizer(self.learning_rate)
+      for b in xrange(len(buckets)):
+        gradients = tf.gradients(self.losses[b], params)
+        clipped_gradients, norm = tf.clip_by_global_norm(gradients,
+                                                         max_gradient_norm)
+        self.gradient_norms.append(norm)
+        self.updates.append(opt.apply_gradients(
+            zip(clipped_gradients, params), global_step=self.global_step))
+
+    self.saver = tf.train.Saver(tf.global_variables())
+
+  def step(self, session, encoder_inputs, decoder_inputs, target_weights,
+           bucket_id, forward_only):
+    """Run a step of the model feeding the given inputs.
+
+    Args:
+      session: tensorflow session to use.
+      encoder_inputs: list of numpy int vectors to feed as encoder inputs.
+      decoder_inputs: list of numpy int vectors to feed as decoder inputs.
+      target_weights: list of numpy float vectors to feed as target weights.
+      bucket_id: which bucket of the model to use.
+      forward_only: whether to do the backward step or only forward.
+
+    Returns:
+      A triple consisting of gradient norm (or None if we did not do backward),
+      average perplexity, and the outputs.
+
+    Raises:
+      ValueError: if length of encoder_inputs, decoder_inputs, or
+        target_weights disagrees with bucket size for the specified bucket_id.
+    """
+    # Check if the sizes match.
+    encoder_size, decoder_size = self.buckets[bucket_id]
+    if len(encoder_inputs) != encoder_size:
+      raise ValueError("Encoder length must be equal to the one in bucket,"
+                       " %d != %d." % (len(encoder_inputs), encoder_size))
+    if len(decoder_inputs) != decoder_size:
+      raise ValueError("Decoder length must be equal to the one in bucket,"
+                       " %d != %d." % (len(decoder_inputs), decoder_size))
+    if len(target_weights) != decoder_size:
+      raise ValueError("Weights length must be equal to the one in bucket,"
+                       " %d != %d." % (len(target_weights), decoder_size))
+
+    # Input feed: encoder inputs, decoder inputs, target_weights, as provided.
+    input_feed = {}
+    for l in xrange(encoder_size):
+      input_feed[self.encoder_inputs[l].name] = encoder_inputs[l]
+    for l in xrange(decoder_size):
+      input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
+      input_feed[self.target_weights[l].name] = target_weights[l]
+
+    # Since our targets are decoder inputs shifted by one, we need one more.
+    last_target = self.decoder_inputs[decoder_size].name
+    input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32)
+
+    # Output feed: depends on whether we do a backward step or not.
+    if not forward_only:
+      output_feed = [self.updates[bucket_id],  # Update Op that does SGD.
+                     self.gradient_norms[bucket_id],  # Gradient norm.
+                     self.losses[bucket_id]]  # Loss for this batch.
+    else:
+      output_feed = [self.losses[bucket_id]]  # Loss for this batch.
+      for l in xrange(decoder_size):  # Output logits.
+        output_feed.append(self.outputs[bucket_id][l])
+
+    outputs = session.run(output_feed, input_feed)
+    if not forward_only:
+      return outputs[1], outputs[2], None  # Gradient norm, loss, no outputs.
+    else:
+      return None, outputs[0], outputs[1:]  # No gradient norm, loss, outputs.
+
+  def get_batch(self, data, bucket_id):
+    """Get a random batch of data from the specified bucket, prepare for step.
+
+    To feed data in step(..) it must be a list of batch-major vectors, while
+    data here contains single length-major cases. So the main logic of this
+    function is to re-index data cases to be in the proper format for feeding.
+
+    Args:
+      data: a tuple of size len(self.buckets) in which each element contains
+        lists of pairs of input and output data that we use to create a batch.
+      bucket_id: integer, which bucket to get the batch for.
+
+    Returns:
+      The triple (encoder_inputs, decoder_inputs, target_weights) for
+      the constructed batch that has the proper format to call step(...) later.
+    """
+    encoder_size, decoder_size = self.buckets[bucket_id]
+    encoder_inputs, decoder_inputs = [], []
+
+    # Get a random batch of encoder and decoder inputs from data,
+    # pad them if needed, reverse encoder inputs and add GO to decoder.
+    for _ in xrange(self.batch_size):
+      encoder_input, decoder_input = random.choice(data[bucket_id])
+
+      # Encoder inputs are padded and then reversed.
+      encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input))
+      encoder_inputs.append(list(reversed(encoder_input + encoder_pad)))
+
+      # Decoder inputs get an extra "GO" symbol, and are padded then.
+      decoder_pad_size = decoder_size - len(decoder_input) - 1
+      decoder_inputs.append([data_utils.GO_ID] + decoder_input +
+                            [data_utils.PAD_ID] * decoder_pad_size)
+
+    # Now we create batch-major vectors from the data selected above.
+    batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
+
+    # Batch encoder inputs are just re-indexed encoder_inputs.
+    for length_idx in xrange(encoder_size):
+      batch_encoder_inputs.append(
+          np.array([encoder_inputs[batch_idx][length_idx]
+                    for batch_idx in xrange(self.batch_size)], dtype=np.int32))
+
+    # Batch decoder inputs are re-indexed decoder_inputs, we create weights.
+    for length_idx in xrange(decoder_size):
+      batch_decoder_inputs.append(
+          np.array([decoder_inputs[batch_idx][length_idx]
+                    for batch_idx in xrange(self.batch_size)], dtype=np.int32))
+
+      # Create target_weights to be 0 for targets that are padding.
+      batch_weight = np.ones(self.batch_size, dtype=np.float32)
+      for batch_idx in xrange(self.batch_size):
+        # We set weight to 0 if the corresponding target is a PAD symbol.
+        # The corresponding target is decoder_input shifted by 1 forward.
+        if length_idx < decoder_size - 1:
+          target = decoder_inputs[batch_idx][length_idx + 1]
+        if length_idx == decoder_size - 1 or target == data_utils.PAD_ID:
+          batch_weight[batch_idx] = 0.0
+      batch_weights.append(batch_weight)
+    return batch_encoder_inputs, batch_decoder_inputs, batch_weights

+ 297 - 0
tutorials/rnn/translate/translate.py

@@ -0,0 +1,297 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Binary for training translation models and decoding from them.
+
+Running this program without --decode will download the WMT corpus into
+the directory specified as --data_dir and tokenize it in a very basic way,
+and then start training a model saving checkpoints to --train_dir.
+
+Running with --decode starts an interactive loop so you can see how
+the current checkpoint translates English sentences into French.
+
+See the following papers for more information on neural translation models.
+ * http://arxiv.org/abs/1409.3215
+ * http://arxiv.org/abs/1409.0473
+ * http://arxiv.org/abs/1412.2007
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import os
+import random
+import sys
+import time
+import logging
+
+import numpy as np
+from six.moves import xrange  # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+from tensorflow.models.rnn.translate import data_utils
+from tensorflow.models.rnn.translate import seq2seq_model
+
+
+tf.app.flags.DEFINE_float("learning_rate", 0.5, "Learning rate.")
+tf.app.flags.DEFINE_float("learning_rate_decay_factor", 0.99,
+                          "Learning rate decays by this much.")
+tf.app.flags.DEFINE_float("max_gradient_norm", 5.0,
+                          "Clip gradients to this norm.")
+tf.app.flags.DEFINE_integer("batch_size", 64,
+                            "Batch size to use during training.")
+tf.app.flags.DEFINE_integer("size", 1024, "Size of each model layer.")
+tf.app.flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.")
+tf.app.flags.DEFINE_integer("en_vocab_size", 40000, "English vocabulary size.")
+tf.app.flags.DEFINE_integer("fr_vocab_size", 40000, "French vocabulary size.")
+tf.app.flags.DEFINE_string("data_dir", "/tmp", "Data directory")
+tf.app.flags.DEFINE_string("train_dir", "/tmp", "Training directory.")
+tf.app.flags.DEFINE_integer("max_train_data_size", 0,
+                            "Limit on the size of training data (0: no limit).")
+tf.app.flags.DEFINE_integer("steps_per_checkpoint", 200,
+                            "How many training steps to do per checkpoint.")
+tf.app.flags.DEFINE_boolean("decode", False,
+                            "Set to True for interactive decoding.")
+tf.app.flags.DEFINE_boolean("self_test", False,
+                            "Run a self-test if this is set to True.")
+tf.app.flags.DEFINE_boolean("use_fp16", False,
+                            "Train using fp16 instead of fp32.")
+
+FLAGS = tf.app.flags.FLAGS
+
+# We use a number of buckets and pad to the closest one for efficiency.
+# See seq2seq_model.Seq2SeqModel for details of how they work.
+_buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
+
+
+def read_data(source_path, target_path, max_size=None):
+  """Read data from source and target files and put into buckets.
+
+  Args:
+    source_path: path to the files with token-ids for the source language.
+    target_path: path to the file with token-ids for the target language;
+      it must be aligned with the source file: n-th line contains the desired
+      output for n-th line from the source_path.
+    max_size: maximum number of lines to read, all other will be ignored;
+      if 0 or None, data files will be read completely (no limit).
+
+  Returns:
+    data_set: a list of length len(_buckets); data_set[n] contains a list of
+      (source, target) pairs read from the provided data files that fit
+      into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
+      len(target) < _buckets[n][1]; source and target are lists of token-ids.
+  """
+  data_set = [[] for _ in _buckets]
+  with tf.gfile.GFile(source_path, mode="r") as source_file:
+    with tf.gfile.GFile(target_path, mode="r") as target_file:
+      source, target = source_file.readline(), target_file.readline()
+      counter = 0
+      while source and target and (not max_size or counter < max_size):
+        counter += 1
+        if counter % 100000 == 0:
+          print("  reading data line %d" % counter)
+          sys.stdout.flush()
+        source_ids = [int(x) for x in source.split()]
+        target_ids = [int(x) for x in target.split()]
+        target_ids.append(data_utils.EOS_ID)
+        for bucket_id, (source_size, target_size) in enumerate(_buckets):
+          if len(source_ids) < source_size and len(target_ids) < target_size:
+            data_set[bucket_id].append([source_ids, target_ids])
+            break
+        source, target = source_file.readline(), target_file.readline()
+  return data_set
+
+
+def create_model(session, forward_only):
+  """Create translation model and initialize or load parameters in session."""
+  dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
+  model = seq2seq_model.Seq2SeqModel(
+      FLAGS.en_vocab_size,
+      FLAGS.fr_vocab_size,
+      _buckets,
+      FLAGS.size,
+      FLAGS.num_layers,
+      FLAGS.max_gradient_norm,
+      FLAGS.batch_size,
+      FLAGS.learning_rate,
+      FLAGS.learning_rate_decay_factor,
+      forward_only=forward_only,
+      dtype=dtype)
+  ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
+  if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
+    print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
+    model.saver.restore(session, ckpt.model_checkpoint_path)
+  else:
+    print("Created model with fresh parameters.")
+    session.run(tf.global_variables_initializer())
+  return model
+
+
+def train():
+  """Train a en->fr translation model using WMT data."""
+  # Prepare WMT data.
+  print("Preparing WMT data in %s" % FLAGS.data_dir)
+  en_train, fr_train, en_dev, fr_dev, _, _ = data_utils.prepare_wmt_data(
+      FLAGS.data_dir, FLAGS.en_vocab_size, FLAGS.fr_vocab_size)
+
+  with tf.Session() as sess:
+    # Create model.
+    print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size))
+    model = create_model(sess, False)
+
+    # Read data into buckets and compute their sizes.
+    print ("Reading development and training data (limit: %d)."
+           % FLAGS.max_train_data_size)
+    dev_set = read_data(en_dev, fr_dev)
+    train_set = read_data(en_train, fr_train, FLAGS.max_train_data_size)
+    train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
+    train_total_size = float(sum(train_bucket_sizes))
+
+    # A bucket scale is a list of increasing numbers from 0 to 1 that we'll use
+    # to select a bucket. Length of [scale[i], scale[i+1]] is proportional to
+    # the size if i-th training bucket, as used later.
+    train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size
+                           for i in xrange(len(train_bucket_sizes))]
+
+    # This is the training loop.
+    step_time, loss = 0.0, 0.0
+    current_step = 0
+    previous_losses = []
+    while True:
+      # Choose a bucket according to data distribution. We pick a random number
+      # in [0, 1] and use the corresponding interval in train_buckets_scale.
+      random_number_01 = np.random.random_sample()
+      bucket_id = min([i for i in xrange(len(train_buckets_scale))
+                       if train_buckets_scale[i] > random_number_01])
+
+      # Get a batch and make a step.
+      start_time = time.time()
+      encoder_inputs, decoder_inputs, target_weights = model.get_batch(
+          train_set, bucket_id)
+      _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
+                                   target_weights, bucket_id, False)
+      step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint
+      loss += step_loss / FLAGS.steps_per_checkpoint
+      current_step += 1
+
+      # Once in a while, we save checkpoint, print statistics, and run evals.
+      if current_step % FLAGS.steps_per_checkpoint == 0:
+        # Print statistics for the previous epoch.
+        perplexity = math.exp(float(loss)) if loss < 300 else float("inf")
+        print ("global step %d learning rate %.4f step-time %.2f perplexity "
+               "%.2f" % (model.global_step.eval(), model.learning_rate.eval(),
+                         step_time, perplexity))
+        # Decrease learning rate if no improvement was seen over last 3 times.
+        if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
+          sess.run(model.learning_rate_decay_op)
+        previous_losses.append(loss)
+        # Save checkpoint and zero timer and loss.
+        checkpoint_path = os.path.join(FLAGS.train_dir, "translate.ckpt")
+        model.saver.save(sess, checkpoint_path, global_step=model.global_step)
+        step_time, loss = 0.0, 0.0
+        # Run evals on development set and print their perplexity.
+        for bucket_id in xrange(len(_buckets)):
+          if len(dev_set[bucket_id]) == 0:
+            print("  eval: empty bucket %d" % (bucket_id))
+            continue
+          encoder_inputs, decoder_inputs, target_weights = model.get_batch(
+              dev_set, bucket_id)
+          _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
+                                       target_weights, bucket_id, True)
+          eval_ppx = math.exp(float(eval_loss)) if eval_loss < 300 else float(
+              "inf")
+          print("  eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx))
+        sys.stdout.flush()
+
+
+def decode():
+  with tf.Session() as sess:
+    # Create model and load parameters.
+    model = create_model(sess, True)
+    model.batch_size = 1  # We decode one sentence at a time.
+
+    # Load vocabularies.
+    en_vocab_path = os.path.join(FLAGS.data_dir,
+                                 "vocab%d.en" % FLAGS.en_vocab_size)
+    fr_vocab_path = os.path.join(FLAGS.data_dir,
+                                 "vocab%d.fr" % FLAGS.fr_vocab_size)
+    en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
+    _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)
+
+    # Decode from standard input.
+    sys.stdout.write("> ")
+    sys.stdout.flush()
+    sentence = sys.stdin.readline()
+    while sentence:
+      # Get token-ids for the input sentence.
+      token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), en_vocab)
+      # Which bucket does it belong to?
+      bucket_id = len(_buckets) - 1
+      for i, bucket in enumerate(_buckets):
+        if bucket[0] >= len(token_ids):
+          bucket_id = i
+          break
+      else:
+        logging.warning("Sentence truncated: %s", sentence) 
+
+      # Get a 1-element batch to feed the sentence to the model.
+      encoder_inputs, decoder_inputs, target_weights = model.get_batch(
+          {bucket_id: [(token_ids, [])]}, bucket_id)
+      # Get output logits for the sentence.
+      _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
+                                       target_weights, bucket_id, True)
+      # This is a greedy decoder - outputs are just argmaxes of output_logits.
+      outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
+      # If there is an EOS symbol in outputs, cut them at that point.
+      if data_utils.EOS_ID in outputs:
+        outputs = outputs[:outputs.index(data_utils.EOS_ID)]
+      # Print out French sentence corresponding to outputs.
+      print(" ".join([tf.compat.as_str(rev_fr_vocab[output]) for output in outputs]))
+      print("> ", end="")
+      sys.stdout.flush()
+      sentence = sys.stdin.readline()
+
+
+def self_test():
+  """Test the translation model."""
+  with tf.Session() as sess:
+    print("Self-test for neural translation model.")
+    # Create model with vocabularies of 10, 2 small buckets, 2 layers of 32.
+    model = seq2seq_model.Seq2SeqModel(10, 10, [(3, 3), (6, 6)], 32, 2,
+                                       5.0, 32, 0.3, 0.99, num_samples=8)
+    sess.run(tf.global_variables_initializer())
+
+    # Fake data set for both the (3, 3) and (6, 6) bucket.
+    data_set = ([([1, 1], [2, 2]), ([3, 3], [4]), ([5], [6])],
+                [([1, 1, 1, 1, 1], [2, 2, 2, 2, 2]), ([3, 3, 3], [5, 6])])
+    for _ in xrange(5):  # Train the fake model for 5 steps.
+      bucket_id = random.choice([0, 1])
+      encoder_inputs, decoder_inputs, target_weights = model.get_batch(
+          data_set, bucket_id)
+      model.step(sess, encoder_inputs, decoder_inputs, target_weights,
+                 bucket_id, False)
+
+
+def main(_):
+  if FLAGS.self_test:
+    self_test()
+  elif FLAGS.decode:
+    decode()
+  else:
+    train()
+
+if __name__ == "__main__":
+  tf.app.run()