Просмотр исходного кода

Cleaned-up version of the Neural GPU code (runs outside of google3).

Lukasz Kaiser 9 лет назад
Родитель
Сommit
90d6e3b97b
4 измененных файлов с 130 добавлено и 140 удалено
  1. 6 0
      neural_gpu/README.md
  2. 16 19
      neural_gpu/data_utils.py
  3. 1 19
      neural_gpu/neural_gpu.py
  4. 107 102
      neural_gpu/neural_gpu_trainer.py

+ 6 - 0
neural_gpu/README.md

@@ -2,4 +2,10 @@
 Code for the Neural GPU model as described
 Code for the Neural GPU model as described
 in [[http://arxiv.org/abs/1511.08228]].
 in [[http://arxiv.org/abs/1511.08228]].
 
 
+Requirements:
+* TensorFlow (see tensorflow.org for how to install)
+* Matplotlib for Python (sudo apt-get install python-matplotlib)
+
+Run: python neural_gpu_trainer.py --task=rev
+
 Maintained by Lukasz Kaiser (lukaszkaiser)
 Maintained by Lukasz Kaiser (lukaszkaiser)

+ 16 - 19
neural_gpu/data_utils.py

@@ -1,19 +1,3 @@
-# Copyright 2015 Google Inc. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-#==============================================================================
-
 """Convolutional Gated Recurrent Networks for Algorithm Learning."""
 """Convolutional Gated Recurrent Networks for Algorithm Learning."""
 
 
 import math
 import math
@@ -21,12 +5,10 @@ import random
 import sys
 import sys
 import time
 import time
 
 
-import google3
-
 import numpy as np
 import numpy as np
 import tensorflow as tf
 import tensorflow as tf
 
 
-from google3.third_party.tensorflow.python.platform import gfile
+from tensorflow.python.platform import gfile
 
 
 FLAGS = tf.app.flags.FLAGS
 FLAGS = tf.app.flags.FLAGS
 
 
@@ -162,6 +144,21 @@ def init_data(task, length, nbr_cases, nclass):
       test_set[task][l].append([inp, target])
       test_set[task][l].append([inp, target])
 
 
 
 
+def to_symbol(i):
+  """Covert ids to text."""
+  if i == 0: return ""
+  if i == 11: return "+"
+  if i == 12: return "*"
+  return str(i-1)
+
+
+def to_id(s):
+  """Covert text to ids."""
+  if s == "+": return 11
+  if s == "*": return 12
+  return int(s) + 1
+
+
 def get_batch(max_length, batch_size, do_train, task, offset=None, preset=None):
 def get_batch(max_length, batch_size, do_train, task, offset=None, preset=None):
   """Get a batch of data, training or testing."""
   """Get a batch of data, training or testing."""
   inputs = []
   inputs = []

+ 1 - 19
neural_gpu/neural_gpu.py

@@ -1,28 +1,10 @@
-# Copyright 2015 Google Inc. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-#==============================================================================
-
 """The Neural GPU Model."""
 """The Neural GPU Model."""
 
 
 import time
 import time
 
 
-import google3
-
 import tensorflow as tf
 import tensorflow as tf
 
 
-from google3.experimental.users.lukaszkaiser.neural_gpu import data_utils
+import data_utils
 
 
 
 
 def conv_linear(args, kw, kh, nin, nout, do_bias, bias_start, prefix):
 def conv_linear(args, kw, kh, nin, nout, do_bias, bias_start, prefix):

+ 107 - 102
neural_gpu/neural_gpu_trainer.py

@@ -1,19 +1,3 @@
-# Copyright 2015 Google Inc. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-#==============================================================================
-
 """Neural GPU for Learning Algorithms."""
 """Neural GPU for Learning Algorithms."""
 
 
 import math
 import math
@@ -22,16 +6,15 @@ import random
 import sys
 import sys
 import time
 import time
 
 
-import google3
-
 import matplotlib.animation as anim
 import matplotlib.animation as anim
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
 import numpy as np
 import numpy as np
 import tensorflow as tf
 import tensorflow as tf
 
 
-from google3.third_party.tensorflow.python.platform import gfile
-import google3.experimental.users.lukaszkaiser.neural_gpu.data_utils as data
-import google3.experimental.users.lukaszkaiser.neural_gpu.neural_gpu as ngpu
+from tensorflow.python.platform import gfile
+
+import data_utils as data
+import neural_gpu
 
 
 tf.app.flags.DEFINE_float("lr", 0.1, "Learning rate.")
 tf.app.flags.DEFINE_float("lr", 0.1, "Learning rate.")
 tf.app.flags.DEFINE_float("init_weight", 1.0, "Initial weights deviation.")
 tf.app.flags.DEFINE_float("init_weight", 1.0, "Initial weights deviation.")
@@ -39,7 +22,7 @@ tf.app.flags.DEFINE_float("max_grad_norm", 0.05, "Clip gradients to this norm.")
 tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.")
 tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.")
 tf.app.flags.DEFINE_float("pull", 0.0005, "Starting pull of the relaxations.")
 tf.app.flags.DEFINE_float("pull", 0.0005, "Starting pull of the relaxations.")
 tf.app.flags.DEFINE_float("pull_incr", 1.2, "Increase pull by that much.")
 tf.app.flags.DEFINE_float("pull_incr", 1.2, "Increase pull by that much.")
-tf.app.flags.DEFINE_float("dropout", 0.2, "Dropout that much.")
+tf.app.flags.DEFINE_float("dropout", 0.15, "Dropout that much.")
 tf.app.flags.DEFINE_float("grad_noise_scale", 1.0, "Gradient noise scale.")
 tf.app.flags.DEFINE_float("grad_noise_scale", 1.0, "Gradient noise scale.")
 tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size.")
 tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size.")
 tf.app.flags.DEFINE_integer("low_batch_size", 16, "Low batch size.")
 tf.app.flags.DEFINE_integer("low_batch_size", 16, "Low batch size.")
@@ -63,6 +46,7 @@ tf.app.flags.DEFINE_string("task", "rev", "Which task are we learning?")
 tf.app.flags.DEFINE_string("train_dir", "/tmp/", "Directory to store models.")
 tf.app.flags.DEFINE_string("train_dir", "/tmp/", "Directory to store models.")
 
 
 FLAGS = tf.app.flags.FLAGS
 FLAGS = tf.app.flags.FLAGS
+EXTRA_EVAL = 12
 
 
 
 
 def initialize(sess):
 def initialize(sess):
@@ -83,7 +67,7 @@ def initialize(sess):
   min_length = 3
   min_length = 3
   max_length = min(FLAGS.max_length, data.bins[-1])
   max_length = min(FLAGS.max_length, data.bins[-1])
   assert max_length + 1 > min_length
   assert max_length + 1 > min_length
-  while len(data.bins) > 1 and data.bins[-2] > max_length + 12:
+  while len(data.bins) > 1 and data.bins[-2] > max_length + EXTRA_EVAL:
     data.bins = data.bins[:-1]
     data.bins = data.bins[:-1]
   assert data.bins[0] > FLAGS.rx_step
   assert data.bins[0] > FLAGS.rx_step
   nclass = min(FLAGS.niclass, FLAGS.noclass)
   nclass = min(FLAGS.niclass, FLAGS.noclass)
@@ -92,7 +76,7 @@ def initialize(sess):
   # Initialize data for each task.
   # Initialize data for each task.
   tasks = FLAGS.task.split("-")
   tasks = FLAGS.task.split("-")
   for t in tasks:
   for t in tasks:
-    for l in xrange(max_length + 11):
+    for l in xrange(max_length + EXTRA_EVAL - 1):
       data.init_data(t, l, data_size, nclass)
       data.init_data(t, l, data_size, nclass)
     data.init_data(t, data.bins[-2], data_size, nclass)
     data.init_data(t, data.bins[-2], data_size, nclass)
     data.init_data(t, data.bins[-1], data_size, nclass)
     data.init_data(t, data.bins[-1], data_size, nclass)
@@ -101,14 +85,14 @@ def initialize(sess):
 
 
   # Print out parameters.
   # Print out parameters.
   curriculum = 0.12
   curriculum = 0.12
-  fin = ("cv %d kw %d h %d kh %d rxr %d bs %d ns %.2f t %s"
-         % (FLAGS.nconvs, FLAGS.kw, FLAGS.height, FLAGS.kh, FLAGS.rx_step,
-            FLAGS.batch_size, FLAGS.grad_noise_scale, FLAGS.task))
-  fin = "data %d %s" % (FLAGS.train_data_size, fin)
-  tag = ("df %.2f p %.3f lr %.2f iw %.2f cr %.2f nm %d d%.4f gn %.2f %s" %
-         (FLAGS.cutoff, FLAGS.pull_incr, FLAGS.lr, FLAGS.init_weight,
-          curriculum, FLAGS.nmaps, FLAGS.dropout, FLAGS.max_grad_norm, fin))
-  data.print_out(tag)
+  msg1 = ("layers %d kw %d h %d kh %d relax %d batch %d noise %.2f task %s"
+          % (FLAGS.nconvs, FLAGS.kw, FLAGS.height, FLAGS.kh, FLAGS.rx_step,
+             FLAGS.batch_size, FLAGS.grad_noise_scale, FLAGS.task))
+  msg2 = "data %d %s" % (FLAGS.train_data_size, msg1)
+  msg3 = ("cut %.2f pull %.3f lr %.2f iw %.2f cr %.2f nm %d d%.4f gn %.2f %s" %
+          (FLAGS.cutoff, FLAGS.pull_incr, FLAGS.lr, FLAGS.init_weight,
+           curriculum, FLAGS.nmaps, FLAGS.dropout, FLAGS.max_grad_norm, msg2))
+  data.print_out(msg3)
 
 
   # Create checkpoint directory if it does not exist.
   # Create checkpoint directory if it does not exist.
   checkpoint_dir = os.path.join(FLAGS.train_dir, "neural_gpu%s"
   checkpoint_dir = os.path.join(FLAGS.train_dir, "neural_gpu%s"
@@ -120,7 +104,7 @@ def initialize(sess):
   # Create model and initialize it.
   # Create model and initialize it.
   tf.get_variable_scope().set_initializer(
   tf.get_variable_scope().set_initializer(
       tf.uniform_unit_scaling_initializer(factor=1.8 * FLAGS.init_weight))
       tf.uniform_unit_scaling_initializer(factor=1.8 * FLAGS.init_weight))
-  model = ngpu.NeuralGPU(
+  model = neural_gpu.NeuralGPU(
       FLAGS.nmaps, FLAGS.nmaps, FLAGS.niclass, FLAGS.noclass, FLAGS.dropout,
       FLAGS.nmaps, FLAGS.nmaps, FLAGS.niclass, FLAGS.noclass, FLAGS.dropout,
       FLAGS.rx_step, FLAGS.max_grad_norm, FLAGS.cutoff, FLAGS.nconvs,
       FLAGS.rx_step, FLAGS.max_grad_norm, FLAGS.cutoff, FLAGS.nconvs,
       FLAGS.kw, FLAGS.kh, FLAGS.height, FLAGS.mode, FLAGS.lr,
       FLAGS.kw, FLAGS.kh, FLAGS.height, FLAGS.mode, FLAGS.lr,
@@ -145,131 +129,148 @@ def single_test(l, model, sess, task, nprint, batch_size, print_out=True,
   """Test model on test data of length l using the given session."""
   """Test model on test data of length l using the given session."""
   inpt, target = data.get_batch(l, batch_size, False, task, offset)
   inpt, target = data.get_batch(l, batch_size, False, task, offset)
   _, res, _, steps = model.step(sess, inpt, target, False)
   _, res, _, steps = model.step(sess, inpt, target, False)
-  errors, total, seq = data.accuracy(inpt, res, target, batch_size, nprint)
-  seq = float(seq) / batch_size
+  errors, total, seq_err = data.accuracy(inpt, res, target, batch_size, nprint)
+  seq_err = float(seq_err) / batch_size
   if total > 0:
   if total > 0:
     errors = float(errors) / total
     errors = float(errors) / total
   if print_out:
   if print_out:
     data.print_out("  %s len %d errors %.2f sequence-errors %.2f"
     data.print_out("  %s len %d errors %.2f sequence-errors %.2f"
-                   % (task, l, 100*errors, 100*seq))
-  return errors, seq, (steps, inpt, [np.argmax(o, axis=1) for o in res])
+                   % (task, l, 100*errors, 100*seq_err))
+  return errors, seq_err, (steps, inpt, [np.argmax(o, axis=1) for o in res])
 
 
 
 
 def multi_test(l, model, sess, task, nprint, batch_size, offset=None):
 def multi_test(l, model, sess, task, nprint, batch_size, offset=None):
   """Run multiple tests at lower batch size to save memory."""
   """Run multiple tests at lower batch size to save memory."""
-  errors = 0.0
-  seq = 0.0
+  errors, seq_err = 0.0, 0.0
   to_print = nprint
   to_print = nprint
   low_batch = FLAGS.low_batch_size
   low_batch = FLAGS.low_batch_size
   low_batch = min(low_batch, batch_size)
   low_batch = min(low_batch, batch_size)
   for mstep in xrange(batch_size / low_batch):
   for mstep in xrange(batch_size / low_batch):
     cur_offset = None if offset is None else offset + mstep * low_batch
     cur_offset = None if offset is None else offset + mstep * low_batch
-    err, sq, _ = single_test(l, model, sess, task, to_print, low_batch, False,
-                             cur_offset)
+    err, sq_err, _ = single_test(l, model, sess, task, to_print, low_batch,
+                                 False, cur_offset)
     to_print = max(0, to_print - low_batch)
     to_print = max(0, to_print - low_batch)
     errors += err
     errors += err
-    seq += sq
+    seq_err += sq_err
     if FLAGS.mode > 0:
     if FLAGS.mode > 0:
       cur_errors = float(low_batch * errors) / ((mstep+1) * low_batch)
       cur_errors = float(low_batch * errors) / ((mstep+1) * low_batch)
-      cur_seq = float(low_batch * seq) / ((mstep+1) * low_batch)
+      cur_seq_err = float(low_batch * seq_err) / ((mstep+1) * low_batch)
       data.print_out("    %s multitest current errors %.2f sequence-errors %.2f"
       data.print_out("    %s multitest current errors %.2f sequence-errors %.2f"
-                     % (task, 100*cur_errors, 100*cur_seq))
+                     % (task, 100*cur_errors, 100*cur_seq_err))
   errors = float(low_batch) * float(errors) / batch_size
   errors = float(low_batch) * float(errors) / batch_size
-  seq = float(low_batch) * float(seq) / batch_size
+  seq_err = float(low_batch) * float(seq_err) / batch_size
   data.print_out("  %s len %d errors %.2f sequence-errors %.2f"
   data.print_out("  %s len %d errors %.2f sequence-errors %.2f"
-                 % (task, l, 100*errors, 100*seq))
-  return errors, seq
+                 % (task, l, 100*errors, 100*seq_err))
+  return errors, seq_err
 
 
 
 
 def train():
 def train():
-  """Main training function."""
+  """Train the model."""
   batch_size = FLAGS.batch_size
   batch_size = FLAGS.batch_size
   tasks = FLAGS.task.split("-")
   tasks = FLAGS.task.split("-")
   with tf.Session() as sess:
   with tf.Session() as sess:
     model, min_length, max_length, checkpoint_dir, curriculum = initialize(sess)
     model, min_length, max_length, checkpoint_dir, curriculum = initialize(sess)
     max_cur_length = min(min_length + 3, max_length)
     max_cur_length = min(min_length + 3, max_length)
     prev_acc_perp = [1000000 for _ in xrange(3)]
     prev_acc_perp = [1000000 for _ in xrange(3)]
-    prev_sq = 1.0
+    prev_seq_err = 1.0
 
 
+    # Main traning loop.
     while True:
     while True:
       global_step, pull, max_cur_length, learning_rate = sess.run(
       global_step, pull, max_cur_length, learning_rate = sess.run(
           [model.global_step, model.pull, model.cur_length, model.lr])
           [model.global_step, model.pull, model.cur_length, model.lr])
-      ep = global_step / FLAGS.steps_per_checkpoint
-      acc_loss, acc_total, acc_errors, acc_seq = 0.0, 0, 0, 0
+      acc_loss, acc_total, acc_errors, acc_seq_err = 0.0, 0, 0, 0
       acc_grad_norm, step_count, step_time = 0.0, 0, 0.0
       acc_grad_norm, step_count, step_time = 0.0, 0, 0.0
       for _ in xrange(FLAGS.steps_per_checkpoint):
       for _ in xrange(FLAGS.steps_per_checkpoint):
         global_step += 1
         global_step += 1
         task = random.choice(tasks)
         task = random.choice(tasks)
-        l1 = np.random.randint(max_cur_length - min_length + 1) + min_length
-        l = l1
-        if np.random.randint(10) > 3:  # Prefer longer stuff 60% of time.
-          l = np.random.randint(max_cur_length - min_length+1) + min_length
+
+        # Select the length for curriculum learning.
+        l = np.random.randint(max_cur_length - min_length + 1) + min_length
+        # Prefer longer stuff 60% of time.
+        if np.random.randint(100) < 60:
+          l1 = np.random.randint(max_cur_length - min_length+1) + min_length
           l = max(l, l1)
           l = max(l, l1)
-        if np.random.randint(4) < 1:  # Mixed learning: once in a while big.
-          l = np.random.randint(max_length - min_length + 1) + min_length
+        # Mixed curriculum learning: in 25% of cases go to any larger length.
+        if np.random.randint(100) < 25:
+          l1 = np.random.randint(max_length - min_length + 1) + min_length
           l = max(l, l1)
           l = max(l, l1)
+
+        # Run a step and time it.
         start_time = time.time()
         start_time = time.time()
         inp, target = data.get_batch(l, batch_size, True, task)
         inp, target = data.get_batch(l, batch_size, True, task)
-        stepp = math.pow(global_step, -0.55)
-        noise_param = math.sqrt(stepp * 20 * prev_sq) * FLAGS.grad_noise_scale
+        noise_param = math.sqrt(math.pow(global_step, -0.55) *
+                                (20 * prev_seq_err)) * FLAGS.grad_noise_scale
         loss, res, gnorm, _ = model.step(sess, inp, target, True, noise_param)
         loss, res, gnorm, _ = model.step(sess, inp, target, True, noise_param)
         step_time += time.time() - start_time
         step_time += time.time() - start_time
         acc_grad_norm += float(gnorm)
         acc_grad_norm += float(gnorm)
+
+        # Accumulate statistics only if we did not exceed curriculum length.
         if l < max_cur_length + 1:
         if l < max_cur_length + 1:
           step_count += 1
           step_count += 1
           acc_loss += loss
           acc_loss += loss
-          errors, total, seq = data.accuracy(inp, res, target,
-                                             batch_size, 0)
+          errors, total, seq_err = data.accuracy(inp, res, target,
+                                                 batch_size, 0)
           acc_total += total
           acc_total += total
           acc_errors += errors
           acc_errors += errors
-          acc_seq += seq
+          acc_seq_err += seq_err
+
+      # Normalize and print out accumulated statistics.
       acc_loss /= step_count
       acc_loss /= step_count
       step_time /= FLAGS.steps_per_checkpoint
       step_time /= FLAGS.steps_per_checkpoint
-      acc_seq = float(acc_seq) / (step_count * batch_size)
-      prev_sq = acc_seq
+      acc_seq_err = float(acc_seq_err) / (step_count * batch_size)
+      prev_seq_err = acc_seq_err
       acc_errors = float(acc_errors) / acc_total if acc_total > 0 else 1.0
       acc_errors = float(acc_errors) / acc_total if acc_total > 0 else 1.0
-      msg1 = "ep %d st %.2f lr %.8f" % (ep, step_time, learning_rate)
-      msg2 = "pl %.3f cme %.3f" % (pull, curriculum)
-      msg = ("%s %s gn %.8f"
-             % (msg1, msg2, acc_grad_norm / FLAGS.steps_per_checkpoint))
-      data.print_out("%s len %d ppx %.8f errs %.2f sq %.2f" %
-                     (msg, max_cur_length, data.safe_exp(acc_loss),
-                      100*acc_errors, 100*acc_seq))
-      if curriculum > acc_seq:
-        prev_acc_perp.append(1000000)
+      msg1 = "step %d step-time %.2f" % (global_step, step_time)
+      msg2 = "lr %.8f pull %.3f" % (learning_rate, pull)
+      msg3 = ("%s %s grad-norm %.8f"
+              % (msg1, msg2, acc_grad_norm / FLAGS.steps_per_checkpoint))
+      data.print_out("%s len %d ppx %.8f errors %.2f sequence-errors %.2f" %
+                     (msg3, max_cur_length, data.safe_exp(acc_loss),
+                      100*acc_errors, 100*acc_seq_err))
+
+      # If errors are below the curriculum threshold, move curriculum forward.
+      if curriculum > acc_seq_err:
+        # Increase current length (until the next with training data).
         do_incr = True
         do_incr = True
         while do_incr and max_cur_length < max_length:
         while do_incr and max_cur_length < max_length:
           sess.run(model.cur_length_incr_op)
           sess.run(model.cur_length_incr_op)
           for t in tasks:
           for t in tasks:
             if data.train_set[t]: do_incr = False
             if data.train_set[t]: do_incr = False
+        # Forget last perplexities if we're not yet at the end.
+        if max_cur_length < max_length:
+          prev_acc_perp.append(1000000)
+        # Either increase pull or, if it's large, average parameters.
         if pull < 1:
         if pull < 1:
           sess.run(model.pull_incr_op)
           sess.run(model.pull_incr_op)
         else:
         else:
           data.print_out("  Averaging parameters.")
           data.print_out("  Averaging parameters.")
           sess.run([model.avg_op, model.lr_decay_op])
           sess.run([model.avg_op, model.lr_decay_op])
-      else:
-        acc_perp = data.safe_exp(acc_loss)
-        if acc_perp > max(prev_acc_perp[-3:]):
-          sess.run(model.lr_decay_op)
-        prev_acc_perp.append(acc_perp)
+
+      # Lower learning rate if we're worse than the last 3 checkpoints.
+      acc_perp = data.safe_exp(acc_loss)
+      if acc_perp > max(prev_acc_perp[-3:]):
+        sess.run(model.lr_decay_op)
+      prev_acc_perp.append(acc_perp)
+
+      # Save checkpoint.
       checkpoint_path = os.path.join(checkpoint_dir, "neural_gpu.ckpt")
       checkpoint_path = os.path.join(checkpoint_dir, "neural_gpu.ckpt")
       model.saver.save(sess, checkpoint_path,
       model.saver.save(sess, checkpoint_path,
                        global_step=model.global_step)
                        global_step=model.global_step)
+
       # Run evaluation.
       # Run evaluation.
-      should_exit = True
       bound = data.bins[-1] + 1
       bound = data.bins[-1] + 1
       for t in tasks:
       for t in tasks:
         l = min_length
         l = min_length
-        while l < max_length + 12 and l < bound:
-          _, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
+        while l < max_length + EXTRA_EVAL and l < bound:
+          _, seq_err, _ = single_test(l, model, sess, t,
+                                      FLAGS.nprint, batch_size)
           l += 1
           l += 1
           while l < bound + 1 and not data.test_set[t][l]:
           while l < bound + 1 and not data.test_set[t][l]:
             l += 1
             l += 1
-        if sq < 0.5:
-          _, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
-                             batch_size * 4)
-        if sq > 0.001: should_exit = False
-      if should_exit:
+        if seq_err < 0.5:  # Run larger test if we're good enough.
+          _, seq_err = multi_test(data.forward_max, model, sess, t,
+                                  FLAGS.nprint, batch_size * 4)
+      if seq_err < 0.01:  # Super-large test on 1-task large-forward models.
         if data.forward_max > 4000 and len(tasks) == 1:
         if data.forward_max > 4000 and len(tasks) == 1:
           multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
           multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
                      batch_size * 16, 0)
                      batch_size * 16, 0)
@@ -277,14 +278,17 @@ def train():
 
 
 def animate(l, test_data, anim_size):
 def animate(l, test_data, anim_size):
   """Create animation for the given data (hacky matplotlib use)."""
   """Create animation for the given data (hacky matplotlib use)."""
-  xf = 12
-  fps = 2
+  xf = 12  # Extra frames to slow down at start and end.
+  fps = 2  # Frames per step.
+
+  # Make the figure.
   fig = plt.figure(figsize=(16, 9), facecolor="white")
   fig = plt.figure(figsize=(16, 9), facecolor="white")
   ax = fig.add_axes([0, 0, 1, 1], frameon=False, zorder=2)
   ax = fig.add_axes([0, 0, 1, 1], frameon=False, zorder=2)
   ax.set_xticks([i * 24-0.5 for i in xrange(4)])
   ax.set_xticks([i * 24-0.5 for i in xrange(4)])
   ax.set_xticklabels([])
   ax.set_xticklabels([])
   ax.set_yticks([i - 0.5 for i in xrange(l+1)])
   ax.set_yticks([i - 0.5 for i in xrange(l+1)])
   ax.grid(which="major", axis="both", linestyle="-", color="black")
   ax.grid(which="major", axis="both", linestyle="-", color="black")
+  # We need text fields.
   text_fields = []
   text_fields = []
   text_size = 24*32/l
   text_size = 24*32/l
   for y in xrange(l):
   for y in xrange(l):
@@ -296,11 +300,8 @@ def animate(l, test_data, anim_size):
                  vmax=1.0, cmap="gray", aspect="auto", origin="upper",
                  vmax=1.0, cmap="gray", aspect="auto", origin="upper",
                  interpolation="none", animated=True)
                  interpolation="none", animated=True)
   im.set_zorder(1)
   im.set_zorder(1)
-  def to_symbol(i):
-    if i == 0: return ""
-    if i == 11: return "+"
-    if i == 12: return "*"
-    return str(i-1)
+
+  # Main animation step.
   def animation_update(frame_no, test_data, xf, im, text_fields):
   def animation_update(frame_no, test_data, xf, im, text_fields):
     """Update an animation frame."""
     """Update an animation frame."""
     steps, inpt, out_raw = test_data
     steps, inpt, out_raw = test_data
@@ -319,15 +320,17 @@ def animate(l, test_data, anim_size):
         if index - 2*xf < length:
         if index - 2*xf < length:
           t.set_text("")
           t.set_text("")
         else:
         else:
-          t.set_text(to_symbol(out[i]))
+          t.set_text(data.to_symbol(out[i]))
     else:
     else:
       for i, t in enumerate(text_fields):
       for i, t in enumerate(text_fields):
-        t.set_text(to_symbol(inpt[i][batch]) if index < xf else "")
+        t.set_text(data.to_symbol(inpt[i][batch]) if index < xf else "")
       if index < xf:
       if index < xf:
         im.set_array(np.zeros_like(steps[0][0]))
         im.set_array(np.zeros_like(steps[0][0]))
       else:
       else:
         im.set_array(steps[0][batch])
         im.set_array(steps[0][batch])
     return im,
     return im,
+
+  # Create the animation and save to mp4.
   animation = anim.FuncAnimation(
   animation = anim.FuncAnimation(
       fig, animation_update, blit=True, frames=(l+4*xf)*anim_size*fps,
       fig, animation_update, blit=True, frames=(l+4*xf)*anim_size*fps,
       interval=500/fps, fargs=(test_data, xf, im, text_fields))
       interval=500/fps, fargs=(test_data, xf, im, text_fields))
@@ -343,8 +346,8 @@ def evaluate():
     bound = data.bins[-1] + 1
     bound = data.bins[-1] + 1
     for t in tasks:
     for t in tasks:
       l = min_length
       l = min_length
-      while l < max_length + 12 and l < bound:
-        _, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
+      while l < max_length + EXTRA_EVAL and l < bound:
+        _, seq_err, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
         l += 1
         l += 1
         while l < bound + 1 and not data.test_set[t][l]:
         while l < bound + 1 and not data.test_set[t][l]:
           l += 1
           l += 1
@@ -353,9 +356,9 @@ def evaluate():
       _, _, test_data = single_test(l, model, sess, t, 0, anim_size)
       _, _, test_data = single_test(l, model, sess, t, 0, anim_size)
       animate(l, test_data, anim_size)
       animate(l, test_data, anim_size)
       # More tests.
       # More tests.
-      _, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
-                         batch_size * 4)
-    if sq < 0.01:  # More tests.
+      _, seq_err = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
+                              batch_size * 4)
+    if seq_err < 0.01:  # Super-test if we're very good and in large-test mode.
       if data.forward_max > 4000 and len(tasks) == 1:
       if data.forward_max > 4000 and len(tasks) == 1:
         multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
         multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
                    batch_size * 64, 0)
                    batch_size * 64, 0)
@@ -365,16 +368,18 @@ def interactive():
   """Interactively probe an existing model."""
   """Interactively probe an existing model."""
   with tf.Session() as sess:
   with tf.Session() as sess:
     model, _, _, _, _ = initialize(sess)
     model, _, _, _, _ = initialize(sess)
+    sys.stdout.write("Input to Neural GPU, e.g., 0 1. Use -1 for PAD.\n")
     sys.stdout.write("> ")
     sys.stdout.write("> ")
     sys.stdout.flush()
     sys.stdout.flush()
     inpt = sys.stdin.readline()
     inpt = sys.stdin.readline()
     while inpt:
     while inpt:
-      ids = [int(c) for c in inpt.strip()]
+      ids = [data.to_id(s) for s in inpt.strip().split()]
       inpt, target = data.get_batch(len(ids), 1, False, "",
       inpt, target = data.get_batch(len(ids), 1, False, "",
                                     preset=(ids, [0 for _ in ids]))
                                     preset=(ids, [0 for _ in ids]))
       _, res, _, _ = model.step(sess, inpt, target, False)
       _, res, _, _ = model.step(sess, inpt, target, False)
       res = [np.argmax(o, axis=1) for o in res]
       res = [np.argmax(o, axis=1) for o in res]
-      print " ".join([str(output[0]) for output in res])
+      res = [o for o in res[:len(ids)] if o > 0]
+      print "  " + " ".join([data.to_symbol(output[0]) for output in res])
       sys.stdout.write("> ")
       sys.stdout.write("> ")
       sys.stdout.flush()
       sys.stdout.flush()
       inpt = sys.stdin.readline()
       inpt = sys.stdin.readline()