Przeglądaj źródła

Update to the Neural GPU.

Lukasz Kaiser 8 lat temu
rodzic
commit
a315e5681d

+ 21 - 12
neural_gpu/README.md

@@ -4,7 +4,6 @@ in [[http://arxiv.org/abs/1511.08228]].
 
 Requirements:
 * TensorFlow (see tensorflow.org for how to install)
-* Matplotlib for Python (sudo apt-get install python-matplotlib)
 
 The model can be trained on the following algorithmic tasks:
 
@@ -26,17 +25,27 @@ The model can be trained on the following algorithmic tasks:
 * `qadd` - Long quaternary addition
 * `search` - Search for symbol key in dictionary
 
-The value range for symbols are defined by the `niclass` and `noclass` flags.
-In particular, the values are in the range `min(--niclass, noclass) - 1`.
-So if you set `--niclass=33` and `--noclass=33` (the default) then `--task=rev`
-will be reversing lists of 32 symbols, and `--task=id` will be identity on a
-list of up to 32 symbols.
+It can also be trained on the WMT English-French translation task:
 
+* `wmt` - WMT English-French translation (data will be downloaded)
 
-To train the model on the reverse task run:
+The value range for symbols are defined by the `vocab_size` flag.
+In particular, the values are in the range `vocab_size - 1`.
+So if you set `--vocab_size=16` (the default) then `--problem=rev`
+will be reversing lists of 15 symbols, and `--problem=id` will be identity
+on a list of up to 15 symbols.
+
+
+To train the model on the binary multiplication task run:
+
+```
+python neural_gpu_trainer.py --problem=bmul
+```
+
+This trains the Extended Neural GPU, to train the original model run:
 
 ```
-python neural_gpu_trainer.py --task=rev
+python neural_gpu_trainer.py --problem=bmul --beam_size=0
 ```
 
 While training, interim / checkpoint model parameters will be
@@ -47,16 +56,16 @@ with, hit `Ctrl-C` to stop the training process. The latest
 model parameters will be in `/tmp/neural_gpu/neural_gpu.ckpt-<step>`
 and used on any subsequent run.
 
-To test a trained model on how well it decodes run:
+To evaluate a trained model on how well it decodes run:
 
 ```
-python neural_gpu_trainer.py --task=rev --mode=1
+python neural_gpu_trainer.py --problem=bmul --mode=1
 ```
 
-To produce an animation of the result run:
+To interact with a model (experimental, see code) run:
 
 ```
-python neural_gpu_trainer.py --task=rev --mode=1 --animate=True
+python neural_gpu_trainer.py --problem=bmul --mode=2
 ```
 
 Maintained by Lukasz Kaiser (lukaszkaiser)

+ 193 - 52
neural_gpu/data_utils.py

@@ -12,9 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Convolutional Gated Recurrent Networks for Algorithm Learning."""
+"""Neural GPU -- data generation and batching utilities."""
 
 import math
+import os
 import random
 import sys
 import time
@@ -22,22 +23,28 @@ import time
 import numpy as np
 import tensorflow as tf
 
-from tensorflow.python.platform import gfile
+import program_utils
 
 FLAGS = tf.app.flags.FLAGS
 
-bins = [8, 12, 16, 20, 24, 28, 32, 36, 40, 48, 64, 128]
+bins = [2 + bin_idx_i for bin_idx_i in xrange(256)]
 all_tasks = ["sort", "kvsort", "id", "rev", "rev2", "incr", "add", "left",
              "right", "left-shift", "right-shift", "bmul", "mul", "dup",
-             "badd", "qadd", "search"]
-forward_max = 128
+             "badd", "qadd", "search", "progeval", "progsynth"]
 log_filename = ""
+vocab, rev_vocab = None, None
 
 
 def pad(l):
   for b in bins:
     if b >= l: return b
-  return forward_max
+  return bins[-1]
+
+
+def bin_for(l):
+  for i, b in enumerate(bins):
+    if b >= l: return i
+  return len(bins) - 1
 
 
 train_set = {}
@@ -50,6 +57,35 @@ for some_task in all_tasks:
     test_set[some_task].append([])
 
 
+def read_tmp_file(name):
+  """Read from a file with the given name in our log directory or above."""
+  dirname = os.path.dirname(log_filename)
+  fname = os.path.join(dirname, name + ".txt")
+  if not tf.gfile.Exists(fname):
+    print_out("== not found file: " + fname)
+    fname = os.path.join(dirname, "../" + name + ".txt")
+  if not tf.gfile.Exists(fname):
+    print_out("== not found file: " + fname)
+    fname = os.path.join(dirname, "../../" + name + ".txt")
+  if not tf.gfile.Exists(fname):
+    print_out("== not found file: " + fname)
+    return None
+  print_out("== found file: " + fname)
+  res = []
+  with tf.gfile.GFile(fname, mode="r") as f:
+    for line in f:
+      res.append(line.strip())
+  return res
+
+
+def write_tmp_file(name, lines):
+  dirname = os.path.dirname(log_filename)
+  fname = os.path.join(dirname, name + ".txt")
+  with tf.gfile.GFile(fname, mode="w") as f:
+    for line in lines:
+      f.write(line + "\n")
+
+
 def add(n1, n2, base=10):
   """Add two numbers represented as lower-endian digit lists."""
   k = max(len(n1), len(n2)) + 1
@@ -130,6 +166,30 @@ def init_data(task, length, nbr_cases, nclass):
     sorted_kv = [(k, vals[i]) for (k, i) in sorted(keys)]
     return [x for p in kv for x in p], [x for p in sorted_kv for x in p]
 
+  def prog_io_pair(prog, max_len, counter=0):
+    try:
+      ilen = np.random.randint(max_len - 3) + 1
+      bound = max(15 - (counter / 20), 1)
+      inp = [random.choice(range(-bound, bound)) for _ in range(ilen)]
+      inp_toks = [program_utils.prog_rev_vocab[t]
+                  for t in program_utils.tokenize(str(inp)) if t != ","]
+      out = program_utils.evaluate(prog, {"a": inp})
+      out_toks = [program_utils.prog_rev_vocab[t]
+                  for t in program_utils.tokenize(str(out)) if t != ","]
+      if counter > 400:
+        out_toks = []
+      if (out_toks and out_toks[0] == program_utils.prog_rev_vocab["["] and
+          len(out_toks) != len([o for o in out if o == ","]) + 3):
+        raise ValueError("generated list with too long ints")
+      if (out_toks and out_toks[0] != program_utils.prog_rev_vocab["["] and
+          len(out_toks) > 1):
+        raise ValueError("generated one int but tokenized it to many")
+      if len(out_toks) > max_len:
+        raise ValueError("output too long")
+      return (inp_toks, out_toks)
+    except ValueError:
+      return prog_io_pair(prog, max_len, counter+1)
+
   def spec(inp):
     """Return the target given the input for some tasks."""
     if task == "sort":
@@ -164,43 +224,114 @@ def init_data(task, length, nbr_cases, nclass):
   l = length
   cur_time = time.time()
   total_time = 0.0
-  for case in xrange(nbr_cases):
+
+  is_prog = task in ["progeval", "progsynth"]
+  if is_prog:
+    inputs_per_prog = 5
+    program_utils.make_vocab()
+    progs = read_tmp_file("programs_len%d" % (l / 10))
+    if not progs:
+      progs = program_utils.gen(l / 10, 1.2 * nbr_cases / inputs_per_prog)
+      write_tmp_file("programs_len%d" % (l / 10), progs)
+    prog_ios = read_tmp_file("programs_len%d_io" % (l / 10))
+    nbr_cases = min(nbr_cases, len(progs) * inputs_per_prog) / 1.2
+    if not prog_ios:
+      # Generate program io data.
+      prog_ios = []
+      for pidx, prog in enumerate(progs):
+        if pidx % 500 == 0:
+          print_out("== generating io pairs for program %d" % pidx)
+        if pidx * inputs_per_prog > nbr_cases * 1.2:
+          break
+        ptoks = [program_utils.prog_rev_vocab[t]
+                 for t in program_utils.tokenize(prog)]
+        ptoks.append(program_utils.prog_rev_vocab["_EOS"])
+        plen = len(ptoks)
+        for _ in xrange(inputs_per_prog):
+          if task == "progeval":
+            inp, out = prog_io_pair(prog, plen)
+            prog_ios.append(str(inp) + "\t" + str(out) + "\t" + prog)
+          elif task == "progsynth":
+            plen = max(len(ptoks), 8)
+            for _ in xrange(3):
+              inp, out = prog_io_pair(prog, plen / 2)
+              prog_ios.append(str(inp) + "\t" + str(out) + "\t" + prog)
+      write_tmp_file("programs_len%d_io" % (l / 10), prog_ios)
+    prog_ios_dict = {}
+    for s in prog_ios:
+      i, o, p = s.split("\t")
+      i_clean = "".join([c for c in i if c.isdigit() or c == " "])
+      o_clean = "".join([c for c in o if c.isdigit() or c == " "])
+      inp = [int(x) for x in i_clean.split()]
+      out = [int(x) for x in o_clean.split()]
+      if inp and out:
+        if p in prog_ios_dict:
+          prog_ios_dict[p].append([inp, out])
+        else:
+          prog_ios_dict[p] = [[inp, out]]
+    # Use prog_ios_dict to create data.
+    progs = []
+    for prog in prog_ios_dict:
+      if len([c for c in prog if c == ";"]) <= (l / 10):
+        progs.append(prog)
+    nbr_cases = min(nbr_cases, len(progs) * inputs_per_prog) / 1.2
+    print_out("== %d training cases on %d progs" % (nbr_cases, len(progs)))
+    for pidx, prog in enumerate(progs):
+      if pidx * inputs_per_prog > nbr_cases * 1.2:
+        break
+      ptoks = [program_utils.prog_rev_vocab[t]
+               for t in program_utils.tokenize(prog)]
+      ptoks.append(program_utils.prog_rev_vocab["_EOS"])
+      plen = len(ptoks)
+      dset = train_set if pidx < nbr_cases / inputs_per_prog else test_set
+      for _ in xrange(inputs_per_prog):
+        if task == "progeval":
+          inp, out = prog_ios_dict[prog].pop()
+          dset[task][bin_for(plen)].append([[ptoks, inp, [], []], [out]])
+        elif task == "progsynth":
+          plen, ilist = max(len(ptoks), 8), [[]]
+          for _ in xrange(3):
+            inp, out = prog_ios_dict[prog].pop()
+            ilist.append(inp + out)
+          dset[task][bin_for(plen)].append([ilist, [ptoks]])
+
+  for case in xrange(0 if is_prog else nbr_cases):
     total_time += time.time() - cur_time
     cur_time = time.time()
     if l > 10000 and case % 100 == 1:
       print_out("  avg gen time %.4f s" % (total_time / float(case)))
     if task in ["add", "badd", "qadd", "bmul", "mul"]:
       i, t = rand_pair(l, task)
-      train_set[task][len(i)].append([i, t])
+      train_set[task][bin_for(len(i))].append([[[], i, [], []], [t]])
       i, t = rand_pair(l, task)
-      test_set[task][len(i)].append([i, t])
+      test_set[task][bin_for(len(i))].append([[[], i, [], []], [t]])
     elif task == "dup":
       i, t = rand_dup_pair(l)
-      train_set[task][len(i)].append([i, t])
+      train_set[task][bin_for(len(i))].append([[i], [t]])
       i, t = rand_dup_pair(l)
-      test_set[task][len(i)].append([i, t])
+      test_set[task][bin_for(len(i))].append([[i], [t]])
     elif task == "rev2":
       i, t = rand_rev2_pair(l)
-      train_set[task][len(i)].append([i, t])
+      train_set[task][bin_for(len(i))].append([[i], [t]])
       i, t = rand_rev2_pair(l)
-      test_set[task][len(i)].append([i, t])
+      test_set[task][bin_for(len(i))].append([[i], [t]])
     elif task == "search":
       i, t = rand_search_pair(l)
-      train_set[task][len(i)].append([i, t])
+      train_set[task][bin_for(len(i))].append([[i], [t]])
       i, t = rand_search_pair(l)
-      test_set[task][len(i)].append([i, t])
+      test_set[task][bin_for(len(i))].append([[i], [t]])
     elif task == "kvsort":
       i, t = rand_kvsort_pair(l)
-      train_set[task][len(i)].append([i, t])
+      train_set[task][bin_for(len(i))].append([[i], [t]])
       i, t = rand_kvsort_pair(l)
-      test_set[task][len(i)].append([i, t])
-    else:
+      test_set[task][bin_for(len(i))].append([[i], [t]])
+    elif task not in ["progeval", "progsynth"]:
       inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
       target = spec(inp)
-      train_set[task][l].append([inp, target])
+      train_set[task][bin_for(l)].append([[inp], [target]])
       inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
       target = spec(inp)
-      test_set[task][l].append([inp, target])
+      test_set[task][bin_for(l)].append([[inp], [target]])
 
 
 def to_symbol(i):
@@ -218,37 +349,31 @@ def to_id(s):
   return int(s) + 1
 
 
-def get_batch(max_length, batch_size, do_train, task, offset=None, preset=None):
+def get_batch(bin_id, batch_size, data_set, height, offset=None, preset=None):
   """Get a batch of data, training or testing."""
-  inputs = []
-  targets = []
-  length = max_length
-  if preset is None:
-    cur_set = test_set[task]
-    if do_train: cur_set = train_set[task]
-    while not cur_set[length]:
-      length -= 1
-  pad_length = pad(length)
+  inputs, targets = [], []
+  pad_length = bins[bin_id]
   for b in xrange(batch_size):
     if preset is None:
-      elem = random.choice(cur_set[length])
-      if offset is not None and offset + b < len(cur_set[length]):
-        elem = cur_set[length][offset + b]
+      elem = random.choice(data_set[bin_id])
+      if offset is not None and offset + b < len(data_set[bin_id]):
+        elem = data_set[bin_id][offset + b]
     else:
       elem = preset
-    inp, target = elem[0], elem[1]
-    assert len(inp) == length
-    inputs.append(inp + [0 for l in xrange(pad_length - len(inp))])
-    targets.append(target + [0 for l in xrange(pad_length - len(target))])
-  res_input = []
-  res_target = []
-  for l in xrange(pad_length):
-    new_input = np.array([inputs[b][l] for b in xrange(batch_size)],
-                         dtype=np.int32)
-    new_target = np.array([targets[b][l] for b in xrange(batch_size)],
-                          dtype=np.int32)
-    res_input.append(new_input)
-    res_target.append(new_target)
+    inpt, targett, inpl, targetl = elem[0], elem[1], [], []
+    for inp in inpt:
+      inpl.append(inp + [0 for _ in xrange(pad_length - len(inp))])
+    if len(inpl) == 1:
+      for _ in xrange(height - 1):
+        inpl.append([0 for _ in xrange(pad_length)])
+    for target in targett:
+      targetl.append(target + [0 for _ in xrange(pad_length - len(target))])
+    inputs.append(inpl)
+    targets.append(targetl)
+  res_input = np.array(inputs, dtype=np.int32)
+  res_target = np.array(targets, dtype=np.int32)
+  assert list(res_input.shape) == [batch_size, height, pad_length]
+  assert list(res_target.shape) == [batch_size, 1, pad_length]
   return res_input, res_target
 
 
@@ -256,11 +381,11 @@ def print_out(s, newline=True):
   """Print a message out and log it to file."""
   if log_filename:
     try:
-      with gfile.GFile(log_filename, mode="a") as f:
+      with tf.gfile.GFile(log_filename, mode="a") as f:
         f.write(s + ("\n" if newline else ""))
     # pylint: disable=bare-except
     except:
-      sys.stdout.write("Error appending to %s\n" % log_filename)
+      sys.stderr.write("Error appending to %s\n" % log_filename)
   sys.stdout.write(s + ("\n" if newline else ""))
   sys.stdout.flush()
 
@@ -269,21 +394,36 @@ def decode(output):
   return [np.argmax(o, axis=1) for o in output]
 
 
-def accuracy(inpt, output, target, batch_size, nprint):
+def accuracy(inpt_t, output, target_t, batch_size, nprint,
+             beam_out=None, beam_scores=None):
   """Calculate output accuracy given target."""
   assert nprint < batch_size + 1
+  inpt = []
+  for h in xrange(inpt_t.shape[1]):
+    inpt.extend([inpt_t[:, h, l] for l in xrange(inpt_t.shape[2])])
+  target = [target_t[:, 0, l] for l in xrange(target_t.shape[2])]
+  def tok(i):
+    if rev_vocab and i < len(rev_vocab):
+      return rev_vocab[i]
+    return str(i - 1)
   def task_print(inp, output, target):
     stop_bound = 0
     print_len = 0
     while print_len < len(target) and target[print_len] > stop_bound:
       print_len += 1
-    print_out("    i: " + " ".join([str(i - 1) for i in inp if i > 0]))
+    print_out("    i: " + " ".join([tok(i) for i in inp if i > 0]))
     print_out("    o: " +
-              " ".join([str(output[l] - 1) for l in xrange(print_len)]))
+              " ".join([tok(output[l]) for l in xrange(print_len)]))
     print_out("    t: " +
-              " ".join([str(target[l] - 1) for l in xrange(print_len)]))
+              " ".join([tok(target[l]) for l in xrange(print_len)]))
   decoded_target = target
   decoded_output = decode(output)
+  # Use beam output if given and score is high enough.
+  if beam_out is not None:
+    for b in xrange(batch_size):
+      if beam_scores[b] >= 10.0:
+        for l in xrange(min(len(decoded_output), beam_out.shape[2])):
+          decoded_output[l][b] = int(beam_out[b, 0, l])
   total = 0
   errors = 0
   seq = [0 for b in xrange(batch_size)]
@@ -311,6 +451,7 @@ def accuracy(inpt, output, target, batch_size, nprint):
 
 def safe_exp(x):
   perp = 10000
+  x = float(x)
   if x < 100: perp = math.exp(x)
   if perp > 10000: return 10000
   return perp

+ 637 - 217
neural_gpu/neural_gpu.py

@@ -16,26 +16,34 @@
 
 import time
 
+import numpy as np
 import tensorflow as tf
 
-import data_utils
+from tensorflow.python.framework import function
+import data_utils as data
 
+do_jit = False  # Gives more speed but experimental for now.
+jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
 
-def conv_linear(args, kw, kh, nin, nout, do_bias, bias_start, prefix):
+
+def conv_linear(args, kw, kh, nin, nout, rate, do_bias, bias_start, prefix):
   """Convolutional linear map."""
-  assert args is not None
   if not isinstance(args, (list, tuple)):
     args = [args]
   with tf.variable_scope(prefix):
-    k = tf.get_variable("CvK", [kw, kh, nin, nout])
+    with tf.device("/cpu:0"):
+      k = tf.get_variable("CvK", [kw, kh, nin, nout])
     if len(args) == 1:
-      res = tf.nn.conv2d(args[0], k, [1, 1, 1, 1], "SAME")
+      arg = args[0]
     else:
-      res = tf.nn.conv2d(tf.concat(3, args), k, [1, 1, 1, 1], "SAME")
+      arg = tf.concat(args, 3)
+    res = tf.nn.convolution(arg, k, dilation_rate=(rate, 1), padding="SAME")
     if not do_bias: return res
-    bias_term = tf.get_variable("CvB", [nout],
-                                initializer=tf.constant_initializer(0.0))
-    return res + bias_term + bias_start
+    with tf.device("/cpu:0"):
+      bias_term = tf.get_variable(
+          "CvB", [nout], initializer=tf.constant_initializer(bias_start))
+    bias_term = tf.reshape(bias_term, [1, 1, 1, nout])
+    return res + bias_term
 
 
 def sigmoid_cutoff(x, cutoff):
@@ -43,7 +51,34 @@ def sigmoid_cutoff(x, cutoff):
   y = tf.sigmoid(x)
   if cutoff < 1.01: return y
   d = (cutoff - 1.0) / 2.0
-  return tf.minimum(1.0, tf.maximum(0.0, cutoff * y - d))
+  return tf.minimum(1.0, tf.maximum(0.0, cutoff * y - d), name="cutoff_min")
+
+
+@function.Defun(tf.float32, noinline=True)
+def sigmoid_cutoff_12(x):
+  """Sigmoid with cutoff 1.2, specialized for speed and memory use."""
+  y = tf.sigmoid(x)
+  return tf.minimum(1.0, tf.maximum(0.0, 1.2 * y - 0.1), name="cutoff_min_12")
+
+
+@function.Defun(tf.float32, noinline=True)
+def sigmoid_hard(x):
+  """Hard sigmoid."""
+  return tf.minimum(1.0, tf.maximum(0.0, 0.25 * x + 0.5))
+
+
+def place_at14(decided, selected, it):
+  """Place selected at it-th coordinate of decided, dim=1 of 4."""
+  slice1 = decided[:, :it, :, :]
+  slice2 = decided[:, it + 1:, :, :]
+  return tf.concat([slice1, selected, slice2], 1)
+
+
+def place_at13(decided, selected, it):
+  """Place selected at it-th coordinate of decided, dim=1 of 3."""
+  slice1 = decided[:, :it, :]
+  slice2 = decided[:, it + 1:, :]
+  return tf.concat([slice1, selected, slice2], 1)
 
 
 def tanh_cutoff(x, cutoff):
@@ -54,18 +89,80 @@ def tanh_cutoff(x, cutoff):
   return tf.minimum(1.0, tf.maximum(-1.0, (1.0 + d) * y))
 
 
-def conv_gru(inpts, mem, kw, kh, nmaps, cutoff, prefix):
+@function.Defun(tf.float32, noinline=True)
+def tanh_hard(x):
+  """Hard tanh."""
+  return tf.minimum(1.0, tf.maximum(0.0, x))
+
+
+def layer_norm(x, nmaps, prefix, epsilon=1e-5):
+  """Layer normalize the 4D tensor x, averaging over the last dimension."""
+  with tf.variable_scope(prefix):
+    scale = tf.get_variable("layer_norm_scale", [nmaps],
+                            initializer=tf.ones_initializer())
+    bias = tf.get_variable("layer_norm_bias", [nmaps],
+                           initializer=tf.zeros_initializer())
+    mean, variance = tf.nn.moments(x, [3], keep_dims=True)
+    norm_x = (x - mean) / tf.sqrt(variance + epsilon)
+    return norm_x * scale + bias
+
+
+def conv_gru(inpts, mem, kw, kh, nmaps, rate, cutoff, prefix, do_layer_norm,
+             args_len=None):
   """Convolutional GRU."""
   def conv_lin(args, suffix, bias_start):
-    return conv_linear(args, kw, kh, len(args) * nmaps, nmaps, True, bias_start,
-                       prefix + "/" + suffix)
-  reset = sigmoid_cutoff(conv_lin(inpts + [mem], "r", 1.0), cutoff)
-  # candidate = tanh_cutoff(conv_lin(inpts + [reset * mem], "c", 0.0), cutoff)
-  candidate = tf.tanh(conv_lin(inpts + [reset * mem], "c", 0.0))
-  gate = sigmoid_cutoff(conv_lin(inpts + [mem], "g", 1.0), cutoff)
+    total_args_len = args_len or len(args) * nmaps
+    res = conv_linear(args, kw, kh, total_args_len, nmaps, rate, True,
+                      bias_start, prefix + "/" + suffix)
+    if do_layer_norm:
+      return layer_norm(res, nmaps, prefix + "/" + suffix)
+    else:
+      return res
+  if cutoff == 1.2:
+    reset = sigmoid_cutoff_12(conv_lin(inpts + [mem], "r", 1.0))
+    gate = sigmoid_cutoff_12(conv_lin(inpts + [mem], "g", 1.0))
+  elif cutoff > 10:
+    reset = sigmoid_hard(conv_lin(inpts + [mem], "r", 1.0))
+    gate = sigmoid_hard(conv_lin(inpts + [mem], "g", 1.0))
+  else:
+    reset = sigmoid_cutoff(conv_lin(inpts + [mem], "r", 1.0), cutoff)
+    gate = sigmoid_cutoff(conv_lin(inpts + [mem], "g", 1.0), cutoff)
+  if cutoff > 10:
+    candidate = tf.tanh_hard(conv_lin(inpts + [reset * mem], "c", 0.0))
+  else:
+    # candidate = tanh_cutoff(conv_lin(inpts + [reset * mem], "c", 0.0), cutoff)
+    candidate = tf.tanh(conv_lin(inpts + [reset * mem], "c", 0.0))
   return gate * mem + (1 - gate) * candidate
 
 
+CHOOSE_K = 256
+
+
+def memory_call(q, l, nmaps, mem_size, vocab_size, num_gpus, update_mem):
+  raise ValueError("Fill for experiments with additional memory structures.")
+
+
+def memory_run(step, nmaps, mem_size, batch_size, vocab_size,
+               global_step, do_training, update_mem, decay_factor, num_gpus,
+               target_emb_weights, output_w, gpu_targets_tn, it):
+  """Run memory."""
+  q = step[:, 0, it, :]
+  mlabels = gpu_targets_tn[:, it, 0]
+  res, mask, mem_loss = memory_call(
+      q, mlabels, nmaps, mem_size, vocab_size, num_gpus, update_mem)
+  res = tf.gather(target_emb_weights, res) * tf.expand_dims(mask[:, 0], 1)
+
+  # Mix gold and original in the first steps, 20% later.
+  gold = tf.nn.dropout(tf.gather(target_emb_weights, mlabels), 0.7)
+  use_gold = 1.0 - tf.cast(global_step, tf.float32) / (1000. * decay_factor)
+  use_gold = tf.maximum(use_gold, 0.2) * do_training
+  mem = tf.cond(tf.less(tf.random_uniform([]), use_gold),
+                lambda: use_gold * gold + (1.0 - use_gold) * res,
+                lambda: res)
+  mem = tf.reshape(mem, [-1, 1, 1, nmaps])
+  return mem, mem_loss, update_mem
+
+
 @tf.RegisterGradient("CustomIdG")
 def _custom_id_grad(_, grads):
   return grads
@@ -86,237 +183,560 @@ def quantize_weights_op(quant_scale, max_value):
   return tf.group(*ops)
 
 
-def relaxed_average(var_name_suffix, rx_step):
-  """Calculate the average of relaxed variables having var_name_suffix."""
-  relaxed_vars = []
-  for l in xrange(rx_step):
-    with tf.variable_scope("RX%d" % l, reuse=True):
-      try:
-        relaxed_vars.append(tf.get_variable(var_name_suffix))
-      except ValueError:
-        pass
-  dsum = tf.add_n(relaxed_vars)
-  avg = dsum / len(relaxed_vars)
-  diff = [v - avg for v in relaxed_vars]
-  davg = tf.add_n([d*d for d in diff])
-  return avg, tf.reduce_sum(davg)
-
-
-def relaxed_distance(rx_step):
-  """Distance between relaxed variables and their average."""
-  res, ops, rx_done = [], [], {}
-  for v in tf.trainable_variables():
-    if v.name[0:2] == "RX":
-      rx_name = v.op.name[v.name.find("/") + 1:]
-      if rx_name not in rx_done:
-        avg, dist_loss = relaxed_average(rx_name, rx_step)
-        res.append(dist_loss)
-        rx_done[rx_name] = avg
-      ops.append(v.assign(rx_done[rx_name]))
-  return tf.add_n(res), tf.group(*ops)
-
-
-def make_dense(targets, noclass):
+def autoenc_quantize(x, nbits, nmaps, do_training, layers=1):
+  """Autoencoder into nbits vectors of bits, using noise and sigmoids."""
+  enc_x = tf.reshape(x, [-1, nmaps])
+  for i in xrange(layers - 1):
+    enc_x = tf.layers.dense(enc_x, nmaps, name="autoenc_%d" % i)
+  enc_x = tf.layers.dense(enc_x, nbits, name="autoenc_%d" % (layers - 1))
+  noise = tf.truncated_normal(tf.shape(enc_x), stddev=2.0)
+  dec_x = sigmoid_cutoff_12(enc_x + noise * do_training)
+  dec_x = tf.reshape(dec_x, [-1, nbits])
+  for i in xrange(layers):
+    dec_x = tf.layers.dense(dec_x, nmaps, name="autodec_%d" % i)
+  return tf.reshape(dec_x, tf.shape(x))
+
+
+def make_dense(targets, noclass, low_param):
   """Move a batch of targets to a dense 1-hot representation."""
-  with tf.device("/cpu:0"):
-    shape = tf.shape(targets)
-    batch_size = shape[0]
-    indices = targets + noclass * tf.range(0, batch_size)
-    length = tf.expand_dims(batch_size * noclass, 0)
-    dense = tf.sparse_to_dense(indices, length, 1.0, 0.0)
-  return tf.reshape(dense, [-1, noclass])
-
-
-def check_for_zero(sparse):
-  """In a sparse batch of ints, make 1.0 if it's 0 and 0.0 else."""
-  with tf.device("/cpu:0"):
-    shape = tf.shape(sparse)
-    batch_size = shape[0]
-    sparse = tf.minimum(sparse, 1)
-    indices = sparse + 2 * tf.range(0, batch_size)
-    dense = tf.sparse_to_dense(indices, tf.expand_dims(2 * batch_size, 0),
-                               1.0, 0.0)
-    reshaped = tf.reshape(dense, [-1, 2])
-  return tf.reshape(tf.slice(reshaped, [0, 0], [-1, 1]), [-1])
+  low = low_param / float(noclass - 1)
+  high = 1.0 - low * (noclass - 1)
+  targets = tf.cast(targets, tf.int64)
+  return tf.one_hot(targets, depth=noclass, on_value=high, off_value=low)
+
+
+def reorder_beam(beam_size, batch_size, beam_val, output, is_first,
+                 tensors_to_reorder):
+  """Reorder to minimize beam costs."""
+  # beam_val is [batch_size x beam_size]; let b = batch_size * beam_size
+  # decided is len x b x a x b
+  # output is b x out_size; step is b x len x a x b;
+  outputs = tf.split(tf.nn.log_softmax(output), beam_size, 0)
+  all_beam_vals, all_beam_idx = [], []
+  beam_range = 1 if is_first else beam_size
+  for i in xrange(beam_range):
+    top_out, top_out_idx = tf.nn.top_k(outputs[i], k=beam_size)
+    cur_beam_val = beam_val[:, i]
+    top_out = tf.Print(top_out, [top_out, top_out_idx, beam_val, i,
+                                 cur_beam_val], "GREPO", summarize=8)
+    all_beam_vals.append(top_out + tf.expand_dims(cur_beam_val, 1))
+    all_beam_idx.append(top_out_idx)
+  all_beam_idx = tf.reshape(tf.transpose(tf.concat(all_beam_idx, 1), [1, 0]),
+                            [-1])
+  top_beam, top_beam_idx = tf.nn.top_k(tf.concat(all_beam_vals, 1), k=beam_size)
+  top_beam_idx = tf.Print(top_beam_idx, [top_beam, top_beam_idx],
+                          "GREP", summarize=8)
+  reordered = [[] for _ in xrange(len(tensors_to_reorder) + 1)]
+  top_out_idx = []
+  for i in xrange(beam_size):
+    which_idx = top_beam_idx[:, i] * batch_size + tf.range(batch_size)
+    top_out_idx.append(tf.gather(all_beam_idx, which_idx))
+    which_beam = top_beam_idx[:, i] / beam_size  # [batch]
+    which_beam = which_beam * batch_size + tf.range(batch_size)
+    reordered[0].append(tf.gather(output, which_beam))
+    for i, t in enumerate(tensors_to_reorder):
+      reordered[i + 1].append(tf.gather(t, which_beam))
+  new_tensors = [tf.concat(t, 0) for t in reordered]
+  top_out_idx = tf.concat(top_out_idx, 0)
+  return (top_beam, new_tensors[0], top_out_idx, new_tensors[1:])
 
 
 class NeuralGPU(object):
   """Neural GPU Model."""
 
-  def __init__(self, nmaps, vec_size, niclass, noclass, dropout, rx_step,
-               max_grad_norm, cutoff, nconvs, kw, kh, height, mode,
-               learning_rate, pull, pull_incr, min_length, act_noise=0.0):
+  def __init__(self, nmaps, vec_size, niclass, noclass, dropout,
+               max_grad_norm, cutoff, nconvs, kw, kh, height, mem_size,
+               learning_rate, min_length, num_gpus, num_replicas,
+               grad_noise_scale, sampling_rate, act_noise=0.0, do_rnn=False,
+               atrous=False, beam_size=1, backward=True, do_layer_norm=False,
+               autoenc_decay=1.0):
     # Feeds for parameters and ops to update them.
-    self.global_step = tf.Variable(0, trainable=False)
-    self.cur_length = tf.Variable(min_length, trainable=False)
-    self.cur_length_incr_op = self.cur_length.assign_add(1)
-    self.lr = tf.Variable(float(learning_rate), trainable=False)
-    self.lr_decay_op = self.lr.assign(self.lr * 0.98)
-    self.pull = tf.Variable(float(pull), trainable=False)
-    self.pull_incr_op = self.pull.assign(self.pull * pull_incr)
+    self.nmaps = nmaps
+    if backward:
+      self.global_step = tf.Variable(0, trainable=False, name="global_step")
+      self.cur_length = tf.Variable(min_length, trainable=False)
+      self.cur_length_incr_op = self.cur_length.assign_add(1)
+      self.lr = tf.Variable(learning_rate, trainable=False)
+      self.lr_decay_op = self.lr.assign(self.lr * 0.995)
     self.do_training = tf.placeholder(tf.float32, name="do_training")
+    self.update_mem = tf.placeholder(tf.int32, name="update_mem")
     self.noise_param = tf.placeholder(tf.float32, name="noise_param")
 
     # Feeds for inputs, targets, outputs, losses, etc.
-    self.input = []
-    self.target = []
-    for l in xrange(data_utils.forward_max + 1):
-      self.input.append(tf.placeholder(tf.int32, name="inp{0}".format(l)))
-      self.target.append(tf.placeholder(tf.int32, name="tgt{0}".format(l)))
-    self.outputs = []
-    self.losses = []
-    self.grad_norms = []
-    self.updates = []
+    self.input = tf.placeholder(tf.int32, name="inp")
+    self.target = tf.placeholder(tf.int32, name="tgt")
+    self.prev_step = tf.placeholder(tf.float32, name="prev_step")
+    gpu_input = tf.split(self.input, num_gpus, 0)
+    gpu_target = tf.split(self.target, num_gpus, 0)
+    gpu_prev_step = tf.split(self.prev_step, num_gpus, 0)
+    batch_size = tf.shape(gpu_input[0])[0]
+
+    if backward:
+      adam_lr = 0.005 * self.lr
+      adam = tf.train.AdamOptimizer(adam_lr, epsilon=2e-4)
+
+      def adam_update(grads):
+        return adam.apply_gradients(zip(grads, tf.trainable_variables()),
+                                    global_step=self.global_step,
+                                    name="adam_update")
+
+    # When switching from Adam to SGD we perform reverse-decay.
+    if backward:
+      global_step_float = tf.cast(self.global_step, tf.float32)
+      sampling_decay_exponent = global_step_float / 100000.0
+      sampling_decay = tf.maximum(0.05, tf.pow(0.5, sampling_decay_exponent))
+      self.sampling = sampling_rate * 0.05 / sampling_decay
+    else:
+      self.sampling = tf.constant(0.0)
+
+    # Cache variables on cpu if needed.
+    if num_replicas > 1 or num_gpus > 1:
+      with tf.device("/cpu:0"):
+        caching_const = tf.constant(0)
+      tf.get_variable_scope().set_caching_device(caching_const.op.device)
+      # partitioner = tf.variable_axis_size_partitioner(1024*256*4)
+      # tf.get_variable_scope().set_partitioner(partitioner)
+
+    def gpu_avg(l):
+      if l[0] is None:
+        for elem in l:
+          assert elem is None
+        return 0.0
+      if len(l) < 2:
+        return l[0]
+      return sum(l) / float(num_gpus)
+
+    self.length_tensor = tf.placeholder(tf.int32, name="length")
 
-    # Computation.
-    inp0_shape = tf.shape(self.input[0])
-    batch_size = inp0_shape[0]
     with tf.device("/cpu:0"):
       emb_weights = tf.get_variable(
           "embedding", [niclass, vec_size],
           initializer=tf.random_uniform_initializer(-1.7, 1.7))
+      if beam_size > 0:
+        target_emb_weights = tf.get_variable(
+            "target_embedding", [noclass, nmaps],
+            initializer=tf.random_uniform_initializer(-1.7, 1.7))
       e0 = tf.scatter_update(emb_weights,
                              tf.constant(0, dtype=tf.int32, shape=[1]),
                              tf.zeros([1, vec_size]))
-
-    adam = tf.train.AdamOptimizer(self.lr, epsilon=1e-4)
-
-    # Main graph creation loop, for every bin in data_utils.
-    self.steps = []
-    for length in sorted(list(set(data_utils.bins + [data_utils.forward_max]))):
-      data_utils.print_out("Creating model for bin of length %d." % length)
-      start_time = time.time()
-      if length > data_utils.bins[0]:
+      output_w = tf.get_variable("output_w", [nmaps, noclass], tf.float32)
+
+    def conv_rate(layer):
+      if atrous:
+        return 2**layer
+      return 1
+
+    # pylint: disable=cell-var-from-loop
+    def enc_step(step):
+      """Encoder step."""
+      if autoenc_decay < 1.0:
+        quant_step = autoenc_quantize(step, 16, nmaps, self.do_training)
+        if backward:
+          exp_glob = tf.train.exponential_decay(1.0, self.global_step - 10000,
+                                                1000, autoenc_decay)
+          dec_factor = 1.0 - exp_glob  # * self.do_training
+          dec_factor = tf.cond(tf.less(self.global_step, 10500),
+                               lambda: tf.constant(0.05), lambda: dec_factor)
+        else:
+          dec_factor = 1.0
+        cur = tf.cond(tf.less(tf.random_uniform([]), dec_factor),
+                      lambda: quant_step, lambda: step)
+      else:
+        cur = step
+      if dropout > 0.0001:
+        cur = tf.nn.dropout(cur, keep_prob)
+      if act_noise > 0.00001:
+        cur += tf.truncated_normal(tf.shape(cur)) * act_noise_scale
+      # Do nconvs-many CGRU steps.
+      if do_jit and tf.get_variable_scope().reuse:
+        with jit_scope():
+          for layer in xrange(nconvs):
+            cur = conv_gru([], cur, kw, kh, nmaps, conv_rate(layer),
+                           cutoff, "ecgru_%d" % layer, do_layer_norm)
+      else:
+        for layer in xrange(nconvs):
+          cur = conv_gru([], cur, kw, kh, nmaps, conv_rate(layer),
+                         cutoff, "ecgru_%d" % layer, do_layer_norm)
+      return cur
+
+    zero_tgt = tf.zeros([batch_size, nmaps, 1])
+    zero_tgt.set_shape([None, nmaps, 1])
+
+    def dec_substep(step, decided):
+      """Decoder sub-step."""
+      cur = step
+      if dropout > 0.0001:
+        cur = tf.nn.dropout(cur, keep_prob)
+      if act_noise > 0.00001:
+        cur += tf.truncated_normal(tf.shape(cur)) * act_noise_scale
+      # Do nconvs-many CGRU steps.
+      if do_jit and tf.get_variable_scope().reuse:
+        with jit_scope():
+          for layer in xrange(nconvs):
+            cur = conv_gru([decided], cur, kw, kh, nmaps, conv_rate(layer),
+                           cutoff, "dcgru_%d" % layer, do_layer_norm)
+      else:
+        for layer in xrange(nconvs):
+          cur = conv_gru([decided], cur, kw, kh, nmaps, conv_rate(layer),
+                         cutoff, "dcgru_%d" % layer, do_layer_norm)
+      return cur
+    # pylint: enable=cell-var-from-loop
+
+    def dec_step(step, it, it_int, decided, output_ta, tgts,
+                 mloss, nupd_in, out_idx, beam_cost):
+      """Decoder step."""
+      nupd, mem_loss = 0, 0.0
+      if mem_size > 0:
+        it_incr = tf.minimum(it+1, length - 1)
+        mem, mem_loss, nupd = memory_run(
+            step, nmaps, mem_size, batch_size, noclass, self.global_step,
+            self.do_training, self.update_mem, 10, num_gpus,
+            target_emb_weights, output_w, gpu_targets_tn, it_incr)
+      step = dec_substep(step, decided)
+      output_l = tf.expand_dims(tf.expand_dims(step[:, it, 0, :], 1), 1)
+      # Calculate argmax output.
+      output = tf.reshape(output_l, [-1, nmaps])
+      # pylint: disable=cell-var-from-loop
+      output = tf.matmul(output, output_w)
+      if beam_size > 1:
+        beam_cost, output, out, reordered = reorder_beam(
+            beam_size, batch_size, beam_cost, output, it_int == 0,
+            [output_l, out_idx, step, decided])
+        [output_l, out_idx, step, decided] = reordered
+      else:
+        # Scheduled sampling.
+        out = tf.multinomial(tf.stop_gradient(output), 1)
+        out = tf.to_int32(tf.squeeze(out, [1]))
+      out_write = output_ta.write(it, output_l[:batch_size, :, :, :])
+      output = tf.gather(target_emb_weights, out)
+      output = tf.reshape(output, [-1, 1, nmaps])
+      output = tf.concat([output] * height, 1)
+      tgt = tgts[it, :, :, :]
+      selected = tf.cond(tf.less(tf.random_uniform([]), self.sampling),
+                         lambda: output, lambda: tgt)
+      # pylint: enable=cell-var-from-loop
+      dec_write = place_at14(decided, tf.expand_dims(selected, 1), it)
+      out_idx = place_at13(
+          out_idx, tf.reshape(out, [beam_size * batch_size, 1, 1]), it)
+      if mem_size > 0:
+        mem = tf.concat([mem] * height, 2)
+        dec_write = place_at14(dec_write, mem, it_incr)
+      return (step, dec_write, out_write, mloss + mem_loss, nupd_in + nupd,
+              out_idx, beam_cost)
+
+    # Main model construction.
+    gpu_outputs = []
+    gpu_losses = []
+    gpu_grad_norms = []
+    grads_list = []
+    gpu_out_idx = []
+    self.after_enc_step = []
+    for gpu in xrange(num_gpus):  # Multi-GPU towers, average gradients later.
+      length = self.length_tensor
+      length_float = tf.cast(length, tf.float32)
+      if gpu > 0:
         tf.get_variable_scope().reuse_variables()
+      gpu_outputs.append([])
+      gpu_losses.append([])
+      gpu_grad_norms.append([])
+      with tf.name_scope("gpu%d" % gpu), tf.device("/gpu:%d" % gpu):
+        # Main graph creation loop.
+        data.print_out("Creating model.")
+        start_time = time.time()
+
+        # Embed inputs and calculate mask.
+        with tf.device("/cpu:0"):
+          tgt_shape = tf.shape(tf.squeeze(gpu_target[gpu], [1]))
+          weights = tf.where(tf.squeeze(gpu_target[gpu], [1]) > 0,
+                             tf.ones(tgt_shape), tf.zeros(tgt_shape))
+
+          # Embed inputs and targets.
+          with tf.control_dependencies([e0]):
+            start = tf.gather(emb_weights, gpu_input[gpu])  # b x h x l x nmaps
+            gpu_targets_tn = gpu_target[gpu]  # b x 1 x len
+            if beam_size > 0:
+              embedded_targets_tn = tf.gather(target_emb_weights,
+                                              gpu_targets_tn)
+              embedded_targets_tn = tf.transpose(
+                  embedded_targets_tn, [2, 0, 1, 3])  # len x b x 1 x nmaps
+              embedded_targets_tn = tf.concat([embedded_targets_tn] * height, 2)
+
+        # First image comes from start by applying convolution and adding 0s.
+        start = tf.transpose(start, [0, 2, 1, 3])  # Now b x len x h x vec_s
+        first = conv_linear(start, 1, 1, vec_size, nmaps, 1, True, 0.0, "input")
+        first = layer_norm(first, nmaps, "input")
+
+        # Computation steps.
+        keep_prob = dropout * 3.0 / tf.sqrt(length_float)
+        keep_prob = 1.0 - self.do_training * keep_prob
+        act_noise_scale = act_noise * self.do_training
+
+        # Start with a convolutional gate merging previous step.
+        step = conv_gru([gpu_prev_step[gpu]], first,
+                        kw, kh, nmaps, 1, cutoff, "first", do_layer_norm)
+
+        # This is just for running a baseline RNN seq2seq model.
+        if do_rnn:
+          self.after_enc_step.append(step)  # Not meaningful here, but needed.
+          lstm_cell = tf.contrib.rnn.BasicLSTMCell(height * nmaps)
+          cell = tf.contrib.rnn.MultiRNNCell([lstm_cell] * nconvs)
+          with tf.variable_scope("encoder"):
+            encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
+                cell, tf.reshape(step, [batch_size, length, height * nmaps]),
+                dtype=tf.float32, time_major=False)
+
+          # Attention.
+          attn = tf.layers.dense(
+              encoder_outputs, height * nmaps, name="attn1")
+
+          # pylint: disable=cell-var-from-loop
+          @function.Defun(noinline=True)
+          def attention_query(query, attn_v):
+            vecs = tf.tanh(attn + tf.expand_dims(query, 1))
+            mask = tf.reduce_sum(vecs * tf.reshape(attn_v, [1, 1, -1]), 2)
+            mask = tf.nn.softmax(mask)
+            return tf.reduce_sum(encoder_outputs * tf.expand_dims(mask, 2), 1)
+
+          with tf.variable_scope("decoder"):
+            def decoder_loop_fn((state, prev_cell_out, _), (cell_inp, cur_tgt)):
+              """Decoder loop function."""
+              attn_q = tf.layers.dense(prev_cell_out, height * nmaps,
+                                       name="attn_query")
+              attn_res = attention_query(attn_q, tf.get_variable(
+                  "attn_v", [height * nmaps],
+                  initializer=tf.random_uniform_initializer(-0.1, 0.1)))
+              concatenated = tf.reshape(tf.concat([cell_inp, attn_res], 1),
+                                        [batch_size, 2 * height * nmaps])
+              cell_inp = tf.layers.dense(
+                  concatenated, height * nmaps, name="attn_merge")
+              output, new_state = cell(cell_inp, state)
+
+              mem_loss = 0.0
+              if mem_size > 0:
+                res, mask, mem_loss = memory_call(
+                    output, cur_tgt, height * nmaps, mem_size, noclass,
+                    num_gpus, self.update_mem)
+                res = tf.gather(target_emb_weights, res)
+                res *= tf.expand_dims(mask[:, 0], 1)
+                output = tf.layers.dense(
+                    tf.concat([output, res], 1), height * nmaps, name="rnnmem")
+
+              return new_state, output, mem_loss
+            # pylint: enable=cell-var-from-loop
+            gpu_targets = tf.squeeze(gpu_target[gpu], [1])  # b x len
+            gpu_tgt_trans = tf.transpose(gpu_targets, [1, 0])
+            dec_zero = tf.zeros([batch_size, 1], dtype=tf.int32)
+            dec_inp = tf.concat([dec_zero, gpu_targets], 1)
+            dec_inp = dec_inp[:, :length]
+            embedded_dec_inp = tf.gather(target_emb_weights, dec_inp)
+            embedded_dec_inp_proj = tf.layers.dense(
+                embedded_dec_inp, height * nmaps, name="dec_proj")
+            embedded_dec_inp_proj = tf.transpose(embedded_dec_inp_proj,
+                                                 [1, 0, 2])
+            init_vals = (encoder_state,
+                         tf.zeros([batch_size, height * nmaps]), 0.0)
+            _, dec_outputs, mem_losses = tf.scan(
+                decoder_loop_fn, (embedded_dec_inp_proj, gpu_tgt_trans),
+                initializer=init_vals)
+          mem_loss = tf.reduce_mean(mem_losses)
+          outputs = tf.layers.dense(dec_outputs, nmaps, name="out_proj")
+          # Final convolution to get logits, list outputs.
+          outputs = tf.matmul(tf.reshape(outputs, [-1, nmaps]), output_w)
+          outputs = tf.reshape(outputs, [length, batch_size, noclass])
+          gpu_out_idx.append(tf.argmax(outputs, 2))
+        else:  # Here we go with the Neural GPU.
+          # Encoder.
+          enc_length = length
+          step = enc_step(step)  # First step hard-coded.
+          # pylint: disable=cell-var-from-loop
+          i = tf.constant(1)
+          c = lambda i, _s: tf.less(i, enc_length)
+          def enc_step_lambda(i, step):
+            with tf.variable_scope(tf.get_variable_scope(), reuse=True):
+              new_step = enc_step(step)
+            return (i + 1, new_step)
+          _, step = tf.while_loop(
+              c, enc_step_lambda, [i, step],
+              parallel_iterations=1, swap_memory=True)
+          # pylint: enable=cell-var-from-loop
+
+          self.after_enc_step.append(step)
+
+          # Decoder.
+          if beam_size > 0:
+            output_ta = tf.TensorArray(
+                dtype=tf.float32, size=length, dynamic_size=False,
+                infer_shape=False, name="outputs")
+            out_idx = tf.zeros([beam_size * batch_size, length, 1],
+                               dtype=tf.int32)
+            decided_t = tf.zeros([beam_size * batch_size, length,
+                                  height, vec_size])
+
+            # Prepare for beam search.
+            tgts = tf.concat([embedded_targets_tn] * beam_size, 1)
+            beam_cost = tf.zeros([batch_size, beam_size])
+            step = tf.concat([step] * beam_size, 0)
+            # First step hard-coded.
+            step, decided_t, output_ta, mem_loss, nupd, oi, bc = dec_step(
+                step, 0, 0, decided_t, output_ta, tgts, 0.0, 0, out_idx,
+                beam_cost)
+            tf.get_variable_scope().reuse_variables()
+            # pylint: disable=cell-var-from-loop
+            def step_lambda(i, step, dec_t, out_ta, ml, nu, oi, bc):
+              with tf.variable_scope(tf.get_variable_scope(), reuse=True):
+                s, d, t, nml, nu, oi, bc = dec_step(
+                    step, i, 1, dec_t, out_ta, tgts, ml, nu, oi, bc)
+              return (i + 1, s, d, t, nml, nu, oi, bc)
+            i = tf.constant(1)
+            c = lambda i, _s, _d, _o, _ml, _nu, _oi, _bc: tf.less(i, length)
+            _, step, _, output_ta, mem_loss, nupd, out_idx, _ = tf.while_loop(
+                c, step_lambda,
+                [i, step, decided_t, output_ta, mem_loss, nupd, oi, bc],
+                parallel_iterations=1, swap_memory=True)
+            # pylint: enable=cell-var-from-loop
+            gpu_out_idx.append(tf.squeeze(out_idx, [2]))
+            outputs = output_ta.stack()
+            outputs = tf.squeeze(outputs, [2, 3])  # Now l x b x nmaps
+          else:
+            # If beam_size is 0 or less, we don't have a decoder.
+            mem_loss = 0.0
+            outputs = tf.transpose(step[:, :, 1, :], [1, 0, 2])
+            gpu_out_idx.append(tf.argmax(outputs, 2))
+
+          # Final convolution to get logits, list outputs.
+          outputs = tf.matmul(tf.reshape(outputs, [-1, nmaps]), output_w)
+          outputs = tf.reshape(outputs, [length, batch_size, noclass])
+        gpu_outputs[gpu] = tf.nn.softmax(outputs)
+
+        # Calculate cross-entropy loss and normalize it.
+        targets_soft = make_dense(tf.squeeze(gpu_target[gpu], [1]),
+                                  noclass, 0.1)
+        targets_soft = tf.reshape(targets_soft, [-1, noclass])
+        targets_hard = make_dense(tf.squeeze(gpu_target[gpu], [1]),
+                                  noclass, 0.0)
+        targets_hard = tf.reshape(targets_hard, [-1, noclass])
+        output = tf.transpose(outputs, [1, 0, 2])
+        xent_soft = tf.reshape(tf.nn.softmax_cross_entropy_with_logits(
+            logits=tf.reshape(output, [-1, noclass]), labels=targets_soft),
+                               [batch_size, length])
+        xent_hard = tf.reshape(tf.nn.softmax_cross_entropy_with_logits(
+            logits=tf.reshape(output, [-1, noclass]), labels=targets_hard),
+                               [batch_size, length])
+        low, high = 0.1 / float(noclass - 1), 0.9
+        const = high * tf.log(high) + float(noclass - 1) * low * tf.log(low)
+        weight_sum = tf.reduce_sum(weights) + 1e-20
+        true_perp = tf.reduce_sum(xent_hard * weights) / weight_sum
+        soft_loss = tf.reduce_sum(xent_soft * weights) / weight_sum
+        perp_loss = soft_loss + const
+        # Final loss: cross-entropy + shared parameter relaxation part + extra.
+        mem_loss = 0.5 * tf.reduce_mean(mem_loss) / length_float
+        total_loss = perp_loss + mem_loss
+        gpu_losses[gpu].append(true_perp)
+
+        # Gradients.
+        if backward:
+          data.print_out("Creating backward pass for the model.")
+          grads = tf.gradients(
+              total_loss, tf.trainable_variables(),
+              colocate_gradients_with_ops=True)
+          for g_i, g in enumerate(grads):
+            if isinstance(g, tf.IndexedSlices):
+              grads[g_i] = tf.convert_to_tensor(g)
+          grads, norm = tf.clip_by_global_norm(grads, max_grad_norm)
+          gpu_grad_norms[gpu].append(norm)
+          for g in grads:
+            if grad_noise_scale > 0.001:
+              g += tf.truncated_normal(tf.shape(g)) * self.noise_param
+          grads_list.append(grads)
+        else:
+          gpu_grad_norms[gpu].append(0.0)
+        data.print_out("Created model for gpu %d in %.2f s."
+                       % (gpu, time.time() - start_time))
 
-      # Embed inputs and calculate mask.
-      with tf.device("/cpu:0"):
-        with tf.control_dependencies([e0]):
-          embedded = [tf.nn.embedding_lookup(emb_weights, self.input[l])
-                      for l in xrange(length)]
-        # Mask to 0-out padding space in each step.
-        imask = [check_for_zero(self.input[l]) for l in xrange(length)]
-        omask = [check_for_zero(self.target[l]) for l in xrange(length)]
-        mask = [1.0 - (imask[i] * omask[i]) for i in xrange(length)]
-        mask = [tf.reshape(m, [-1, 1]) for m in mask]
-        # Use a shifted mask for step scaling and concatenated for weights.
-        shifted_mask = mask + [tf.zeros_like(mask[0])]
-        scales = [shifted_mask[i] * (1.0 - shifted_mask[i+1])
-                  for i in xrange(length)]
-        scales = [tf.reshape(s, [-1, 1, 1, 1]) for s in scales]
-        mask = tf.concat(1, mask[0:length])  # batch x length
-        weights = mask
-        # Add a height dimension to mask to use later for masking.
-        mask = tf.reshape(mask, [-1, length, 1, 1])
-        mask = tf.concat(2, [mask for _ in xrange(height)]) + tf.zeros(
-            tf.pack([batch_size, length, height, nmaps]), dtype=tf.float32)
-
-      # Start is a length-list of batch-by-nmaps tensors, reshape and concat.
-      start = [tf.tanh(embedded[l]) for l in xrange(length)]
-      start = [tf.reshape(start[l], [-1, 1, nmaps]) for l in xrange(length)]
-      start = tf.reshape(tf.concat(1, start), [-1, length, 1, nmaps])
-
-      # First image comes from start by applying one convolution and adding 0s.
-      first = conv_linear(start, 1, 1, vec_size, nmaps, True, 0.0, "input")
-      first = [first] + [tf.zeros(tf.pack([batch_size, length, 1, nmaps]),
-                                  dtype=tf.float32) for _ in xrange(height - 1)]
-      first = tf.concat(2, first)
-
-      # Computation steps.
-      keep_prob = 1.0 - self.do_training * (dropout * 8.0 / float(length))
-      step = [tf.nn.dropout(first, keep_prob) * mask]
-      act_noise_scale = act_noise * self.do_training * self.pull
-      outputs = []
-      for it in xrange(length):
-        with tf.variable_scope("RX%d" % (it % rx_step)) as vs:
-          if it >= rx_step:
-            vs.reuse_variables()
-          cur = step[it]
-          # Do nconvs-many CGRU steps.
-          for layer in xrange(nconvs):
-            cur = conv_gru([], cur, kw, kh, nmaps, cutoff, "cgru_%d" % layer)
-            cur *= mask
-          outputs.append(tf.slice(cur, [0, 0, 0, 0], [-1, -1, 1, -1]))
-          cur = tf.nn.dropout(cur, keep_prob)
-          if act_noise > 0.00001:
-            cur += tf.truncated_normal(tf.shape(cur)) * act_noise_scale
-          step.append(cur * mask)
-
-      self.steps.append([tf.reshape(s, [-1, length, height * nmaps])
-                         for s in step])
-      # Output is the n-th step output; n = current length, as in scales.
-      output = tf.add_n([outputs[i] * scales[i] for i in xrange(length)])
-      # Final convolution to get logits, list outputs.
-      output = conv_linear(output, 1, 1, nmaps, noclass, True, 0.0, "output")
-      output = tf.reshape(output, [-1, length, noclass])
-      external_output = [tf.reshape(o, [-1, noclass])
-                         for o in list(tf.split(1, length, output))]
-      external_output = [tf.nn.softmax(o) for o in external_output]
-      self.outputs.append(external_output)
-
-      # Calculate cross-entropy loss and normalize it.
-      targets = tf.concat(1, [make_dense(self.target[l], noclass)
-                              for l in xrange(length)])
-      targets = tf.reshape(targets, [-1, noclass])
-      xent = tf.reshape(tf.nn.softmax_cross_entropy_with_logits(
-          tf.reshape(output, [-1, noclass]), targets), [-1, length])
-      perp_loss = tf.reduce_sum(xent * weights)
-      perp_loss /= tf.cast(batch_size, dtype=tf.float32)
-      perp_loss /= length
-
-      # Final loss: cross-entropy + shared parameter relaxation part.
-      relax_dist, self.avg_op = relaxed_distance(rx_step)
-      total_loss = perp_loss + relax_dist * self.pull
-      self.losses.append(perp_loss)
-
-      # Gradients and Adam update operation.
-      if length == data_utils.bins[0] or (mode == 0 and
-                                          length < data_utils.bins[-1] + 1):
-        data_utils.print_out("Creating backward for bin of length %d." % length)
-        params = tf.trainable_variables()
-        grads = tf.gradients(total_loss, params)
-        grads, norm = tf.clip_by_global_norm(grads, max_grad_norm)
-        self.grad_norms.append(norm)
-        for grad in grads:
-          if isinstance(grad, tf.Tensor):
-            grad += tf.truncated_normal(tf.shape(grad)) * self.noise_param
-        update = adam.apply_gradients(zip(grads, params),
-                                      global_step=self.global_step)
-        self.updates.append(update)
-      data_utils.print_out("Created model for bin of length %d in"
-                           " %.2f s." % (length, time.time() - start_time))
-    self.saver = tf.train.Saver(tf.all_variables())
-
-  def step(self, sess, inp, target, do_backward, noise_param=None,
-           get_steps=False):
+    self.updates = []
+    self.after_enc_step = tf.concat(self.after_enc_step, 0)  # Concat GPUs.
+    if backward:
+      tf.get_variable_scope()._reuse = False
+      tf.get_variable_scope().set_caching_device(None)
+      grads = [gpu_avg([grads_list[g][i] for g in xrange(num_gpus)])
+               for i in xrange(len(grads_list[0]))]
+      update = adam_update(grads)
+      self.updates.append(update)
+    else:
+      self.updates.append(tf.no_op())
+
+    self.losses = [gpu_avg([gpu_losses[g][i] for g in xrange(num_gpus)])
+                   for i in xrange(len(gpu_losses[0]))]
+    self.out_idx = tf.concat(gpu_out_idx, 0)
+    self.grad_norms = [gpu_avg([gpu_grad_norms[g][i] for g in xrange(num_gpus)])
+                       for i in xrange(len(gpu_grad_norms[0]))]
+    self.outputs = [tf.concat([gpu_outputs[g] for g in xrange(num_gpus)], 1)]
+    self.quantize_op = quantize_weights_op(512, 8)
+    if backward:
+      self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
+
+  def step(self, sess, inp, target, do_backward_in, noise_param=None,
+           beam_size=2, eos_id=2, eos_cost=0.0, update_mem=None, state=None):
     """Run a step of the network."""
-    assert len(inp) == len(target)
-    length = len(target)
+    batch_size, height, length = inp.shape[0], inp.shape[1], inp.shape[2]
+    do_backward = do_backward_in
+    train_mode = True
+    if do_backward_in is None:
+      do_backward = False
+      train_mode = False
+    if update_mem is None:
+      update_mem = do_backward
     feed_in = {}
+    # print "    feeding sequences of length %d" % length
+    if state is None:
+      state = np.zeros([batch_size, length, height, self.nmaps])
+    feed_in[self.prev_step.name] = state
+    feed_in[self.length_tensor.name] = length
     feed_in[self.noise_param.name] = noise_param if noise_param else 0.0
     feed_in[self.do_training.name] = 1.0 if do_backward else 0.0
+    feed_in[self.update_mem.name] = 1 if update_mem else 0
+    if do_backward_in is False:
+      feed_in[self.sampling.name] = 0.0
+    index = 0  # We're dynamic now.
     feed_out = []
-    index = len(data_utils.bins)
-    if length < data_utils.bins[-1] + 1:
-      index = data_utils.bins.index(length)
     if do_backward:
       feed_out.append(self.updates[index])
       feed_out.append(self.grad_norms[index])
-    feed_out.append(self.losses[index])
-    for l in xrange(length):
-      feed_in[self.input[l].name] = inp[l]
-    for l in xrange(length):
-      feed_in[self.target[l].name] = target[l]
-      feed_out.append(self.outputs[index][l])
-    if get_steps:
-      for l in xrange(length+1):
-        feed_out.append(self.steps[index][l])
-    res = sess.run(feed_out, feed_in)
+    if train_mode:
+      feed_out.append(self.losses[index])
+    feed_in[self.input.name] = inp
+    feed_in[self.target.name] = target
+    feed_out.append(self.outputs[index])
+    if train_mode:
+      # Make a full-sequence training step with one call to session.run.
+      res = sess.run([self.after_enc_step] + feed_out, feed_in)
+      after_enc_state, res = res[0], res[1:]
+    else:
+      # Make a full-sequence decoding step with one call to session.run.
+      feed_in[self.sampling.name] = 1.1  # Sample every time.
+      res = sess.run([self.after_enc_step, self.out_idx] + feed_out, feed_in)
+      after_enc_state, out_idx = res[0], res[1]
+      res = [res[2][l] for l in xrange(length)]
+      outputs = [out_idx[:, i] for i in xrange(length)]
+      cost = [0.0 for _ in xrange(beam_size * batch_size)]
+      seen_eos = [0 for _ in xrange(beam_size * batch_size)]
+      for idx, logit in enumerate(res):
+        best = outputs[idx]
+        for b in xrange(batch_size):
+          if seen_eos[b] > 1:
+            cost[b] -= eos_cost
+          else:
+            cost[b] += np.log(logit[b][best[b]])
+          if best[b] in [eos_id]:
+            seen_eos[b] += 1
+      res = [[-c for c in cost]] + outputs
+    # Collect and output results.
     offset = 0
     norm = None
     if do_backward:
       offset = 2
       norm = res[1]
-    outputs = res[offset + 1:offset + 1 + length]
-    steps = res[offset + 1 + length:] if get_steps else None
-    return res[offset], outputs, norm, steps
+    if train_mode:
+      outputs = res[offset + 1]
+      outputs = [outputs[l] for l in xrange(length)]
+    return res[offset], outputs, norm, after_enc_state

Plik diff jest za duży
+ 889 - 326
neural_gpu/neural_gpu_trainer.py


+ 440 - 0
neural_gpu/program_utils.py

@@ -0,0 +1,440 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for generating program synthesis and evaluation data."""
+
+import contextlib
+import sys
+import StringIO
+import random
+import os
+
+class ListType(object):
+  def __init__(self, arg):
+    self.arg = arg
+
+  def __str__(self):
+    return "[" + str(self.arg) + "]"
+
+  def __eq__(self, other):
+    if not isinstance(other, ListType):
+      return False
+    return self.arg == other.arg
+  
+  def __hash__(self):
+    return hash(self.arg)
+
+class VarType(object):
+  def __init__(self, arg):
+    self.arg = arg
+
+  def __str__(self):
+    return str(self.arg)
+
+  def __eq__(self, other):
+    if not isinstance(other, VarType):
+      return False
+    return self.arg == other.arg
+
+  def __hash__(self):
+    return hash(self.arg)
+
+class FunctionType(object):
+  def __init__(self, args):
+    self.args = args
+
+  def __str__(self):
+    return str(self.args[0]) + " -> " + str(self.args[1])
+
+  def __eq__(self, other):
+    if not isinstance(other, FunctionType):
+      return False
+    return self.args == other.args
+
+  def __hash__(self):
+    return hash(tuple(self.args))
+
+
+class Function(object):
+  def __init__(self, name, arg_types, output_type, fn_arg_types = None):
+    self.name = name 
+    self.arg_types = arg_types
+    self.fn_arg_types = fn_arg_types or []
+    self.output_type = output_type
+
+Null = 100
+## Functions
+f_head = Function("c_head", [ListType("Int")], "Int")
+def c_head(xs): return xs[0] if len(xs) > 0 else Null
+
+f_last = Function("c_last", [ListType("Int")], "Int")
+def c_last(xs): return xs[-1] if len(xs) > 0 else Null
+
+f_take = Function("c_take", ["Int", ListType("Int")], ListType("Int"))
+def c_take(n, xs): return xs[:n]
+
+f_drop = Function("c_drop", ["Int", ListType("Int")], ListType("Int"))
+def c_drop(n, xs): return xs[n:]
+
+f_access = Function("c_access", ["Int", ListType("Int")], "Int")
+def c_access(n, xs): return xs[n] if n >= 0 and len(xs) > n else Null
+
+f_max = Function("c_max", [ListType("Int")], "Int")
+def c_max(xs): return max(xs) if len(xs) > 0 else Null
+
+f_min = Function("c_min", [ListType("Int")], "Int")
+def c_min(xs): return min(xs) if len(xs) > 0 else Null
+
+f_reverse = Function("c_reverse", [ListType("Int")], ListType("Int"))
+def c_reverse(xs): return list(reversed(xs))
+
+f_sort = Function("sorted", [ListType("Int")], ListType("Int"))
+# def c_sort(xs): return sorted(xs)
+
+f_sum = Function("sum", [ListType("Int")], "Int")
+# def c_sum(xs): return sum(xs)
+
+
+## Lambdas
+# Int -> Int
+def plus_one(x): return x + 1
+def minus_one(x): return x - 1
+def times_two(x): return x * 2
+def neg(x): return x * (-1)
+def div_two(x): return int(x/2)
+def sq(x): return x**2 
+def times_three(x): return x * 3
+def div_three(x): return int(x/3)
+def times_four(x): return x * 4
+def div_four(x): return int(x/4)
+
+# Int -> Bool 
+def pos(x): return x > 0 
+def neg(x): return x < 0
+def even(x): return x%2 == 0
+def odd(x): return x%2 == 1
+
+# Int -> Int -> Int
+def add(x, y): return x + y
+def sub(x, y): return x - y
+def mul(x, y): return x * y
+
+# HOFs
+f_map = Function("map", [ListType("Int")], 
+                        ListType("Int"), 
+                        [FunctionType(["Int", "Int"])])
+f_filter = Function("filter", [ListType("Int")], 
+                              ListType("Int"), 
+                              [FunctionType(["Int", "Bool"])])
+f_count = Function("c_count", [ListType("Int")], 
+                              "Int", 
+                              [FunctionType(["Int", "Bool"])])
+def c_count(f, xs): return len([x for x in xs if f(x)])
+
+f_zipwith = Function("c_zipwith", [ListType("Int"), ListType("Int")], 
+                                  ListType("Int"), 
+                                  [FunctionType(["Int", "Int", "Int"])]) #FIX
+def c_zipwith(f, xs, ys): return [f(x, y) for (x, y) in zip(xs, ys)]
+
+f_scan = Function("c_scan", [ListType("Int")],
+                            ListType("Int"), 
+                            [FunctionType(["Int", "Int", "Int"])])
+def c_scan(f, xs):
+  out = xs
+  for i in range(1, len(xs)):
+    out[i] = f(xs[i], xs[i -1])
+  return out
+
+@contextlib.contextmanager
+def stdoutIO(stdout=None):
+  old = sys.stdout
+  if stdout is None:
+    stdout = StringIO.StringIO()
+  sys.stdout = stdout
+  yield stdout
+  sys.stdout = old
+
+
+def evaluate(program_str, input_names_to_vals, default="ERROR"):
+  exec_str = []
+  for name, val in input_names_to_vals.iteritems():
+    exec_str += name + " = " + str(val) + "; "
+  exec_str += program_str
+  if type(exec_str) is list:
+    exec_str = "".join(exec_str)
+
+  with stdoutIO() as s:
+    # pylint: disable=bare-except
+    try:
+      exec exec_str + " print(out)"
+      return s.getvalue()[:-1]
+    except:
+      return default
+   # pylint: enable=bare-except
+
+
+class Statement(object):
+  """Statement class."""
+  
+  def __init__(self, fn, output_var, arg_vars, fn_args=None):
+    self.fn = fn
+    self.output_var = output_var
+    self.arg_vars = arg_vars
+    self.fn_args = fn_args or []
+
+  def __str__(self):
+    return "%s = %s(%s%s%s)"%(self.output_var,
+                              self.fn.name,
+                              ", ".join(self.fn_args),
+                              ", " if self.fn_args else "",
+                              ", ".join(self.arg_vars))
+
+  def substitute(self, env):
+    self.output_var = env.get(self.output_var, self.output_var)
+    self.arg_vars = [env.get(v, v) for v in self.arg_vars]
+
+
+class ProgramGrower(object):
+  """Grow programs."""
+
+  def __init__(self, functions, types_to_lambdas):
+    self.functions = functions
+    self.types_to_lambdas = types_to_lambdas
+
+  def grow_body(self, new_var_name, dependencies, types_to_vars):
+    """Grow the program body."""
+    choices = []
+    for f in self.functions:
+      if all([a in types_to_vars.keys() for a in f.arg_types]):
+        choices.append(f)
+
+    f = random.choice(choices)
+    args = []
+    for t in f.arg_types:
+      possible_vars = random.choice(types_to_vars[t])
+      var = random.choice(possible_vars)
+      args.append(var)
+      dependencies.setdefault(new_var_name, []).extend(
+          [var] + (dependencies[var]))
+
+    fn_args = [random.choice(self.types_to_lambdas[t]) for t in f.fn_arg_types]
+    types_to_vars.setdefault(f.output_type, []).append(new_var_name)
+
+    return Statement(f, new_var_name, args, fn_args)
+
+  def grow(self, program_len, input_types):
+    """Grow the program."""
+    var_names = list(reversed(map(chr, range(97, 123))))
+    dependencies = dict()
+    types_to_vars = dict()
+    input_names = []
+    for t in input_types:
+      var = var_names.pop()
+      dependencies[var] = []
+      types_to_vars.setdefault(t, []).append(var)
+      input_names.append(var)
+
+    statements = []
+    for _ in range(program_len - 1):
+      var = var_names.pop()
+      statements.append(self.grow_body(var, dependencies, types_to_vars))
+    statements.append(self.grow_body("out", dependencies, types_to_vars))
+
+    new_var_names = [c for c in map(chr, range(97, 123))
+                     if c not in input_names]
+    new_var_names.reverse()
+    keep_statements = []
+    env = dict()
+    for s in statements:
+      if s.output_var in dependencies["out"]:
+        keep_statements.append(s)
+        env[s.output_var] = new_var_names.pop()
+      if s.output_var == "out":
+        keep_statements.append(s)
+
+    for k in keep_statements:
+      k.substitute(env)
+
+    return Program(input_names, input_types, ";".join(
+        [str(k) for k in keep_statements]))
+
+
+class Program(object):
+  """The program class."""
+
+  def __init__(self, input_names, input_types, body):
+    self.input_names = input_names
+    self.input_types = input_types
+    self.body = body
+
+  def evaluate(self, inputs):
+    """Evaluate this program."""
+    if len(inputs) != len(self.input_names):
+      raise AssertionError("inputs and input_names have to"
+                           "have the same len. inp: %s , names: %s" %
+                           (str(inputs), str(self.input_names)))
+    inp_str = ""
+    for (name, inp) in zip(self.input_names, inputs):
+      inp_str += name + " = " + str(inp) + "; "
+
+    with stdoutIO() as s:
+      # pylint: disable=exec-used
+      exec inp_str + self.body + "; print(out)"
+      # pylint: enable=exec-used
+    return s.getvalue()[:-1]
+
+  def flat_str(self):
+    out = ""
+    for s in self.body.split(";"):
+      out += s + ";"
+    return out
+
+  def __str__(self):
+    out = ""
+    for (n, t) in zip(self.input_names, self.input_types):
+      out += n + " = " + str(t) + "\n"
+    for s in self.body.split(";"):
+      out += s + "\n"
+    return out
+
+
+prog_vocab = []
+prog_rev_vocab = {}
+
+
+def tokenize(string, tokens=None):
+  """Tokenize the program string."""
+  if tokens is None:
+    tokens = prog_vocab
+  tokens = sorted(tokens, key=len, reverse=True)
+  out = []
+  string = string.strip()
+  while string:
+    found = False
+    for t in tokens:
+      if string.startswith(t):
+        out.append(t)
+        string = string[len(t):]
+        found = True
+        break
+    if not found:
+      raise ValueError("Couldn't tokenize this: " + string)
+    string = string.strip()
+  return out
+
+
+def clean_up(output, max_val=100):
+  o = eval(str(output))
+  if isinstance(o, bool):
+    return o
+  if isinstance(o, int):
+    if o >= 0:
+      return min(o, max_val)
+    else:
+      return max(o, -1 * max_val)
+  if isinstance(o, list):
+    return [clean_up(l) for l in o]
+
+
+def make_vocab():
+  gen(2, 0)
+
+
+def gen(max_len, how_many):
+  """Generate some programs."""
+  functions = [f_head, f_last, f_take, f_drop, f_access, f_max, f_min,
+               f_reverse, f_sort, f_sum, f_map, f_filter, f_count, f_zipwith,
+               f_scan]
+
+  types_to_lambdas = {
+      FunctionType(["Int", "Int"]): ["plus_one", "minus_one", "times_two",
+                                     "div_two", "sq", "times_three",
+                                     "div_three", "times_four", "div_four"],
+      FunctionType(["Int", "Bool"]): ["pos", "neg", "even", "odd"],
+      FunctionType(["Int", "Int", "Int"]): ["add", "sub", "mul"]
+  }
+
+  tokens = []
+  for f in functions:
+    tokens.append(f.name)
+  for v in types_to_lambdas.values():
+    tokens.extend(v)
+  tokens.extend(["=", ";", ",", "(", ")", "[", "]", "Int", "out"])
+  tokens.extend(map(chr, range(97, 123)))
+
+  io_tokens = map(str, range(-220, 220))
+  if not prog_vocab:
+    prog_vocab.extend(["_PAD", "_EOS"] + tokens + io_tokens)
+    for i, t in enumerate(prog_vocab):
+      prog_rev_vocab[t] = i
+
+  io_tokens += [",", "[", "]", ")", "(", "None"]
+  grower = ProgramGrower(functions=functions,
+                         types_to_lambdas=types_to_lambdas)
+
+  def mk_inp(l):
+    return [random.choice(range(-5, 5)) for _ in range(l)]
+
+  tar = [ListType("Int")]
+  inps = [[mk_inp(3)], [mk_inp(5)], [mk_inp(7)], [mk_inp(15)]]
+
+  save_prefix = None
+  outcomes_to_programs = dict()
+  tried = set()
+  counter = 0
+  choices = [0] if max_len == 0 else range(max_len)
+  while counter < 100 * how_many and len(outcomes_to_programs) < how_many:
+    counter += 1
+    length = random.choice(choices)
+    t = grower.grow(length, tar)
+    while t in tried:
+      length = random.choice(choices)
+      t = grower.grow(length, tar)
+    # print(t.flat_str())
+    tried.add(t)
+    outcomes = [clean_up(t.evaluate(i)) for i in inps]
+    outcome_str = str(zip(inps, outcomes))
+    if outcome_str in outcomes_to_programs:
+      outcomes_to_programs[outcome_str] = min(
+          [t.flat_str(), outcomes_to_programs[outcome_str]],
+          key=lambda x: len(tokenize(x, tokens)))
+    else:
+      outcomes_to_programs[outcome_str] = t.flat_str()
+    if counter % 5000 == 0:
+      print "== proggen: tried: " + str(counter)
+      print "== proggen: kept:  " + str(len(outcomes_to_programs))
+
+    if counter % 250000 == 0 and save_prefix is not None:
+      print "saving..."
+      save_counter = 0
+      progfilename = os.path.join(save_prefix, "prog_" + str(counter) + ".txt")
+      iofilename = os.path.join(save_prefix, "io_" + str(counter) + ".txt")
+      prog_token_filename = os.path.join(save_prefix,
+                                         "prog_tokens_" + str(counter) + ".txt")
+      io_token_filename = os.path.join(save_prefix,
+                                       "io_tokens_" + str(counter) + ".txt")
+      with open(progfilename, "a+") as fp,  \
+           open(iofilename, "a+") as fi, \
+           open(prog_token_filename, "a+") as ftp, \
+           open(io_token_filename, "a+") as fti:
+        for (o, p) in outcomes_to_programs.iteritems():
+          save_counter += 1
+          if save_counter % 500 == 0:
+            print "saving %d of %d" % (save_counter, len(outcomes_to_programs))
+          fp.write(p+"\n")
+          fi.write(o+"\n")
+          ftp.write(str(tokenize(p, tokens))+"\n")
+          fti.write(str(tokenize(o, io_tokens))+"\n")
+
+  return list(outcomes_to_programs.values())

+ 435 - 0
neural_gpu/wmt_utils.py

@@ -0,0 +1,435 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for downloading data from WMT, tokenizing, vocabularies."""
+
+import gzip
+import os
+import re
+import tarfile
+
+from six.moves import urllib
+import tensorflow as tf
+
+# Special vocabulary symbols - we always put them at the start.
+_PAD = b"_PAD"
+_GO = b"_GO"
+_EOS = b"_EOS"
+_UNK = b"_CHAR_UNK"
+_SPACE = b"_SPACE"
+_START_VOCAB = [_PAD, _GO, _EOS, _UNK, _SPACE]
+
+PAD_ID = 0
+GO_ID = 1
+EOS_ID = 2
+UNK_ID = 3
+SPACE_ID = 4
+
+# Regular expressions used to tokenize.
+_CHAR_MARKER = "_CHAR_"
+_CHAR_MARKER_LEN = len(_CHAR_MARKER)
+_SPEC_CHARS = "" + chr(226) + chr(153) + chr(128)
+_PUNCTUATION = "][.,!?\"':;%$#@&*+}{|><=/^~)(_`,0123456789" + _SPEC_CHARS + "-"
+_WORD_SPLIT = re.compile(b"([" + _PUNCTUATION + "])")
+_OLD_WORD_SPLIT = re.compile(b"([.,!?\"':;)(])")
+_DIGIT_RE = re.compile(br"\d")
+
+# URLs for WMT data.
+_WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/training-giga-fren.tar"
+_WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/dev-v2.tgz"
+
+
+def maybe_download(directory, filename, url):
+  """Download filename from url unless it's already in directory."""
+  if not tf.gfile.Exists(directory):
+    print "Creating directory %s" % directory
+    os.mkdir(directory)
+  filepath = os.path.join(directory, filename)
+  if not tf.gfile.Exists(filepath):
+    print "Downloading %s to %s" % (url, filepath)
+    filepath, _ = urllib.request.urlretrieve(url, filepath)
+    statinfo = os.stat(filepath)
+    print "Succesfully downloaded", filename, statinfo.st_size, "bytes"
+  return filepath
+
+
+def gunzip_file(gz_path, new_path):
+  """Unzips from gz_path into new_path."""
+  print "Unpacking %s to %s" % (gz_path, new_path)
+  with gzip.open(gz_path, "rb") as gz_file:
+    with open(new_path, "wb") as new_file:
+      for line in gz_file:
+        new_file.write(line)
+
+
+def get_wmt_enfr_train_set(directory):
+  """Download the WMT en-fr training corpus to directory unless it's there."""
+  train_path = os.path.join(directory, "giga-fren.release2.fixed")
+  if not (tf.gfile.Exists(train_path +".fr") and
+          tf.gfile.Exists(train_path +".en")):
+    corpus_file = maybe_download(directory, "training-giga-fren.tar",
+                                 _WMT_ENFR_TRAIN_URL)
+    print "Extracting tar file %s" % corpus_file
+    with tarfile.open(corpus_file, "r") as corpus_tar:
+      corpus_tar.extractall(directory)
+    gunzip_file(train_path + ".fr.gz", train_path + ".fr")
+    gunzip_file(train_path + ".en.gz", train_path + ".en")
+  return train_path
+
+
+def get_wmt_enfr_dev_set(directory):
+  """Download the WMT en-fr training corpus to directory unless it's there."""
+  dev_name = "newstest2013"
+  dev_path = os.path.join(directory, dev_name)
+  if not (tf.gfile.Exists(dev_path + ".fr") and
+          tf.gfile.Exists(dev_path + ".en")):
+    dev_file = maybe_download(directory, "dev-v2.tgz", _WMT_ENFR_DEV_URL)
+    print "Extracting tgz file %s" % dev_file
+    with tarfile.open(dev_file, "r:gz") as dev_tar:
+      fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr")
+      en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en")
+      fr_dev_file.name = dev_name + ".fr"  # Extract without "dev/" prefix.
+      en_dev_file.name = dev_name + ".en"
+      dev_tar.extract(fr_dev_file, directory)
+      dev_tar.extract(en_dev_file, directory)
+  return dev_path
+
+
+def is_char(token):
+  if len(token) > _CHAR_MARKER_LEN:
+    if token[:_CHAR_MARKER_LEN] == _CHAR_MARKER:
+      return True
+  return False
+
+
+def basic_detokenizer(tokens):
+  """Reverse the process of the basic tokenizer below."""
+  result = []
+  previous_nospace = True
+  for t in tokens:
+    if is_char(t):
+      result.append(t[_CHAR_MARKER_LEN:])
+      previous_nospace = True
+    elif t == _SPACE:
+      result.append(" ")
+      previous_nospace = True
+    elif previous_nospace:
+      result.append(t)
+      previous_nospace = False
+    else:
+      result.extend([" ", t])
+      previous_nospace = False
+  return "".join(result)
+
+
+old_style = False
+
+
+def basic_tokenizer(sentence):
+  """Very basic tokenizer: split the sentence into a list of tokens."""
+  words = []
+  if old_style:
+    for space_separated_fragment in sentence.strip().split():
+      words.extend(re.split(_OLD_WORD_SPLIT, space_separated_fragment))
+    return [w for w in words if w]
+  for space_separated_fragment in sentence.strip().split():
+    tokens = [t for t in re.split(_WORD_SPLIT, space_separated_fragment) if t]
+    first_is_char = False
+    for i, t in enumerate(tokens):
+      if len(t) == 1 and t in _PUNCTUATION:
+        tokens[i] = _CHAR_MARKER + t
+        if i == 0:
+          first_is_char = True
+    if words and words[-1] != _SPACE and (first_is_char or is_char(words[-1])):
+      tokens = [_SPACE] + tokens
+    spaced_tokens = []
+    for i, tok in enumerate(tokens):
+      spaced_tokens.append(tokens[i])
+      if i < len(tokens) - 1:
+        if tok != _SPACE and not (is_char(tok) or is_char(tokens[i+1])):
+          spaced_tokens.append(_SPACE)
+    words.extend(spaced_tokens)
+  return words
+
+
+def space_tokenizer(sentence):
+  return sentence.strip().split()
+
+
+def is_pos_tag(token):
+  """Check if token is a part-of-speech tag."""
+  return(token in ["CC", "CD", "DT", "EX", "FW", "IN", "JJ", "JJR",
+                   "JJS", "LS", "MD", "NN", "NNS", "NNP", "NNPS", "PDT",
+                   "POS", "PRP", "PRP$", "RB", "RBR", "RBS", "RP", "SYM", "TO",
+                   "UH", "VB", "VBD", "VBG", "VBN", "VBP", "VBZ", "WDT", "WP",
+                   "WP$", "WRB", ".", ",", ":", ")", "-LRB-", "(", "-RRB-",
+                   "HYPH", "$", "``", "''", "ADD", "AFX", "QTR", "BES", "-DFL-",
+                   "GW", "HVS", "NFP"])
+
+
+def parse_constraints(inpt, res):
+  ntags = len(res)
+  nwords = len(inpt)
+  npostags = len([x for x in res if is_pos_tag(x)])
+  nclose = len([x for x in res if x[0] == "/"])
+  nopen = ntags - nclose - npostags
+  return (abs(npostags - nwords), abs(nclose - nopen))
+
+
+def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size,
+                      tokenizer=None, normalize_digits=False):
+  """Create vocabulary file (if it does not exist yet) from data file.
+
+  Data file is assumed to contain one sentence per line. Each sentence is
+  tokenized and digits are normalized (if normalize_digits is set).
+  Vocabulary contains the most-frequent tokens up to max_vocabulary_size.
+  We write it to vocabulary_path in a one-token-per-line format, so that later
+  token in the first line gets id=0, second line gets id=1, and so on.
+
+  Args:
+    vocabulary_path: path where the vocabulary will be created.
+    data_path: data file that will be used to create vocabulary.
+    max_vocabulary_size: limit on the size of the created vocabulary.
+    tokenizer: a function to use to tokenize each data sentence;
+      if None, basic_tokenizer will be used.
+    normalize_digits: Boolean; if true, all digits are replaced by 0s.
+  """
+  if not tf.gfile.Exists(vocabulary_path):
+    print "Creating vocabulary %s from data %s" % (vocabulary_path, data_path)
+    vocab, chars = {}, {}
+    for c in _PUNCTUATION:
+      chars[c] = 1
+
+    # Read French file.
+    with tf.gfile.GFile(data_path + ".fr", mode="rb") as f:
+      counter = 0
+      for line_in in f:
+        line = " ".join(line_in.split())
+        counter += 1
+        if counter % 100000 == 0:
+          print "  processing fr line %d" % counter
+        for c in line:
+          if c in chars:
+            chars[c] += 1
+          else:
+            chars[c] = 1
+        tokens = tokenizer(line) if tokenizer else basic_tokenizer(line)
+        tokens = [t for t in tokens if not is_char(t) and t != _SPACE]
+        for w in tokens:
+          word = re.sub(_DIGIT_RE, b"0", w) if normalize_digits else w
+          if word in vocab:
+            vocab[word] += 1000000000  # We want target words first.
+          else:
+            vocab[word] = 1000000000
+
+    # Read English file.
+    with tf.gfile.GFile(data_path + ".en", mode="rb") as f:
+      counter = 0
+      for line_in in f:
+        line = " ".join(line_in.split())
+        counter += 1
+        if counter % 100000 == 0:
+          print "  processing en line %d" % counter
+        for c in line:
+          if c in chars:
+            chars[c] += 1
+          else:
+            chars[c] = 1
+        tokens = tokenizer(line) if tokenizer else basic_tokenizer(line)
+        tokens = [t for t in tokens if not is_char(t) and t != _SPACE]
+        for w in tokens:
+          word = re.sub(_DIGIT_RE, b"0", w) if normalize_digits else w
+          if word in vocab:
+            vocab[word] += 1
+          else:
+            vocab[word] = 1
+
+      sorted_vocab = sorted(vocab, key=vocab.get, reverse=True)
+      sorted_chars = sorted(chars, key=vocab.get, reverse=True)
+      sorted_chars = [_CHAR_MARKER + c for c in sorted_chars]
+      vocab_list = _START_VOCAB + sorted_chars + sorted_vocab
+      if tokenizer:
+        vocab_list = _START_VOCAB + sorted_vocab
+      if len(vocab_list) > max_vocabulary_size:
+        vocab_list = vocab_list[:max_vocabulary_size]
+      with tf.gfile.GFile(vocabulary_path, mode="wb") as vocab_file:
+        for w in vocab_list:
+          vocab_file.write(w + b"\n")
+
+
+def initialize_vocabulary(vocabulary_path):
+  """Initialize vocabulary from file.
+
+  We assume the vocabulary is stored one-item-per-line, so a file:
+    dog
+    cat
+  will result in a vocabulary {"dog": 0, "cat": 1}, and this function will
+  also return the reversed-vocabulary ["dog", "cat"].
+
+  Args:
+    vocabulary_path: path to the file containing the vocabulary.
+
+  Returns:
+    a pair: the vocabulary (a dictionary mapping string to integers), and
+    the reversed vocabulary (a list, which reverses the vocabulary mapping).
+
+  Raises:
+    ValueError: if the provided vocabulary_path does not exist.
+  """
+  if tf.gfile.Exists(vocabulary_path):
+    rev_vocab = []
+    with tf.gfile.GFile(vocabulary_path, mode="rb") as f:
+      rev_vocab.extend(f.readlines())
+    rev_vocab = [line.strip() for line in rev_vocab]
+    vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)])
+    return vocab, rev_vocab
+  else:
+    raise ValueError("Vocabulary file %s not found.", vocabulary_path)
+
+
+def sentence_to_token_ids_raw(sentence, vocabulary,
+                              tokenizer=None, normalize_digits=old_style):
+  """Convert a string to list of integers representing token-ids.
+
+  For example, a sentence "I have a dog" may become tokenized into
+  ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2,
+  "a": 4, "dog": 7"} this function will return [1, 2, 4, 7].
+
+  Args:
+    sentence: the sentence in bytes format to convert to token-ids.
+    vocabulary: a dictionary mapping tokens to integers.
+    tokenizer: a function to use to tokenize each sentence;
+      if None, basic_tokenizer will be used.
+    normalize_digits: Boolean; if true, all digits are replaced by 0s.
+
+  Returns:
+    a list of integers, the token-ids for the sentence.
+  """
+  if tokenizer:
+    words = tokenizer(sentence)
+  else:
+    words = basic_tokenizer(sentence)
+  result = []
+  for w in words:
+    if normalize_digits:
+      w = re.sub(_DIGIT_RE, b"0", w)
+    if w in vocabulary:
+      result.append(vocabulary[w])
+    else:
+      if tokenizer:
+        result.append(UNK_ID)
+      else:
+        result.append(SPACE_ID)
+        for c in w:
+          result.append(vocabulary.get(_CHAR_MARKER + c, UNK_ID))
+        result.append(SPACE_ID)
+  while result and result[0] == SPACE_ID:
+    result = result[1:]
+  while result and result[-1] == SPACE_ID:
+    result = result[:-1]
+  return result
+
+
+def sentence_to_token_ids(sentence, vocabulary,
+                          tokenizer=None, normalize_digits=old_style):
+  """Convert a string to list of integers representing token-ids, tab=0."""
+  tab_parts = sentence.strip().split("\t")
+  toks = [sentence_to_token_ids_raw(t, vocabulary, tokenizer, normalize_digits)
+          for t in tab_parts]
+  res = []
+  for t in toks:
+    res.extend(t)
+    res.append(0)
+  return res[:-1]
+
+
+def data_to_token_ids(data_path, target_path, vocabulary_path,
+                      tokenizer=None, normalize_digits=False):
+  """Tokenize data file and turn into token-ids using given vocabulary file.
+
+  This function loads data line-by-line from data_path, calls the above
+  sentence_to_token_ids, and saves the result to target_path. See comment
+  for sentence_to_token_ids on the details of token-ids format.
+
+  Args:
+    data_path: path to the data file in one-sentence-per-line format.
+    target_path: path where the file with token-ids will be created.
+    vocabulary_path: path to the vocabulary file.
+    tokenizer: a function to use to tokenize each sentence;
+      if None, basic_tokenizer will be used.
+    normalize_digits: Boolean; if true, all digits are replaced by 0s.
+  """
+  if not tf.gfile.Exists(target_path):
+    print "Tokenizing data in %s" % data_path
+    vocab, _ = initialize_vocabulary(vocabulary_path)
+    with tf.gfile.GFile(data_path, mode="rb") as data_file:
+      with tf.gfile.GFile(target_path, mode="w") as tokens_file:
+        counter = 0
+        for line in data_file:
+          counter += 1
+          if counter % 100000 == 0:
+            print "  tokenizing line %d" % counter
+          token_ids = sentence_to_token_ids(line, vocab, tokenizer,
+                                            normalize_digits)
+          tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")
+
+
+def prepare_wmt_data(data_dir, vocabulary_size,
+                     tokenizer=None, normalize_digits=False):
+  """Get WMT data into data_dir, create vocabularies and tokenize data.
+
+  Args:
+    data_dir: directory in which the data sets will be stored.
+    vocabulary_size: size of the joint vocabulary to create and use.
+    tokenizer: a function to use to tokenize each data sentence;
+      if None, basic_tokenizer will be used.
+    normalize_digits: Boolean; if true, all digits are replaced by 0s.
+
+  Returns:
+    A tuple of 6 elements:
+      (1) path to the token-ids for English training data-set,
+      (2) path to the token-ids for French training data-set,
+      (3) path to the token-ids for English development data-set,
+      (4) path to the token-ids for French development data-set,
+      (5) path to the vocabulary file,
+      (6) path to the vocabulary file (for compatibility with non-joint vocab).
+  """
+  # Get wmt data to the specified directory.
+  train_path = get_wmt_enfr_train_set(data_dir)
+  dev_path = get_wmt_enfr_dev_set(data_dir)
+
+  # Create vocabularies of the appropriate sizes.
+  vocab_path = os.path.join(data_dir, "vocab%d.txt" % vocabulary_size)
+  create_vocabulary(vocab_path, train_path, vocabulary_size,
+                    tokenizer=tokenizer, normalize_digits=normalize_digits)
+
+  # Create token ids for the training data.
+  fr_train_ids_path = train_path + (".ids%d.fr" % vocabulary_size)
+  en_train_ids_path = train_path + (".ids%d.en" % vocabulary_size)
+  data_to_token_ids(train_path + ".fr", fr_train_ids_path, vocab_path,
+                    tokenizer=tokenizer, normalize_digits=normalize_digits)
+  data_to_token_ids(train_path + ".en", en_train_ids_path, vocab_path,
+                    tokenizer=tokenizer, normalize_digits=normalize_digits)
+
+  # Create token ids for the development data.
+  fr_dev_ids_path = dev_path + (".ids%d.fr" % vocabulary_size)
+  en_dev_ids_path = dev_path + (".ids%d.en" % vocabulary_size)
+  data_to_token_ids(dev_path + ".fr", fr_dev_ids_path, vocab_path,
+                    tokenizer=tokenizer, normalize_digits=normalize_digits)
+  data_to_token_ids(dev_path + ".en", en_dev_ids_path, vocab_path,
+                    tokenizer=tokenizer, normalize_digits=normalize_digits)
+
+  return (en_train_ids_path, fr_train_ids_path,
+          en_dev_ids_path, fr_dev_ids_path,
+          vocab_path, vocab_path)