Pārlūkot izejas kodu

Adding Neural GPU code.

Lukasz Kaiser 9 gadi atpakaļ
vecāks
revīzija
fa65d8adc9

+ 41 - 0
neural_gpu/BUILD

@@ -0,0 +1,41 @@
+py_library(
+    name = "data_utils",
+    srcs = [
+        "data_utils.py",
+    ],
+    deps = [
+        "//file/colossus/public:cns",
+        "//third_party/py/numpy",
+        "//third_party/py/tensorflow",
+    ],
+)
+
+py_library(
+    name = "neural_gpu",
+    srcs = [
+        "neural_gpu.py",
+    ],
+    deps = [
+        ":data_utils",
+        "//third_party/py/numpy",
+        "//third_party/py/tensorflow",
+    ],
+)
+
+py_binary(
+    name = "neural_gpu_trainer",
+    srcs = [
+        "neural_gpu_trainer.py",
+    ],
+    launcher = "//devtools/python/launcher",
+    malloc = "//tcmalloc:tcmalloc_or_debug",
+    deps = [
+        ":neural_gpu",
+        "//file/colossus/public:cns",
+        "//net/proto2/python/public:use_fast_cpp_protos",
+        "//third_party/py/Tkinter",
+        "//third_party/py/matplotlib",
+        "//third_party/py/numpy",
+        "//third_party/py/tensorflow",
+    ],
+)

+ 4 - 0
neural_gpu/README.md

@@ -0,0 +1,4 @@
+# NeuralGPU
+Code for the Neural GPU model as described
+in [[http://arxiv.org/abs/1511.08228]].
+

+ 244 - 0
neural_gpu/data_utils.py

@@ -0,0 +1,244 @@
+"""Convolutional Gated Recurrent Networks for Algorithm Learning."""
+
+import math
+import random
+import sys
+import time
+
+import google3
+
+import numpy as np
+import tensorflow as tf
+
+from google3.third_party.tensorflow.python.platform import gfile
+
+FLAGS = tf.app.flags.FLAGS
+
+bins = [8, 16, 32, 64, 128]
+all_tasks = ["sort", "id", "rev", "incr", "left", "right", "left-shift", "add",
+             "right-shift", "bmul", "dup", "badd", "qadd"]
+forward_max = 128
+log_filename = ""
+
+
+def pad(l):
+  for b in bins:
+    if b >= l: return b
+  return forward_max
+
+
+train_set = {}
+test_set = {}
+for some_task in all_tasks:
+  train_set[some_task] = []
+  test_set[some_task] = []
+  for all_max_len in xrange(10000):
+    train_set[some_task].append([])
+    test_set[some_task].append([])
+
+
+def add(n1, n2, base=10):
+  """Add two numbers represented as lower-endian digit lists."""
+  k = max(len(n1), len(n2)) + 1
+  d1 = n1 + [0 for _ in xrange(k - len(n1))]
+  d2 = n2 + [0 for _ in xrange(k - len(n2))]
+  res = []
+  carry = 0
+  for i in xrange(k):
+    if d1[i] + d2[i] + carry < base:
+      res.append(d1[i] + d2[i] + carry)
+      carry = 0
+    else:
+      res.append(d1[i] + d2[i] + carry - base)
+      carry = 1
+  while res and res[-1] == 0:
+    res = res[:-1]
+  if res: return res
+  return [0]
+
+
+def init_data(task, length, nbr_cases, nclass):
+  """Data initialization."""
+  def rand_pair(l, task):
+    """Random data pair for a task. Total length should be <= l."""
+    k = (l-1)/2
+    base = 10
+    if task[0] == "b": base = 2
+    if task[0] == "q": base = 4
+    d1 = [np.random.randint(base) for _ in xrange(k)]
+    d2 = [np.random.randint(base) for _ in xrange(k)]
+    if task in ["add", "badd", "qadd"]:
+      res = add(d1, d2, base)
+    elif task in ["bmul"]:
+      d1n = sum([d * (base ** i) for i, d in enumerate(d1)])
+      d2n = sum([d * (base ** i) for i, d in enumerate(d2)])
+      res = [int(x) for x in list(reversed(str(bin(d1n * d2n))))[:-2]]
+    else:
+      sys.exit()
+    sep = [12]
+    if task in ["add", "badd", "qadd"]: sep = [11]
+    inp = [d + 1 for d in d1] + sep + [d + 1 for d in d2]
+    return inp, [r + 1 for r in res]
+
+  def rand_dup_pair(l):
+    """Random data pair for duplication task. Total length should be <= l."""
+    k = l/2
+    x = [np.random.randint(nclass - 1) + 1 for _ in xrange(k)]
+    inp = x + [0 for _ in xrange(l - k)]
+    res = x + x + [0 for _ in xrange(l - 2*k)]
+    return inp, res
+
+  def spec(inp):
+    """Return the target given the input for some tasks."""
+    if task == "sort":
+      return sorted(inp)
+    elif task == "id":
+      return inp
+    elif task == "rev":
+      return [i for i in reversed(inp)]
+    elif task == "incr":
+      carry = 1
+      res = []
+      for i in xrange(len(inp)):
+        if inp[i] + carry < nclass:
+          res.append(inp[i] + carry)
+          carry = 0
+        else:
+          res.append(1)
+          carry = 1
+      return res
+    elif task == "left":
+      return [inp[0]]
+    elif task == "right":
+      return [inp[-1]]
+    elif task == "left-shift":
+      return [inp[l-1] for l in xrange(len(inp))]
+    elif task == "right-shift":
+      return [inp[l+1] for l in xrange(len(inp))]
+    else:
+      print_out("Unknown spec for task " + str(task))
+      sys.exit()
+
+  l = length
+  cur_time = time.time()
+  total_time = 0.0
+  for case in xrange(nbr_cases):
+    total_time += time.time() - cur_time
+    cur_time = time.time()
+    if l > 10000 and case % 100 == 1:
+      print_out("  avg gen time %.4f s" % (total_time / float(case)))
+    if task in ["add", "badd", "qadd", "bmul"]:
+      i, t = rand_pair(l, task)
+      train_set[task][len(i)].append([i, t])
+      i, t = rand_pair(l, task)
+      test_set[task][len(i)].append([i, t])
+    elif task == "dup":
+      i, t = rand_dup_pair(l)
+      train_set[task][len(i)].append([i, t])
+      i, t = rand_dup_pair(l)
+      test_set[task][len(i)].append([i, t])
+    else:
+      inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
+      target = spec(inp)
+      train_set[task][l].append([inp, target])
+      inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
+      target = spec(inp)
+      test_set[task][l].append([inp, target])
+
+
+def get_batch(max_length, batch_size, do_train, task, offset=None, preset=None):
+  """Get a batch of data, training or testing."""
+  inputs = []
+  targets = []
+  length = max_length
+  if preset is None:
+    cur_set = test_set[task]
+    if do_train: cur_set = train_set[task]
+    while not cur_set[length]:
+      length -= 1
+  pad_length = pad(length)
+  for b in xrange(batch_size):
+    if preset is None:
+      elem = random.choice(cur_set[length])
+      if offset is not None and offset + b < len(cur_set[length]):
+        elem = cur_set[length][offset + b]
+    else:
+      elem = preset
+    inp, target = elem[0], elem[1]
+    assert len(inp) == length
+    inputs.append(inp + [0 for l in xrange(pad_length - len(inp))])
+    targets.append(target + [0 for l in xrange(pad_length - len(target))])
+  res_input = []
+  res_target = []
+  for l in xrange(pad_length):
+    new_input = np.array([inputs[b][l] for b in xrange(batch_size)],
+                         dtype=np.int32)
+    new_target = np.array([targets[b][l] for b in xrange(batch_size)],
+                          dtype=np.int32)
+    res_input.append(new_input)
+    res_target.append(new_target)
+  return res_input, res_target
+
+
+def print_out(s, newline=True):
+  """Print a message out and log it to file."""
+  if log_filename:
+    try:
+      with gfile.GFile(log_filename, mode="a") as f:
+        f.write(s + ("\n" if newline else ""))
+    # pylint: disable=bare-except
+    except:
+      sys.stdout.write("Error appending to %s\n" % log_filename)
+  sys.stdout.write(s + ("\n" if newline else ""))
+  sys.stdout.flush()
+
+
+def decode(output):
+  return [np.argmax(o, axis=1) for o in output]
+
+
+def accuracy(inpt, output, target, batch_size, nprint):
+  """Calculate output accuracy given target."""
+  assert nprint < batch_size + 1
+  def task_print(inp, output, target):
+    stop_bound = 0
+    print_len = 0
+    while print_len < len(target) and target[print_len] > stop_bound:
+      print_len += 1
+    print_out("    i: " + " ".join([str(i - 1) for i in inp if i > 0]))
+    print_out("    o: " +
+              " ".join([str(output[l] - 1) for l in xrange(print_len)]))
+    print_out("    t: " +
+              " ".join([str(target[l] - 1) for l in xrange(print_len)]))
+  decoded_target = target
+  decoded_output = decode(output)
+  total = 0
+  errors = 0
+  seq = [0 for b in xrange(batch_size)]
+  for l in xrange(len(decoded_output)):
+    for b in xrange(batch_size):
+      if decoded_target[l][b] > 0:
+        total += 1
+        if decoded_output[l][b] != decoded_target[l][b]:
+          seq[b] = 1
+          errors += 1
+  e = 0  # Previous error index
+  for _ in xrange(min(nprint, sum(seq))):
+    while seq[e] == 0:
+      e += 1
+    task_print([inpt[l][e] for l in xrange(len(inpt))],
+               [decoded_output[l][e] for l in xrange(len(decoded_target))],
+               [decoded_target[l][e] for l in xrange(len(decoded_target))])
+    e += 1
+  for b in xrange(nprint - errors):
+    task_print([inpt[l][b] for l in xrange(len(inpt))],
+               [decoded_output[l][b] for l in xrange(len(decoded_target))],
+               [decoded_target[l][b] for l in xrange(len(decoded_target))])
+  return errors, total, sum(seq)
+
+
+def safe_exp(x):
+  perp = 10000
+  if x < 100: perp = math.exp(x)
+  if perp > 10000: return 10000
+  return perp

+ 271 - 0
neural_gpu/neural_gpu.py

@@ -0,0 +1,271 @@
+"""The Neural GPU Model."""
+
+import time
+
+import google3
+
+import tensorflow as tf
+
+from google3.experimental.users.lukaszkaiser.neural_gpu import data_utils
+
+
+def conv_linear(args, kw, kh, nin, nout, do_bias, bias_start, prefix):
+  """Convolutional linear map."""
+  assert args
+  if not isinstance(args, (list, tuple)):
+    args = [args]
+  with tf.variable_scope(prefix):
+    k = tf.get_variable("CvK", [kw, kh, nin, nout])
+    if len(args) == 1:
+      res = tf.nn.conv2d(args[0], k, [1, 1, 1, 1], "SAME")
+    else:
+      res = tf.nn.conv2d(tf.concat(3, args), k, [1, 1, 1, 1], "SAME")
+    if not do_bias: return res
+    bias_term = tf.get_variable("CvB", [nout],
+                                initializer=tf.constant_initializer(0.0))
+    return res + bias_term + bias_start
+
+
+def sigmoid_cutoff(x, cutoff):
+  """Sigmoid with cutoff, e.g., 1.2sigmoid(x) - 0.1."""
+  y = tf.sigmoid(x)
+  if cutoff < 1.01: return y
+  d = (cutoff - 1.0) / 2.0
+  return tf.minimum(1.0, tf.maximum(0.0, cutoff * y - d))
+
+
+def conv_gru(inpts, mem, kw, kh, nmaps, cutoff, prefix):
+  """Convolutional GRU."""
+  def conv_lin(args, suffix, bias_start):
+    return conv_linear(args, kw, kh, len(args) * nmaps, nmaps, True, bias_start,
+                       prefix + "/" + suffix)
+  reset = sigmoid_cutoff(conv_lin(inpts + [mem], "r", 1.0), cutoff)
+  candidate = tf.tanh(conv_lin(inpts + [reset * mem], "c", 0.0))
+  gate = sigmoid_cutoff(conv_lin(inpts + [mem], "g", 1.0), cutoff)
+  return gate * mem + (1 - gate) * candidate
+
+
+def relaxed_average(var_name_suffix, rx_step):
+  """Calculate the average of relaxed variables having var_name_suffix."""
+  relaxed_vars = []
+  for l in xrange(rx_step):
+    with tf.variable_scope("RX%d" % l, reuse=True):
+      try:
+        relaxed_vars.append(tf.get_variable(var_name_suffix))
+      except ValueError:
+        pass
+  dsum = tf.add_n(relaxed_vars)
+  avg = dsum / len(relaxed_vars)
+  diff = [v - avg for v in relaxed_vars]
+  davg = tf.add_n([d*d for d in diff])
+  return avg, tf.reduce_sum(davg)
+
+
+def relaxed_distance(rx_step):
+  """Distance between relaxed variables and their average."""
+  res, ops, rx_done = [], [], {}
+  for v in tf.trainable_variables():
+    if v.name[0:2] == "RX":
+      rx_name = v.op.name[v.name.find("/") + 1:]
+      if rx_name not in rx_done:
+        avg, dist_loss = relaxed_average(rx_name, rx_step)
+        res.append(dist_loss)
+        rx_done[rx_name] = avg
+      ops.append(v.assign(rx_done[rx_name]))
+  return tf.add_n(res), tf.group(*ops)
+
+
+def make_dense(targets, noclass):
+  """Move a batch of targets to a dense 1-hot representation."""
+  with tf.device("/cpu:0"):
+    shape = tf.shape(targets)
+    batch_size = shape[0]
+    indices = targets + noclass * tf.range(0, batch_size)
+    length = batch_size * noclass
+    dense = tf.sparse_to_dense(indices, length, 1.0, 0.0)
+  return tf.reshape(dense, [-1, noclass])
+
+
+def check_for_zero(sparse):
+  """In a sparse batch of ints, make 1.0 if it's 0 and 0.0 else."""
+  with tf.device("/cpu:0"):
+    shape = tf.shape(sparse)
+    batch_size = shape[0]
+    sparse = tf.minimum(sparse, 1)
+    indices = sparse + 2 * tf.range(0, batch_size)
+    dense = tf.sparse_to_dense(indices, 2 * batch_size, 1.0, 0.0)
+    reshaped = tf.reshape(dense, [-1, 2])
+  return tf.reshape(tf.slice(reshaped, [0, 0], [-1, 1]), [-1])
+
+
+class NeuralGPU(object):
+  """Neural GPU Model."""
+
+  def __init__(self, nmaps, vec_size, niclass, noclass, dropout, rx_step,
+               max_grad_norm, cutoff, nconvs, kw, kh, height, mode,
+               learning_rate, pull, pull_incr, min_length):
+    # Feeds for parameters and ops to update them.
+    self.global_step = tf.Variable(0, trainable=False)
+    self.cur_length = tf.Variable(min_length, trainable=False)
+    self.cur_length_incr_op = self.cur_length.assign_add(1)
+    self.lr = tf.Variable(float(learning_rate), trainable=False)
+    self.lr_decay_op = self.lr.assign(self.lr * 0.98)
+    self.pull = tf.Variable(float(pull), trainable=False)
+    self.pull_incr_op = self.pull.assign(self.pull * pull_incr)
+    self.do_training = tf.placeholder(tf.float32, name="do_training")
+    self.noise_param = tf.placeholder(tf.float32, name="noise_param")
+
+    # Feeds for inputs, targets, outputs, losses, etc.
+    self.input = []
+    self.target = []
+    for l in xrange(data_utils.forward_max + 1):
+      self.input.append(tf.placeholder(tf.int32, name="inp{0}".format(l)))
+      self.target.append(tf.placeholder(tf.int32, name="tgt{0}".format(l)))
+    self.outputs = []
+    self.losses = []
+    self.grad_norms = []
+    self.updates = []
+
+    # Computation.
+    inp0_shape = tf.shape(self.input[0])
+    batch_size = inp0_shape[0]
+    with tf.device("/cpu:0"):
+      emb_weights = tf.get_variable(
+          "embedding", [niclass, vec_size],
+          initializer=tf.random_uniform_initializer(-1.7, 1.7))
+      e0 = tf.scatter_update(emb_weights,
+                             tf.constant(0, dtype=tf.int32, shape=[1]),
+                             tf.zeros([1, vec_size]))
+
+    adam = tf.train.AdamOptimizer(0.01*self.lr, epsilon=1e-5)
+
+    # Main graph creation loop, for every bin in data_utils.
+    self.steps = []
+    for length in sorted(list(set(data_utils.bins + [data_utils.forward_max]))):
+      data_utils.print_out("Creating model for bin of length %d." % length)
+      start_time = time.time()
+      if length > data_utils.bins[0]:
+        tf.get_variable_scope().reuse_variables()
+
+      # Embed inputs and calculate mask.
+      with tf.device("/cpu:0"):
+        with tf.control_dependencies([e0]):
+          embedded = [tf.nn.embedding_lookup(emb_weights, self.input[l])
+                      for l in xrange(length)]
+        # Mask to 0-out padding space in each step.
+        imask = [check_for_zero(self.input[l]) for l in xrange(length)]
+        omask = [check_for_zero(self.target[l]) for l in xrange(length)]
+        mask = [1.0 - (imask[i] * omask[i]) for i in xrange(length)]
+        mask = [tf.reshape(m, [-1, 1]) for m in mask]
+        # Use a shifted mask for step scaling and concatenated for weights.
+        shifted_mask = mask + [tf.zeros_like(mask[0])]
+        scales = [shifted_mask[i] * (1.0 - shifted_mask[i+1])
+                  for i in xrange(length)]
+        scales = [tf.reshape(s, [-1, 1, 1, 1]) for s in scales]
+        mask = tf.concat(1, mask[0:length])  # batch x length
+        weights = mask
+        # Add a height dimension to mask to use later for masking.
+        mask = tf.reshape(mask, [-1, length, 1, 1])
+        mask = tf.concat(2, [mask for _ in xrange(height)]) + tf.zeros(
+            tf.pack([batch_size, length, height, nmaps]), dtype=tf.float32)
+
+      # Start is a length-list of batch-by-nmaps tensors, reshape and concat.
+      start = [tf.tanh(embedded[l]) for l in xrange(length)]
+      start = [tf.reshape(start[l], [-1, 1, nmaps]) for l in xrange(length)]
+      start = tf.reshape(tf.concat(1, start), [-1, length, 1, nmaps])
+
+      # First image comes from start by applying one convolution and adding 0s.
+      first = conv_linear(start, 1, 1, vec_size, nmaps, True, 0.0, "input")
+      first = [first] + [tf.zeros(tf.pack([batch_size, length, 1, nmaps]),
+                                  dtype=tf.float32) for _ in xrange(height - 1)]
+      first = tf.concat(2, first)
+
+      # Computation steps.
+      step = [tf.nn.dropout(first, 1.0 - self.do_training * dropout) * mask]
+      outputs = []
+      for it in xrange(length):
+        with tf.variable_scope("RX%d" % (it % rx_step)) as vs:
+          if it >= rx_step:
+            vs.reuse_variables()
+          cur = step[it]
+          # Do nconvs-many CGRU steps.
+          for layer in xrange(nconvs):
+            cur = conv_gru([], cur, kw, kh, nmaps, cutoff, "cgru_%d" % layer)
+          cur = tf.nn.dropout(cur, 1.0 - self.do_training * dropout)
+          step.append(cur * mask)
+          outputs.append(tf.slice(step[-1], [0, 0, 0, 0], [-1, -1, 1, -1]))
+
+      self.steps.append([tf.reshape(s, [-1, length, height * nmaps])
+                         for s in step])
+      # Output is the n-th step output; n = current length, as in scales.
+      output = tf.add_n([outputs[i] * scales[i] for i in xrange(length)])
+      # Final convolution to get logits, list outputs.
+      output = conv_linear(output, 1, 1, nmaps, noclass, True, 0.0, "output")
+      output = tf.reshape(output, [-1, length, noclass])
+      self.outputs.append([tf.reshape(o, [-1, noclass])
+                           for o in list(tf.split(1, length, output))])
+
+      # Calculate cross-entropy loss and normalize it.
+      targets = tf.concat(1, [make_dense(self.target[l], noclass)
+                              for l in xrange(length)])
+      targets = tf.reshape(targets, [-1, noclass])
+      xent = tf.reshape(tf.nn.softmax_cross_entropy_with_logits(
+          tf.reshape(output, [-1, noclass]), targets), [-1, length])
+      perp_loss = tf.reduce_sum(xent * weights)
+      perp_loss /= tf.cast(batch_size, dtype=tf.float32)
+      perp_loss /= length
+
+      # Final loss: cross-entropy + shared parameter relaxation part.
+      relax_dist, self.avg_op = relaxed_distance(rx_step)
+      total_loss = perp_loss + relax_dist * self.pull
+      self.losses.append(perp_loss)
+
+      # Gradients and Adam update operation.
+      if length == data_utils.bins[0] or (mode == 0 and
+                                          length < data_utils.bins[-1] + 1):
+        data_utils.print_out("Creating backward for bin of length %d." % length)
+        params = tf.trainable_variables()
+        grads = tf.gradients(total_loss, params)
+        grads, norm = tf.clip_by_global_norm(grads, max_grad_norm)
+        self.grad_norms.append(norm)
+        for grad in grads:
+          if isinstance(grad, tf.Tensor):
+            grad += tf.truncated_normal(tf.shape(grad)) * self.noise_param
+        update = adam.apply_gradients(zip(grads, params),
+                                      global_step=self.global_step)
+        self.updates.append(update)
+      data_utils.print_out("Created model for bin of length %d in"
+                           " %.2f s." % (length, time.time() - start_time))
+    self.saver = tf.train.Saver(tf.all_variables())
+
+  def step(self, sess, inp, target, do_backward, noise_param=None):
+    """Run a step of the network."""
+    assert len(inp) == len(target)
+    length = len(target)
+    feed_in = {}
+    feed_in[self.noise_param.name] = noise_param if noise_param else 0.0
+    feed_in[self.do_training.name] = 1.0 if do_backward else 0.0
+    feed_out = []
+    index = len(data_utils.bins)
+    if length < data_utils.bins[-1] + 1:
+      index = data_utils.bins.index(length)
+    if do_backward:
+      feed_out.append(self.updates[index])
+      feed_out.append(self.grad_norms[index])
+    feed_out.append(self.losses[index])
+    for l in xrange(length):
+      feed_in[self.input[l].name] = inp[l]
+    for l in xrange(length):
+      feed_in[self.target[l].name] = target[l]
+      feed_out.append(self.outputs[index][l])
+    for l in xrange(length+1):
+      feed_out.append(self.steps[index][l])
+    res = sess.run(feed_out, feed_in)
+    offset = 0
+    norm = None
+    if do_backward:
+      offset = 2
+      norm = res[1]
+    outputs = res[offset + 1:offset + 1 + length]
+    steps = res[offset + 1 + length:]
+    return res[offset], outputs, norm, steps

+ 376 - 0
neural_gpu/neural_gpu_trainer.py

@@ -0,0 +1,376 @@
+"""Neural GPU for Learning Algorithms."""
+
+import math
+import os
+import random
+import sys
+import time
+
+import google3
+
+import matplotlib.animation as anim
+import matplotlib.pyplot as plt
+import numpy as np
+import tensorflow as tf
+
+from google3.third_party.tensorflow.python.platform import gfile
+import google3.experimental.users.lukaszkaiser.neural_gpu.data_utils as data
+import google3.experimental.users.lukaszkaiser.neural_gpu.neural_gpu as ngpu
+
+tf.app.flags.DEFINE_float("lr", 0.1, "Learning rate.")
+tf.app.flags.DEFINE_float("init_weight", 1.0, "Initial weights deviation.")
+tf.app.flags.DEFINE_float("max_grad_norm", 0.05, "Clip gradients to this norm.")
+tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.")
+tf.app.flags.DEFINE_float("pull", 0.0005, "Starting pull of the relaxations.")
+tf.app.flags.DEFINE_float("pull_incr", 1.2, "Increase pull by that much.")
+tf.app.flags.DEFINE_float("dropout", 0.2, "Dropout that much.")
+tf.app.flags.DEFINE_float("grad_noise_scale", 1.0, "Gradient noise scale.")
+tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size.")
+tf.app.flags.DEFINE_integer("low_batch_size", 16, "Low batch size.")
+tf.app.flags.DEFINE_integer("steps_per_checkpoint", 100, "Steps per epoch.")
+tf.app.flags.DEFINE_integer("nmaps", 24, "Number of floats in each cell.")
+tf.app.flags.DEFINE_integer("niclass", 14, "Number of classes (0 is padding).")
+tf.app.flags.DEFINE_integer("noclass", 14, "Number of classes (0 is padding).")
+tf.app.flags.DEFINE_integer("train_data_size", 5000, "Training examples/len.")
+tf.app.flags.DEFINE_integer("max_length", 41, "Maximum length.")
+tf.app.flags.DEFINE_integer("rx_step", 6, "Relax that many recursive steps.")
+tf.app.flags.DEFINE_integer("random_seed", 125459, "Random seed.")
+tf.app.flags.DEFINE_integer("nconvs", 2, "How many convolutions / 1 step.")
+tf.app.flags.DEFINE_integer("kw", 3, "Kernel width.")
+tf.app.flags.DEFINE_integer("kh", 3, "Kernel height.")
+tf.app.flags.DEFINE_integer("height", 4, "Height.")
+tf.app.flags.DEFINE_integer("forward_max", 401, "Maximum forward length.")
+tf.app.flags.DEFINE_integer("jobid", -1, "Task id when running on borg.")
+tf.app.flags.DEFINE_integer("nprint", 0, "How many test examples to print out.")
+tf.app.flags.DEFINE_integer("mode", 0, "Mode: 0-train other-decode.")
+tf.app.flags.DEFINE_string("task", "rev", "Which task are we learning?")
+tf.app.flags.DEFINE_string("train_dir", "/tmp/", "Directory to store models.")
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def initialize(sess):
+  """Initialize data and model."""
+  if FLAGS.jobid >= 0:
+    data.log_filename = os.path.join(FLAGS.train_dir, "log%d" % FLAGS.jobid)
+  data.print_out("NN ", newline=False)
+
+  # Set random seed.
+  seed = FLAGS.random_seed + max(0, FLAGS.jobid)
+  tf.set_random_seed(seed)
+  random.seed(seed)
+  np.random.seed(seed)
+
+  # Check data sizes.
+  data.forward_max = max(FLAGS.forward_max, data.bins[-1])
+  assert data.bins
+  min_length = 3
+  max_length = min(FLAGS.max_length, data.bins[-1])
+  assert max_length + 1 > min_length
+  while len(data.bins) > 1 and data.bins[-2] > max_length + 12:
+    data.bins = data.bins[:-1]
+  assert data.bins[0] > FLAGS.rx_step
+  nclass = min(FLAGS.niclass, FLAGS.noclass)
+  data_size = FLAGS.train_data_size if FLAGS.mode == 0 else 1000
+
+  # Initialize data for each task.
+  tasks = FLAGS.task.split("-")
+  for t in tasks:
+    for l in xrange(max_length + 11):
+      data.init_data(t, l, data_size, nclass)
+    data.init_data(t, data.bins[-2], data_size, nclass)
+    data.init_data(t, data.bins[-1], data_size, nclass)
+    end_size = 4 * 1024 if FLAGS.mode > 0 else 1024
+    data.init_data(t, data.forward_max, end_size, nclass)
+
+  # Print out parameters.
+  curriculum = 0.12
+  fin = ("cv %d kw %d h %d kh %d rxr %d bs %d ns %.2f t %s"
+         % (FLAGS.nconvs, FLAGS.kw, FLAGS.height, FLAGS.kh, FLAGS.rx_step,
+            FLAGS.batch_size, FLAGS.grad_noise_scale, FLAGS.task))
+  fin = "data %d %s" % (FLAGS.train_data_size, fin)
+  tag = ("df %.2f p %.3f lr %.2f iw %.2f cr %.2f nm %d d%.4f gn %.2f %s" %
+         (FLAGS.cutoff, FLAGS.pull_incr, FLAGS.lr, FLAGS.init_weight,
+          curriculum, FLAGS.nmaps, FLAGS.dropout, FLAGS.max_grad_norm, fin))
+  data.print_out(tag)
+
+  # Create checkpoint directory if it does not exist.
+  checkpoint_dir = os.path.join(FLAGS.train_dir, "neural_gpu%s"
+                                % ("" if FLAGS.jobid < 0 else str(FLAGS.jobid)))
+  if not gfile.IsDirectory(checkpoint_dir):
+    data.print_out("Creating checkpoint directory %s." % checkpoint_dir)
+    gfile.MkDir(checkpoint_dir)
+
+  # Create model and initialize it.
+  tf.get_variable_scope().set_initializer(
+      tf.uniform_unit_scaling_initializer(factor=1.8 * FLAGS.init_weight))
+  model = ngpu.NeuralGPU(
+      FLAGS.nmaps, FLAGS.nmaps, FLAGS.niclass, FLAGS.noclass, FLAGS.dropout,
+      FLAGS.rx_step, FLAGS.max_grad_norm, FLAGS.cutoff, FLAGS.nconvs,
+      FLAGS.kw, FLAGS.kh, FLAGS.height, FLAGS.mode, FLAGS.lr,
+      FLAGS.pull, FLAGS.pull_incr, min_length + 3)
+  data.print_out("Created model.")
+  sess.run(tf.initialize_all_variables())
+  data.print_out("Initialized variables.")
+
+  # Load model from parameters if a checkpoint exists.
+  ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
+  if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
+    data.print_out("Reading model parameters from %s"
+                   % ckpt.model_checkpoint_path)
+    model.saver.restore(sess, ckpt.model_checkpoint_path)
+
+  # Return the model and needed variables.
+  return (model, min_length, max_length, checkpoint_dir, curriculum)
+
+
+def single_test(l, model, sess, task, nprint, batch_size, print_out=True,
+                offset=None):
+  """Test model on test data of length l using the given session."""
+  inpt, target = data.get_batch(l, batch_size, False, task, offset)
+  _, res, _, steps = model.step(sess, inpt, target, False)
+  errors, total, seq = data.accuracy(inpt, res, target, batch_size, nprint)
+  seq = float(seq) / batch_size
+  if total > 0:
+    errors = float(errors) / total
+  if print_out:
+    data.print_out("  %s len %d errors %.2f sequence-errors %.2f"
+                   % (task, l, 100*errors, 100*seq))
+  return errors, seq, (steps, inpt, [np.argmax(o, axis=1) for o in res])
+
+
+def multi_test(l, model, sess, task, nprint, batch_size, offset=None):
+  """Run multiple tests at lower batch size to save memory."""
+  errors = 0.0
+  seq = 0.0
+  to_print = nprint
+  low_batch = FLAGS.low_batch_size
+  low_batch = min(low_batch, batch_size)
+  for mstep in xrange(batch_size / low_batch):
+    cur_offset = None if offset is None else offset + mstep * low_batch
+    err, sq, _ = single_test(l, model, sess, task, to_print, low_batch, False,
+                             cur_offset)
+    to_print = max(0, to_print - low_batch)
+    errors += err
+    seq += sq
+    if FLAGS.mode > 0:
+      cur_errors = float(low_batch * errors) / ((mstep+1) * low_batch)
+      cur_seq = float(low_batch * seq) / ((mstep+1) * low_batch)
+      data.print_out("    %s multitest current errors %.2f sequence-errors %.2f"
+                     % (task, 100*cur_errors, 100*cur_seq))
+  errors = float(low_batch) * float(errors) / batch_size
+  seq = float(low_batch) * float(seq) / batch_size
+  data.print_out("  %s len %d errors %.2f sequence-errors %.2f"
+                 % (task, l, 100*errors, 100*seq))
+  return errors, seq
+
+
+def train():
+  """Main training function."""
+  batch_size = FLAGS.batch_size
+  tasks = FLAGS.task.split("-")
+  with tf.Session() as sess:
+    model, min_length, max_length, checkpoint_dir, curriculum = initialize(sess)
+    max_cur_length = min(min_length + 3, max_length)
+    prev_acc_perp = [1000000 for _ in xrange(3)]
+    prev_sq = 1.0
+
+    while True:
+      global_step, pull, max_cur_length, learning_rate = sess.run(
+          [model.global_step, model.pull, model.cur_length, model.lr])
+      ep = global_step / FLAGS.steps_per_checkpoint
+      acc_loss, acc_total, acc_errors, acc_seq = 0.0, 0, 0, 0
+      acc_grad_norm, step_count, step_time = 0.0, 0, 0.0
+      for _ in xrange(FLAGS.steps_per_checkpoint):
+        global_step += 1
+        task = random.choice(tasks)
+        l1 = np.random.randint(max_cur_length - min_length + 1) + min_length
+        l = l1
+        if np.random.randint(10) > 3:  # Prefer longer stuff 60% of time.
+          l = np.random.randint(max_cur_length - min_length+1) + min_length
+          l = max(l, l1)
+        if np.random.randint(4) < 1:  # Mixed learning: once in a while big.
+          l = np.random.randint(max_length - min_length + 1) + min_length
+          l = max(l, l1)
+        start_time = time.time()
+        inp, target = data.get_batch(l, batch_size, True, task)
+        stepp = math.pow(global_step, -0.55)
+        noise_param = math.sqrt(stepp * 20 * prev_sq) * FLAGS.grad_noise_scale
+        loss, res, gnorm, _ = model.step(sess, inp, target, True, noise_param)
+        step_time += time.time() - start_time
+        acc_grad_norm += float(gnorm)
+        if l < max_cur_length + 1:
+          step_count += 1
+          acc_loss += loss
+          errors, total, seq = data.accuracy(inp, res, target,
+                                             batch_size, 0)
+          acc_total += total
+          acc_errors += errors
+          acc_seq += seq
+      acc_loss /= step_count
+      step_time /= FLAGS.steps_per_checkpoint
+      acc_seq = float(acc_seq) / (step_count * batch_size)
+      prev_sq = acc_seq
+      acc_errors = float(acc_errors) / acc_total if acc_total > 0 else 1.0
+      msg1 = "ep %d st %.2f lr %.8f" % (ep, step_time, learning_rate)
+      msg2 = "pl %.3f cme %.3f" % (pull, curriculum)
+      msg = ("%s %s gn %.8f"
+             % (msg1, msg2, acc_grad_norm / FLAGS.steps_per_checkpoint))
+      data.print_out("%s len %d ppx %.8f errs %.2f sq %.2f" %
+                     (msg, max_cur_length, data.safe_exp(acc_loss),
+                      100*acc_errors, 100*acc_seq))
+      if curriculum > acc_seq:
+        prev_acc_perp.append(1000000)
+        do_incr = True
+        while do_incr and max_cur_length < max_length:
+          sess.run(model.cur_length_incr_op)
+          for t in tasks:
+            if data.train_set[t]: do_incr = False
+        if pull < 1:
+          sess.run(model.pull_incr_op)
+        else:
+          data.print_out("  Averaging parameters.")
+          sess.run([model.avg_op, model.lr_decay_op])
+      else:
+        acc_perp = data.safe_exp(acc_loss)
+        if acc_perp > max(prev_acc_perp[-3:]):
+          sess.run(model.lr_decay_op)
+        prev_acc_perp.append(acc_perp)
+      checkpoint_path = os.path.join(checkpoint_dir, "neural_gpu.ckpt")
+      model.saver.save(sess, checkpoint_path,
+                       global_step=model.global_step)
+      # Run evaluation.
+      should_exit = True
+      bound = data.bins[-1] + 1
+      for t in tasks:
+        l = min_length
+        while l < max_length + 12 and l < bound:
+          _, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
+          l += 1
+          while l < bound + 1 and not data.test_set[t][l]:
+            l += 1
+        if sq < 0.5:
+          _, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
+                             batch_size * 4)
+        if sq > 0.001: should_exit = False
+      if should_exit:
+        if data.forward_max > 4000 and len(tasks) == 1:
+          multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
+                     batch_size * 16, 0)
+
+
+def animate(l, test_data, anim_size):
+  """Create animation for the given data (hacky matplotlib use)."""
+  xf = 12
+  fps = 2
+  fig = plt.figure(figsize=(16, 9), facecolor="white")
+  ax = fig.add_axes([0, 0, 1, 1], frameon=False, zorder=2)
+  ax.set_xticks([i * 24-0.5 for i in xrange(4)])
+  ax.set_xticklabels([])
+  ax.set_yticks([i - 0.5 for i in xrange(l+1)])
+  ax.grid(which="major", axis="both", linestyle="-", color="black")
+  text_fields = []
+  text_size = 24*32/l
+  for y in xrange(l):
+    text_fields.append(ax.text(
+        11.25, y + 0.15, "", color="g", ha="center", va="center",
+        bbox={"facecolor": "b", "alpha": 0.01, "pad": 24 * text_size},
+        size=text_size - (4 * 32 / l), animated=True))
+  im = ax.imshow(np.zeros_like(test_data[0][0][0]), vmin=-1.0,
+                 vmax=1.0, cmap="gray", aspect="auto", origin="upper",
+                 interpolation="none", animated=True)
+  im.set_zorder(1)
+  def to_symbol(i):
+    if i == 0: return ""
+    if i == 11: return "+"
+    if i == 12: return "*"
+    return str(i-1)
+  def animation_update(frame_no, test_data, xf, im, text_fields):
+    """Update an animation frame."""
+    steps, inpt, out_raw = test_data
+    length = len(steps)
+    batch = frame_no / (fps * (l+4*xf))
+    index = int((frame_no % (fps * (l+4*xf))) / fps)
+    # Cut output after first padding.
+    out = [out_raw[i][batch] for i in xrange(len(text_fields))]
+    if 0 in out:
+      i = out.index(0)
+      out = out[0:i] + [0 for _ in xrange(len(out) - i)]
+    # Show the state after the first frames.
+    if index >= 2*xf:
+      im.set_array(steps[min(length - 1, index - 2*xf)][batch])
+      for i, t in enumerate(text_fields):
+        if index - 2*xf < length:
+          t.set_text("")
+        else:
+          t.set_text(to_symbol(out[i]))
+    else:
+      for i, t in enumerate(text_fields):
+        t.set_text(to_symbol(inpt[i][batch]) if index < xf else "")
+      if index < xf:
+        im.set_array(np.zeros_like(steps[0][0]))
+      else:
+        im.set_array(steps[0][batch])
+    return im,
+  animation = anim.FuncAnimation(
+      fig, animation_update, blit=True, frames=(l+4*xf)*anim_size*fps,
+      interval=500/fps, fargs=(test_data, xf, im, text_fields))
+  animation.save("/tmp/neural_gpu.mp4", writer="mencoder", fps=4*fps, dpi=3*80)
+
+
+def evaluate():
+  """Evaluate an existing model."""
+  batch_size = FLAGS.batch_size
+  tasks = FLAGS.task.split("-")
+  with tf.Session() as sess:
+    model, min_length, max_length, _, _ = initialize(sess)
+    bound = data.bins[-1] + 1
+    for t in tasks:
+      l = min_length
+      while l < max_length + 12 and l < bound:
+        _, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
+        l += 1
+        while l < bound + 1 and not data.test_set[t][l]:
+          l += 1
+      # Animate.
+      anim_size = 2
+      _, _, test_data = single_test(l, model, sess, t, 0, anim_size)
+      animate(l, test_data, anim_size)
+      # More tests.
+      _, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
+                         batch_size * 4)
+    if sq < 0.01:  # More tests.
+      if data.forward_max > 4000 and len(tasks) == 1:
+        multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
+                   batch_size * 64, 0)
+
+
+def interactive():
+  """Interactively probe an existing model."""
+  with tf.Session() as sess:
+    model, _, _, _, _ = initialize(sess)
+    sys.stdout.write("> ")
+    sys.stdout.flush()
+    inpt = sys.stdin.readline()
+    while inpt:
+      ids = [int(c) for c in inpt.strip()]
+      inpt, target = data.get_batch(len(ids), 1, False, "",
+                                    preset=(ids, [0 for _ in ids]))
+      _, res, _, _ = model.step(sess, inpt, target, False)
+      res = [np.argmax(o, axis=1) for o in res]
+      print " ".join([str(output[0]) for output in res])
+      sys.stdout.write("> ")
+      sys.stdout.flush()
+      inpt = sys.stdin.readline()
+
+
+def main(_):
+  if FLAGS.mode == 0:
+    train()
+  elif FLAGS.mode == 1:
+    evaluate()
+  else:
+    interactive()
+
+if __name__ == "__main__":
+  tf.app.run()