Преглед изворни кода

Update Neural GPU: allow to quantize activations, add a few tasks.

Lukasz Kaiser пре 9 година
родитељ
комит
2e4f31a174
3 измењених фајлова са 169 додато и 42 уклоњено
  1. 51 7
      neural_gpu/data_utils.py
  2. 48 11
      neural_gpu/neural_gpu.py
  3. 70 24
      neural_gpu/neural_gpu_trainer.py

+ 51 - 7
neural_gpu/data_utils.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-
 """Convolutional Gated Recurrent Networks for Algorithm Learning."""
 
 import math
@@ -27,9 +26,10 @@ from tensorflow.python.platform import gfile
 
 FLAGS = tf.app.flags.FLAGS
 
-bins = [8, 16, 32, 64, 128]
-all_tasks = ["sort", "id", "rev", "incr", "left", "right", "left-shift", "add",
-             "right-shift", "bmul", "dup", "badd", "qadd"]
+bins = [8, 12, 16, 20, 24, 28, 32, 36, 40, 48, 64, 128]
+all_tasks = ["sort", "kvsort", "id", "rev", "rev2", "incr", "add", "left",
+             "right", "left-shift", "right-shift", "bmul", "mul", "dup",
+             "badd", "qadd", "search"]
 forward_max = 128
 log_filename = ""
 
@@ -82,10 +82,13 @@ def init_data(task, length, nbr_cases, nclass):
     d2 = [np.random.randint(base) for _ in xrange(k)]
     if task in ["add", "badd", "qadd"]:
       res = add(d1, d2, base)
-    elif task in ["bmul"]:
+    elif task in ["mul", "bmul"]:
       d1n = sum([d * (base ** i) for i, d in enumerate(d1)])
       d2n = sum([d * (base ** i) for i, d in enumerate(d2)])
-      res = [int(x) for x in list(reversed(str(bin(d1n * d2n))))[:-2]]
+      if task == "bmul":
+        res = [int(x) for x in list(reversed(str(bin(d1n * d2n))))[:-2]]
+      else:
+        res = [int(x) for x in list(reversed(str(d1n * d2n)))]
     else:
       sys.exit()
     sep = [12]
@@ -101,6 +104,32 @@ def init_data(task, length, nbr_cases, nclass):
     res = x + x + [0 for _ in xrange(l - 2*k)]
     return inp, res
 
+  def rand_rev2_pair(l):
+    """Random data pair for reverse2 task. Total length should be <= l."""
+    inp = [(np.random.randint(nclass - 1) + 1,
+            np.random.randint(nclass - 1) + 1) for _ in xrange(l/2)]
+    res = [i for i in reversed(inp)]
+    return [x for p in inp for x in p], [x for p in res for x in p]
+
+  def rand_search_pair(l):
+    """Random data pair for search task. Total length should be <= l."""
+    inp = [(np.random.randint(nclass - 1) + 1,
+            np.random.randint(nclass - 1) + 1) for _ in xrange(l-1/2)]
+    q = np.random.randint(nclass - 1) + 1
+    res = 0
+    for (k, v) in reversed(inp):
+      if k == q:
+        res = v
+    return [x for p in inp for x in p] + [q], [res]
+
+  def rand_kvsort_pair(l):
+    """Random data pair for key-value sort. Total length should be <= l."""
+    keys = [(np.random.randint(nclass - 1) + 1, i) for i in xrange(l/2)]
+    vals = [np.random.randint(nclass - 1) + 1 for _ in xrange(l/2)]
+    kv = [(k, vals[i]) for (k, i) in keys]
+    sorted_kv = [(k, vals[i]) for (k, i) in sorted(keys)]
+    return [x for p in kv for x in p], [x for p in sorted_kv for x in p]
+
   def spec(inp):
     """Return the target given the input for some tasks."""
     if task == "sort":
@@ -140,7 +169,7 @@ def init_data(task, length, nbr_cases, nclass):
     cur_time = time.time()
     if l > 10000 and case % 100 == 1:
       print_out("  avg gen time %.4f s" % (total_time / float(case)))
-    if task in ["add", "badd", "qadd", "bmul"]:
+    if task in ["add", "badd", "qadd", "bmul", "mul"]:
       i, t = rand_pair(l, task)
       train_set[task][len(i)].append([i, t])
       i, t = rand_pair(l, task)
@@ -150,6 +179,21 @@ def init_data(task, length, nbr_cases, nclass):
       train_set[task][len(i)].append([i, t])
       i, t = rand_dup_pair(l)
       test_set[task][len(i)].append([i, t])
+    elif task == "rev2":
+      i, t = rand_rev2_pair(l)
+      train_set[task][len(i)].append([i, t])
+      i, t = rand_rev2_pair(l)
+      test_set[task][len(i)].append([i, t])
+    elif task == "search":
+      i, t = rand_search_pair(l)
+      train_set[task][len(i)].append([i, t])
+      i, t = rand_search_pair(l)
+      test_set[task][len(i)].append([i, t])
+    elif task == "kvsort":
+      i, t = rand_kvsort_pair(l)
+      train_set[task][len(i)].append([i, t])
+      i, t = rand_kvsort_pair(l)
+      test_set[task][len(i)].append([i, t])
     else:
       inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)]
       target = spec(inp)

+ 48 - 11
neural_gpu/neural_gpu.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-
 """The Neural GPU Model."""
 
 import time
@@ -47,17 +46,46 @@ def sigmoid_cutoff(x, cutoff):
   return tf.minimum(1.0, tf.maximum(0.0, cutoff * y - d))
 
 
+def tanh_cutoff(x, cutoff):
+  """Tanh with cutoff, e.g., 1.1tanh(x) cut to [-1. 1]."""
+  y = tf.tanh(x)
+  if cutoff < 1.01: return y
+  d = (cutoff - 1.0) / 2.0
+  return tf.minimum(1.0, tf.maximum(-1.0, (1.0 + d) * y))
+
+
 def conv_gru(inpts, mem, kw, kh, nmaps, cutoff, prefix):
   """Convolutional GRU."""
   def conv_lin(args, suffix, bias_start):
     return conv_linear(args, kw, kh, len(args) * nmaps, nmaps, True, bias_start,
                        prefix + "/" + suffix)
   reset = sigmoid_cutoff(conv_lin(inpts + [mem], "r", 1.0), cutoff)
+  # candidate = tanh_cutoff(conv_lin(inpts + [reset * mem], "c", 0.0), cutoff)
   candidate = tf.tanh(conv_lin(inpts + [reset * mem], "c", 0.0))
   gate = sigmoid_cutoff(conv_lin(inpts + [mem], "g", 1.0), cutoff)
   return gate * mem + (1 - gate) * candidate
 
 
+@tf.RegisterGradient("CustomIdG")
+def _custom_id_grad(_, grads):
+  return grads
+
+
+def quantize(t, quant_scale, max_value=1.0):
+  """Quantize a tensor t with each element in [-max_value, max_value]."""
+  t = tf.minimum(max_value, tf.maximum(t, -max_value))
+  big = quant_scale * (t + max_value) + 0.5
+  with tf.get_default_graph().gradient_override_map({"Floor": "CustomIdG"}):
+    res = (tf.floor(big) / quant_scale) - max_value
+  return res
+
+
+def quantize_weights_op(quant_scale, max_value):
+  ops = [v.assign(quantize(v, quant_scale, float(max_value)))
+         for v in tf.trainable_variables()]
+  return tf.group(*ops)
+
+
 def relaxed_average(var_name_suffix, rx_step):
   """Calculate the average of relaxed variables having var_name_suffix."""
   relaxed_vars = []
@@ -117,7 +145,7 @@ class NeuralGPU(object):
 
   def __init__(self, nmaps, vec_size, niclass, noclass, dropout, rx_step,
                max_grad_norm, cutoff, nconvs, kw, kh, height, mode,
-               learning_rate, pull, pull_incr, min_length):
+               learning_rate, pull, pull_incr, min_length, act_noise=0.0):
     # Feeds for parameters and ops to update them.
     self.global_step = tf.Variable(0, trainable=False)
     self.cur_length = tf.Variable(min_length, trainable=False)
@@ -195,7 +223,9 @@ class NeuralGPU(object):
       first = tf.concat(2, first)
 
       # Computation steps.
-      step = [tf.nn.dropout(first, 1.0 - self.do_training * dropout) * mask]
+      keep_prob = 1.0 - self.do_training * (dropout * 8.0 / float(length))
+      step = [tf.nn.dropout(first, keep_prob) * mask]
+      act_noise_scale = act_noise * self.do_training * self.pull
       outputs = []
       for it in xrange(length):
         with tf.variable_scope("RX%d" % (it % rx_step)) as vs:
@@ -205,9 +235,12 @@ class NeuralGPU(object):
           # Do nconvs-many CGRU steps.
           for layer in xrange(nconvs):
             cur = conv_gru([], cur, kw, kh, nmaps, cutoff, "cgru_%d" % layer)
-          cur = tf.nn.dropout(cur, 1.0 - self.do_training * dropout)
+            cur *= mask
+          outputs.append(tf.slice(cur, [0, 0, 0, 0], [-1, -1, 1, -1]))
+          cur = tf.nn.dropout(cur, keep_prob)
+          if act_noise > 0.00001:
+            cur += tf.truncated_normal(tf.shape(cur)) * act_noise_scale
           step.append(cur * mask)
-          outputs.append(tf.slice(step[-1], [0, 0, 0, 0], [-1, -1, 1, -1]))
 
       self.steps.append([tf.reshape(s, [-1, length, height * nmaps])
                          for s in step])
@@ -216,8 +249,10 @@ class NeuralGPU(object):
       # Final convolution to get logits, list outputs.
       output = conv_linear(output, 1, 1, nmaps, noclass, True, 0.0, "output")
       output = tf.reshape(output, [-1, length, noclass])
-      self.outputs.append([tf.reshape(o, [-1, noclass])
-                           for o in list(tf.split(1, length, output))])
+      external_output = [tf.reshape(o, [-1, noclass])
+                         for o in list(tf.split(1, length, output))]
+      external_output = [tf.nn.softmax(o) for o in external_output]
+      self.outputs.append(external_output)
 
       # Calculate cross-entropy loss and normalize it.
       targets = tf.concat(1, [make_dense(self.target[l], noclass)
@@ -252,7 +287,8 @@ class NeuralGPU(object):
                            " %.2f s." % (length, time.time() - start_time))
     self.saver = tf.train.Saver(tf.all_variables())
 
-  def step(self, sess, inp, target, do_backward, noise_param=None):
+  def step(self, sess, inp, target, do_backward, noise_param=None,
+           get_steps=False):
     """Run a step of the network."""
     assert len(inp) == len(target)
     length = len(target)
@@ -272,8 +308,9 @@ class NeuralGPU(object):
     for l in xrange(length):
       feed_in[self.target[l].name] = target[l]
       feed_out.append(self.outputs[index][l])
-    for l in xrange(length+1):
-      feed_out.append(self.steps[index][l])
+    if get_steps:
+      for l in xrange(length+1):
+        feed_out.append(self.steps[index][l])
     res = sess.run(feed_out, feed_in)
     offset = 0
     norm = None
@@ -281,5 +318,5 @@ class NeuralGPU(object):
       offset = 2
       norm = res[1]
     outputs = res[offset + 1:offset + 1 + length]
-    steps = res[offset + 1 + length:]
+    steps = res[offset + 1 + length:] if get_steps else None
     return res[offset], outputs, norm, steps

+ 70 - 24
neural_gpu/neural_gpu_trainer.py

@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-
 """Neural GPU for Learning Algorithms."""
 
 import math
@@ -31,21 +30,21 @@ from tensorflow.python.platform import gfile
 import data_utils as data
 import neural_gpu
 
-tf.app.flags.DEFINE_float("lr", 0.003, "Learning rate.")
+tf.app.flags.DEFINE_float("lr", 0.001, "Learning rate.")
 tf.app.flags.DEFINE_float("init_weight", 1.0, "Initial weights deviation.")
-tf.app.flags.DEFINE_float("max_grad_norm", 0.05, "Clip gradients to this norm.")
+tf.app.flags.DEFINE_float("max_grad_norm", 1.0, "Clip gradients to this norm.")
 tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.")
 tf.app.flags.DEFINE_float("pull", 0.0005, "Starting pull of the relaxations.")
 tf.app.flags.DEFINE_float("pull_incr", 1.2, "Increase pull by that much.")
-tf.app.flags.DEFINE_float("curriculum_bound", 0.08, "Move curriculum < this.")
+tf.app.flags.DEFINE_float("curriculum_bound", 0.15, "Move curriculum < this.")
 tf.app.flags.DEFINE_float("dropout", 0.15, "Dropout that much.")
-tf.app.flags.DEFINE_float("grad_noise_scale", 1.0, "Gradient noise scale.")
+tf.app.flags.DEFINE_float("grad_noise_scale", 0.0, "Gradient noise scale.")
 tf.app.flags.DEFINE_integer("batch_size", 32, "Batch size.")
 tf.app.flags.DEFINE_integer("low_batch_size", 16, "Low batch size.")
 tf.app.flags.DEFINE_integer("steps_per_checkpoint", 200, "Steps per epoch.")
-tf.app.flags.DEFINE_integer("nmaps", 24, "Number of floats in each cell.")
-tf.app.flags.DEFINE_integer("niclass", 14, "Number of classes (0 is padding).")
-tf.app.flags.DEFINE_integer("noclass", 14, "Number of classes (0 is padding).")
+tf.app.flags.DEFINE_integer("nmaps", 128, "Number of floats in each cell.")
+tf.app.flags.DEFINE_integer("niclass", 33, "Number of classes (0 is padding).")
+tf.app.flags.DEFINE_integer("noclass", 33, "Number of classes (0 is padding).")
 tf.app.flags.DEFINE_integer("train_data_size", 5000, "Training examples/len.")
 tf.app.flags.DEFINE_integer("max_length", 41, "Maximum length.")
 tf.app.flags.DEFINE_integer("rx_step", 6, "Relax that many recursive steps.")
@@ -58,8 +57,11 @@ tf.app.flags.DEFINE_integer("forward_max", 401, "Maximum forward length.")
 tf.app.flags.DEFINE_integer("jobid", -1, "Task id when running on borg.")
 tf.app.flags.DEFINE_integer("nprint", 0, "How many test examples to print out.")
 tf.app.flags.DEFINE_integer("mode", 0, "Mode: 0-train other-decode.")
+tf.app.flags.DEFINE_bool("animate", False, "Whether to produce an animation.")
+tf.app.flags.DEFINE_bool("quantize", False, "Whether to quantize variables.")
 tf.app.flags.DEFINE_string("task", "rev", "Which task are we learning?")
 tf.app.flags.DEFINE_string("train_dir", "/tmp/", "Directory to store models.")
+tf.app.flags.DEFINE_string("ensemble", "", "Model paths for ensemble.")
 
 FLAGS = tf.app.flags.FLAGS
 EXTRA_EVAL = 12
@@ -78,7 +80,6 @@ def initialize(sess):
   np.random.seed(seed)
 
   # Check data sizes.
-  data.forward_max = max(FLAGS.forward_max, data.bins[-1])
   assert data.bins
   min_length = 3
   max_length = min(FLAGS.max_length, data.bins[-1])
@@ -86,6 +87,7 @@ def initialize(sess):
   while len(data.bins) > 1 and data.bins[-2] > max_length + EXTRA_EVAL:
     data.bins = data.bins[:-1]
   assert data.bins[0] > FLAGS.rx_step
+  data.forward_max = max(FLAGS.forward_max, data.bins[-1])
   nclass = min(FLAGS.niclass, FLAGS.noclass)
   data_size = FLAGS.train_data_size if FLAGS.mode == 0 else 1000
 
@@ -136,15 +138,24 @@ def initialize(sess):
                    % ckpt.model_checkpoint_path)
     model.saver.restore(sess, ckpt.model_checkpoint_path)
 
+  # Check if there are ensemble models and get their checkpoints.
+  ensemble = []
+  ensemble_dir_list = [d for d in FLAGS.ensemble.split(",") if d]
+  for ensemble_dir in ensemble_dir_list:
+    ckpt = tf.train.get_checkpoint_state(ensemble_dir)
+    if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
+      data.print_out("Found ensemble model %s" % ckpt.model_checkpoint_path)
+      ensemble.append(ckpt.model_checkpoint_path)
+
   # Return the model and needed variables.
-  return (model, min_length, max_length, checkpoint_dir, curriculum)
+  return (model, min_length, max_length, checkpoint_dir, curriculum, ensemble)
 
 
 def single_test(l, model, sess, task, nprint, batch_size, print_out=True,
-                offset=None):
+                offset=None, ensemble=None, get_steps=False):
   """Test model on test data of length l using the given session."""
   inpt, target = data.get_batch(l, batch_size, False, task, offset)
-  _, res, _, steps = model.step(sess, inpt, target, False)
+  _, res, _, steps = model.step(sess, inpt, target, False, get_steps=get_steps)
   errors, total, seq_err = data.accuracy(inpt, res, target, batch_size, nprint)
   seq_err = float(seq_err) / batch_size
   if total > 0:
@@ -152,10 +163,34 @@ def single_test(l, model, sess, task, nprint, batch_size, print_out=True,
   if print_out:
     data.print_out("  %s len %d errors %.2f sequence-errors %.2f"
                    % (task, l, 100*errors, 100*seq_err))
+  # Ensemble eval.
+  if ensemble:
+    results = []
+    for m in ensemble:
+      model.saver.restore(sess, m)
+      _, result, _, _ = model.step(sess, inpt, target, False)
+      m_errors, m_total, m_seq_err = data.accuracy(inpt, result, target,
+                                                   batch_size, nprint)
+      m_seq_err = float(m_seq_err) / batch_size
+      if total > 0:
+        m_errors = float(m_errors) / m_total
+      data.print_out("     %s len %d m-errors %.2f m-sequence-errors %.2f"
+                     % (task, l, 100*m_errors, 100*m_seq_err))
+      results.append(result)
+    ens = [sum(o) for o in zip(*results)]
+    errors, total, seq_err = data.accuracy(inpt, ens, target,
+                                           batch_size, nprint)
+    seq_err = float(seq_err) / batch_size
+    if total > 0:
+      errors = float(errors) / total
+    if print_out:
+      data.print_out("  %s len %d ens-errors %.2f ens-sequence-errors %.2f"
+                     % (task, l, 100*errors, 100*seq_err))
   return errors, seq_err, (steps, inpt, [np.argmax(o, axis=1) for o in res])
 
 
-def multi_test(l, model, sess, task, nprint, batch_size, offset=None):
+def multi_test(l, model, sess, task, nprint, batch_size, offset=None,
+               ensemble=None):
   """Run multiple tests at lower batch size to save memory."""
   errors, seq_err = 0.0, 0.0
   to_print = nprint
@@ -164,7 +199,7 @@ def multi_test(l, model, sess, task, nprint, batch_size, offset=None):
   for mstep in xrange(batch_size / low_batch):
     cur_offset = None if offset is None else offset + mstep * low_batch
     err, sq_err, _ = single_test(l, model, sess, task, to_print, low_batch,
-                                 False, cur_offset)
+                                 False, cur_offset, ensemble=ensemble)
     to_print = max(0, to_print - low_batch)
     errors += err
     seq_err += sq_err
@@ -185,7 +220,9 @@ def train():
   batch_size = FLAGS.batch_size
   tasks = FLAGS.task.split("-")
   with tf.Session() as sess:
-    model, min_length, max_length, checkpoint_dir, curriculum = initialize(sess)
+    (model, min_length, max_length, checkpoint_dir,
+     curriculum, _) = initialize(sess)
+    quant_op = neural_gpu.quantize_weights_op(512, 8)
     max_cur_length = min(min_length + 3, max_length)
     prev_acc_perp = [1000000 for _ in xrange(3)]
     prev_seq_err = 1.0
@@ -246,6 +283,10 @@ def train():
 
       # If errors are below the curriculum threshold, move curriculum forward.
       if curriculum > acc_seq_err:
+        if FLAGS.quantize:
+          # Quantize weights.
+          data.print_out("  Quantizing parameters.")
+          sess.run([quant_op])
         # Increase current length (until the next with training data).
         do_incr = True
         while do_incr and max_cur_length < max_length:
@@ -260,7 +301,9 @@ def train():
           sess.run(model.pull_incr_op)
         else:
           data.print_out("  Averaging parameters.")
-          sess.run([model.avg_op, model.lr_decay_op])
+          sess.run(model.avg_op)
+          if acc_seq_err < (curriculum / 3.0):
+            sess.run(model.lr_decay_op)
 
       # Lower learning rate if we're worse than the last 3 checkpoints.
       acc_perp = data.safe_exp(acc_loss)
@@ -358,32 +401,35 @@ def evaluate():
   batch_size = FLAGS.batch_size
   tasks = FLAGS.task.split("-")
   with tf.Session() as sess:
-    model, min_length, max_length, _, _ = initialize(sess)
+    model, min_length, max_length, _, _, ensemble = initialize(sess)
     bound = data.bins[-1] + 1
     for t in tasks:
       l = min_length
       while l < max_length + EXTRA_EVAL and l < bound:
-        _, seq_err, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
+        _, seq_err, _ = single_test(l, model, sess, t, FLAGS.nprint,
+                                    batch_size, ensemble=ensemble)
         l += 1
         while l < bound + 1 and not data.test_set[t][l]:
           l += 1
       # Animate.
-      anim_size = 2
-      _, _, test_data = single_test(l, model, sess, t, 0, anim_size)
-      animate(l, test_data, anim_size)
+      if FLAGS.animate:
+        anim_size = 2
+        _, _, test_data = single_test(l, model, sess, t, 0, anim_size,
+                                      get_steps=True)
+        animate(l, test_data, anim_size)
       # More tests.
       _, seq_err = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
-                              batch_size * 4)
+                              batch_size * 4, ensemble=ensemble)
     if seq_err < 0.01:  # Super-test if we're very good and in large-test mode.
       if data.forward_max > 4000 and len(tasks) == 1:
         multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
-                   batch_size * 64, 0)
+                   batch_size * 64, 0, ensemble=ensemble)
 
 
 def interactive():
   """Interactively probe an existing model."""
   with tf.Session() as sess:
-    model, _, _, _, _ = initialize(sess)
+    model, _, _, _, _, _ = initialize(sess)
     sys.stdout.write("Input to Neural GPU, e.g., 0 1. Use -1 for PAD.\n")
     sys.stdout.write("> ")
     sys.stdout.flush()