|
|
@@ -0,0 +1,376 @@
|
|
|
+"""Neural GPU for Learning Algorithms."""
|
|
|
+
|
|
|
+import math
|
|
|
+import os
|
|
|
+import random
|
|
|
+import sys
|
|
|
+import time
|
|
|
+
|
|
|
+import google3
|
|
|
+
|
|
|
+import matplotlib.animation as anim
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import numpy as np
|
|
|
+import tensorflow as tf
|
|
|
+
|
|
|
+from google3.third_party.tensorflow.python.platform import gfile
|
|
|
+import google3.experimental.users.lukaszkaiser.neural_gpu.data_utils as data
|
|
|
+import google3.experimental.users.lukaszkaiser.neural_gpu.neural_gpu as ngpu
|
|
|
+
|
|
|
+tf.app.flags.DEFINE_float("lr", 0.1, "Learning rate.")
|
|
|
+tf.app.flags.DEFINE_float("init_weight", 1.0, "Initial weights deviation.")
|
|
|
+tf.app.flags.DEFINE_float("max_grad_norm", 0.05, "Clip gradients to this norm.")
|
|
|
+tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.")
|
|
|
+tf.app.flags.DEFINE_float("pull", 0.0005, "Starting pull of the relaxations.")
|
|
|
+tf.app.flags.DEFINE_float("pull_incr", 1.2, "Increase pull by that much.")
|
|
|
+tf.app.flags.DEFINE_float("dropout", 0.2, "Dropout that much.")
|
|
|
+tf.app.flags.DEFINE_float("grad_noise_scale", 1.0, "Gradient noise scale.")
|
|
|
+tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size.")
|
|
|
+tf.app.flags.DEFINE_integer("low_batch_size", 16, "Low batch size.")
|
|
|
+tf.app.flags.DEFINE_integer("steps_per_checkpoint", 100, "Steps per epoch.")
|
|
|
+tf.app.flags.DEFINE_integer("nmaps", 24, "Number of floats in each cell.")
|
|
|
+tf.app.flags.DEFINE_integer("niclass", 14, "Number of classes (0 is padding).")
|
|
|
+tf.app.flags.DEFINE_integer("noclass", 14, "Number of classes (0 is padding).")
|
|
|
+tf.app.flags.DEFINE_integer("train_data_size", 5000, "Training examples/len.")
|
|
|
+tf.app.flags.DEFINE_integer("max_length", 41, "Maximum length.")
|
|
|
+tf.app.flags.DEFINE_integer("rx_step", 6, "Relax that many recursive steps.")
|
|
|
+tf.app.flags.DEFINE_integer("random_seed", 125459, "Random seed.")
|
|
|
+tf.app.flags.DEFINE_integer("nconvs", 2, "How many convolutions / 1 step.")
|
|
|
+tf.app.flags.DEFINE_integer("kw", 3, "Kernel width.")
|
|
|
+tf.app.flags.DEFINE_integer("kh", 3, "Kernel height.")
|
|
|
+tf.app.flags.DEFINE_integer("height", 4, "Height.")
|
|
|
+tf.app.flags.DEFINE_integer("forward_max", 401, "Maximum forward length.")
|
|
|
+tf.app.flags.DEFINE_integer("jobid", -1, "Task id when running on borg.")
|
|
|
+tf.app.flags.DEFINE_integer("nprint", 0, "How many test examples to print out.")
|
|
|
+tf.app.flags.DEFINE_integer("mode", 0, "Mode: 0-train other-decode.")
|
|
|
+tf.app.flags.DEFINE_string("task", "rev", "Which task are we learning?")
|
|
|
+tf.app.flags.DEFINE_string("train_dir", "/tmp/", "Directory to store models.")
|
|
|
+
|
|
|
+FLAGS = tf.app.flags.FLAGS
|
|
|
+
|
|
|
+
|
|
|
+def initialize(sess):
|
|
|
+ """Initialize data and model."""
|
|
|
+ if FLAGS.jobid >= 0:
|
|
|
+ data.log_filename = os.path.join(FLAGS.train_dir, "log%d" % FLAGS.jobid)
|
|
|
+ data.print_out("NN ", newline=False)
|
|
|
+
|
|
|
+ # Set random seed.
|
|
|
+ seed = FLAGS.random_seed + max(0, FLAGS.jobid)
|
|
|
+ tf.set_random_seed(seed)
|
|
|
+ random.seed(seed)
|
|
|
+ np.random.seed(seed)
|
|
|
+
|
|
|
+ # Check data sizes.
|
|
|
+ data.forward_max = max(FLAGS.forward_max, data.bins[-1])
|
|
|
+ assert data.bins
|
|
|
+ min_length = 3
|
|
|
+ max_length = min(FLAGS.max_length, data.bins[-1])
|
|
|
+ assert max_length + 1 > min_length
|
|
|
+ while len(data.bins) > 1 and data.bins[-2] > max_length + 12:
|
|
|
+ data.bins = data.bins[:-1]
|
|
|
+ assert data.bins[0] > FLAGS.rx_step
|
|
|
+ nclass = min(FLAGS.niclass, FLAGS.noclass)
|
|
|
+ data_size = FLAGS.train_data_size if FLAGS.mode == 0 else 1000
|
|
|
+
|
|
|
+ # Initialize data for each task.
|
|
|
+ tasks = FLAGS.task.split("-")
|
|
|
+ for t in tasks:
|
|
|
+ for l in xrange(max_length + 11):
|
|
|
+ data.init_data(t, l, data_size, nclass)
|
|
|
+ data.init_data(t, data.bins[-2], data_size, nclass)
|
|
|
+ data.init_data(t, data.bins[-1], data_size, nclass)
|
|
|
+ end_size = 4 * 1024 if FLAGS.mode > 0 else 1024
|
|
|
+ data.init_data(t, data.forward_max, end_size, nclass)
|
|
|
+
|
|
|
+ # Print out parameters.
|
|
|
+ curriculum = 0.12
|
|
|
+ fin = ("cv %d kw %d h %d kh %d rxr %d bs %d ns %.2f t %s"
|
|
|
+ % (FLAGS.nconvs, FLAGS.kw, FLAGS.height, FLAGS.kh, FLAGS.rx_step,
|
|
|
+ FLAGS.batch_size, FLAGS.grad_noise_scale, FLAGS.task))
|
|
|
+ fin = "data %d %s" % (FLAGS.train_data_size, fin)
|
|
|
+ tag = ("df %.2f p %.3f lr %.2f iw %.2f cr %.2f nm %d d%.4f gn %.2f %s" %
|
|
|
+ (FLAGS.cutoff, FLAGS.pull_incr, FLAGS.lr, FLAGS.init_weight,
|
|
|
+ curriculum, FLAGS.nmaps, FLAGS.dropout, FLAGS.max_grad_norm, fin))
|
|
|
+ data.print_out(tag)
|
|
|
+
|
|
|
+ # Create checkpoint directory if it does not exist.
|
|
|
+ checkpoint_dir = os.path.join(FLAGS.train_dir, "neural_gpu%s"
|
|
|
+ % ("" if FLAGS.jobid < 0 else str(FLAGS.jobid)))
|
|
|
+ if not gfile.IsDirectory(checkpoint_dir):
|
|
|
+ data.print_out("Creating checkpoint directory %s." % checkpoint_dir)
|
|
|
+ gfile.MkDir(checkpoint_dir)
|
|
|
+
|
|
|
+ # Create model and initialize it.
|
|
|
+ tf.get_variable_scope().set_initializer(
|
|
|
+ tf.uniform_unit_scaling_initializer(factor=1.8 * FLAGS.init_weight))
|
|
|
+ model = ngpu.NeuralGPU(
|
|
|
+ FLAGS.nmaps, FLAGS.nmaps, FLAGS.niclass, FLAGS.noclass, FLAGS.dropout,
|
|
|
+ FLAGS.rx_step, FLAGS.max_grad_norm, FLAGS.cutoff, FLAGS.nconvs,
|
|
|
+ FLAGS.kw, FLAGS.kh, FLAGS.height, FLAGS.mode, FLAGS.lr,
|
|
|
+ FLAGS.pull, FLAGS.pull_incr, min_length + 3)
|
|
|
+ data.print_out("Created model.")
|
|
|
+ sess.run(tf.initialize_all_variables())
|
|
|
+ data.print_out("Initialized variables.")
|
|
|
+
|
|
|
+ # Load model from parameters if a checkpoint exists.
|
|
|
+ ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
|
|
|
+ if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
|
|
|
+ data.print_out("Reading model parameters from %s"
|
|
|
+ % ckpt.model_checkpoint_path)
|
|
|
+ model.saver.restore(sess, ckpt.model_checkpoint_path)
|
|
|
+
|
|
|
+ # Return the model and needed variables.
|
|
|
+ return (model, min_length, max_length, checkpoint_dir, curriculum)
|
|
|
+
|
|
|
+
|
|
|
+def single_test(l, model, sess, task, nprint, batch_size, print_out=True,
|
|
|
+ offset=None):
|
|
|
+ """Test model on test data of length l using the given session."""
|
|
|
+ inpt, target = data.get_batch(l, batch_size, False, task, offset)
|
|
|
+ _, res, _, steps = model.step(sess, inpt, target, False)
|
|
|
+ errors, total, seq = data.accuracy(inpt, res, target, batch_size, nprint)
|
|
|
+ seq = float(seq) / batch_size
|
|
|
+ if total > 0:
|
|
|
+ errors = float(errors) / total
|
|
|
+ if print_out:
|
|
|
+ data.print_out(" %s len %d errors %.2f sequence-errors %.2f"
|
|
|
+ % (task, l, 100*errors, 100*seq))
|
|
|
+ return errors, seq, (steps, inpt, [np.argmax(o, axis=1) for o in res])
|
|
|
+
|
|
|
+
|
|
|
+def multi_test(l, model, sess, task, nprint, batch_size, offset=None):
|
|
|
+ """Run multiple tests at lower batch size to save memory."""
|
|
|
+ errors = 0.0
|
|
|
+ seq = 0.0
|
|
|
+ to_print = nprint
|
|
|
+ low_batch = FLAGS.low_batch_size
|
|
|
+ low_batch = min(low_batch, batch_size)
|
|
|
+ for mstep in xrange(batch_size / low_batch):
|
|
|
+ cur_offset = None if offset is None else offset + mstep * low_batch
|
|
|
+ err, sq, _ = single_test(l, model, sess, task, to_print, low_batch, False,
|
|
|
+ cur_offset)
|
|
|
+ to_print = max(0, to_print - low_batch)
|
|
|
+ errors += err
|
|
|
+ seq += sq
|
|
|
+ if FLAGS.mode > 0:
|
|
|
+ cur_errors = float(low_batch * errors) / ((mstep+1) * low_batch)
|
|
|
+ cur_seq = float(low_batch * seq) / ((mstep+1) * low_batch)
|
|
|
+ data.print_out(" %s multitest current errors %.2f sequence-errors %.2f"
|
|
|
+ % (task, 100*cur_errors, 100*cur_seq))
|
|
|
+ errors = float(low_batch) * float(errors) / batch_size
|
|
|
+ seq = float(low_batch) * float(seq) / batch_size
|
|
|
+ data.print_out(" %s len %d errors %.2f sequence-errors %.2f"
|
|
|
+ % (task, l, 100*errors, 100*seq))
|
|
|
+ return errors, seq
|
|
|
+
|
|
|
+
|
|
|
+def train():
|
|
|
+ """Main training function."""
|
|
|
+ batch_size = FLAGS.batch_size
|
|
|
+ tasks = FLAGS.task.split("-")
|
|
|
+ with tf.Session() as sess:
|
|
|
+ model, min_length, max_length, checkpoint_dir, curriculum = initialize(sess)
|
|
|
+ max_cur_length = min(min_length + 3, max_length)
|
|
|
+ prev_acc_perp = [1000000 for _ in xrange(3)]
|
|
|
+ prev_sq = 1.0
|
|
|
+
|
|
|
+ while True:
|
|
|
+ global_step, pull, max_cur_length, learning_rate = sess.run(
|
|
|
+ [model.global_step, model.pull, model.cur_length, model.lr])
|
|
|
+ ep = global_step / FLAGS.steps_per_checkpoint
|
|
|
+ acc_loss, acc_total, acc_errors, acc_seq = 0.0, 0, 0, 0
|
|
|
+ acc_grad_norm, step_count, step_time = 0.0, 0, 0.0
|
|
|
+ for _ in xrange(FLAGS.steps_per_checkpoint):
|
|
|
+ global_step += 1
|
|
|
+ task = random.choice(tasks)
|
|
|
+ l1 = np.random.randint(max_cur_length - min_length + 1) + min_length
|
|
|
+ l = l1
|
|
|
+ if np.random.randint(10) > 3: # Prefer longer stuff 60% of time.
|
|
|
+ l = np.random.randint(max_cur_length - min_length+1) + min_length
|
|
|
+ l = max(l, l1)
|
|
|
+ if np.random.randint(4) < 1: # Mixed learning: once in a while big.
|
|
|
+ l = np.random.randint(max_length - min_length + 1) + min_length
|
|
|
+ l = max(l, l1)
|
|
|
+ start_time = time.time()
|
|
|
+ inp, target = data.get_batch(l, batch_size, True, task)
|
|
|
+ stepp = math.pow(global_step, -0.55)
|
|
|
+ noise_param = math.sqrt(stepp * 20 * prev_sq) * FLAGS.grad_noise_scale
|
|
|
+ loss, res, gnorm, _ = model.step(sess, inp, target, True, noise_param)
|
|
|
+ step_time += time.time() - start_time
|
|
|
+ acc_grad_norm += float(gnorm)
|
|
|
+ if l < max_cur_length + 1:
|
|
|
+ step_count += 1
|
|
|
+ acc_loss += loss
|
|
|
+ errors, total, seq = data.accuracy(inp, res, target,
|
|
|
+ batch_size, 0)
|
|
|
+ acc_total += total
|
|
|
+ acc_errors += errors
|
|
|
+ acc_seq += seq
|
|
|
+ acc_loss /= step_count
|
|
|
+ step_time /= FLAGS.steps_per_checkpoint
|
|
|
+ acc_seq = float(acc_seq) / (step_count * batch_size)
|
|
|
+ prev_sq = acc_seq
|
|
|
+ acc_errors = float(acc_errors) / acc_total if acc_total > 0 else 1.0
|
|
|
+ msg1 = "ep %d st %.2f lr %.8f" % (ep, step_time, learning_rate)
|
|
|
+ msg2 = "pl %.3f cme %.3f" % (pull, curriculum)
|
|
|
+ msg = ("%s %s gn %.8f"
|
|
|
+ % (msg1, msg2, acc_grad_norm / FLAGS.steps_per_checkpoint))
|
|
|
+ data.print_out("%s len %d ppx %.8f errs %.2f sq %.2f" %
|
|
|
+ (msg, max_cur_length, data.safe_exp(acc_loss),
|
|
|
+ 100*acc_errors, 100*acc_seq))
|
|
|
+ if curriculum > acc_seq:
|
|
|
+ prev_acc_perp.append(1000000)
|
|
|
+ do_incr = True
|
|
|
+ while do_incr and max_cur_length < max_length:
|
|
|
+ sess.run(model.cur_length_incr_op)
|
|
|
+ for t in tasks:
|
|
|
+ if data.train_set[t]: do_incr = False
|
|
|
+ if pull < 1:
|
|
|
+ sess.run(model.pull_incr_op)
|
|
|
+ else:
|
|
|
+ data.print_out(" Averaging parameters.")
|
|
|
+ sess.run([model.avg_op, model.lr_decay_op])
|
|
|
+ else:
|
|
|
+ acc_perp = data.safe_exp(acc_loss)
|
|
|
+ if acc_perp > max(prev_acc_perp[-3:]):
|
|
|
+ sess.run(model.lr_decay_op)
|
|
|
+ prev_acc_perp.append(acc_perp)
|
|
|
+ checkpoint_path = os.path.join(checkpoint_dir, "neural_gpu.ckpt")
|
|
|
+ model.saver.save(sess, checkpoint_path,
|
|
|
+ global_step=model.global_step)
|
|
|
+ # Run evaluation.
|
|
|
+ should_exit = True
|
|
|
+ bound = data.bins[-1] + 1
|
|
|
+ for t in tasks:
|
|
|
+ l = min_length
|
|
|
+ while l < max_length + 12 and l < bound:
|
|
|
+ _, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
|
|
|
+ l += 1
|
|
|
+ while l < bound + 1 and not data.test_set[t][l]:
|
|
|
+ l += 1
|
|
|
+ if sq < 0.5:
|
|
|
+ _, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
|
|
|
+ batch_size * 4)
|
|
|
+ if sq > 0.001: should_exit = False
|
|
|
+ if should_exit:
|
|
|
+ if data.forward_max > 4000 and len(tasks) == 1:
|
|
|
+ multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
|
|
|
+ batch_size * 16, 0)
|
|
|
+
|
|
|
+
|
|
|
+def animate(l, test_data, anim_size):
|
|
|
+ """Create animation for the given data (hacky matplotlib use)."""
|
|
|
+ xf = 12
|
|
|
+ fps = 2
|
|
|
+ fig = plt.figure(figsize=(16, 9), facecolor="white")
|
|
|
+ ax = fig.add_axes([0, 0, 1, 1], frameon=False, zorder=2)
|
|
|
+ ax.set_xticks([i * 24-0.5 for i in xrange(4)])
|
|
|
+ ax.set_xticklabels([])
|
|
|
+ ax.set_yticks([i - 0.5 for i in xrange(l+1)])
|
|
|
+ ax.grid(which="major", axis="both", linestyle="-", color="black")
|
|
|
+ text_fields = []
|
|
|
+ text_size = 24*32/l
|
|
|
+ for y in xrange(l):
|
|
|
+ text_fields.append(ax.text(
|
|
|
+ 11.25, y + 0.15, "", color="g", ha="center", va="center",
|
|
|
+ bbox={"facecolor": "b", "alpha": 0.01, "pad": 24 * text_size},
|
|
|
+ size=text_size - (4 * 32 / l), animated=True))
|
|
|
+ im = ax.imshow(np.zeros_like(test_data[0][0][0]), vmin=-1.0,
|
|
|
+ vmax=1.0, cmap="gray", aspect="auto", origin="upper",
|
|
|
+ interpolation="none", animated=True)
|
|
|
+ im.set_zorder(1)
|
|
|
+ def to_symbol(i):
|
|
|
+ if i == 0: return ""
|
|
|
+ if i == 11: return "+"
|
|
|
+ if i == 12: return "*"
|
|
|
+ return str(i-1)
|
|
|
+ def animation_update(frame_no, test_data, xf, im, text_fields):
|
|
|
+ """Update an animation frame."""
|
|
|
+ steps, inpt, out_raw = test_data
|
|
|
+ length = len(steps)
|
|
|
+ batch = frame_no / (fps * (l+4*xf))
|
|
|
+ index = int((frame_no % (fps * (l+4*xf))) / fps)
|
|
|
+ # Cut output after first padding.
|
|
|
+ out = [out_raw[i][batch] for i in xrange(len(text_fields))]
|
|
|
+ if 0 in out:
|
|
|
+ i = out.index(0)
|
|
|
+ out = out[0:i] + [0 for _ in xrange(len(out) - i)]
|
|
|
+ # Show the state after the first frames.
|
|
|
+ if index >= 2*xf:
|
|
|
+ im.set_array(steps[min(length - 1, index - 2*xf)][batch])
|
|
|
+ for i, t in enumerate(text_fields):
|
|
|
+ if index - 2*xf < length:
|
|
|
+ t.set_text("")
|
|
|
+ else:
|
|
|
+ t.set_text(to_symbol(out[i]))
|
|
|
+ else:
|
|
|
+ for i, t in enumerate(text_fields):
|
|
|
+ t.set_text(to_symbol(inpt[i][batch]) if index < xf else "")
|
|
|
+ if index < xf:
|
|
|
+ im.set_array(np.zeros_like(steps[0][0]))
|
|
|
+ else:
|
|
|
+ im.set_array(steps[0][batch])
|
|
|
+ return im,
|
|
|
+ animation = anim.FuncAnimation(
|
|
|
+ fig, animation_update, blit=True, frames=(l+4*xf)*anim_size*fps,
|
|
|
+ interval=500/fps, fargs=(test_data, xf, im, text_fields))
|
|
|
+ animation.save("/tmp/neural_gpu.mp4", writer="mencoder", fps=4*fps, dpi=3*80)
|
|
|
+
|
|
|
+
|
|
|
+def evaluate():
|
|
|
+ """Evaluate an existing model."""
|
|
|
+ batch_size = FLAGS.batch_size
|
|
|
+ tasks = FLAGS.task.split("-")
|
|
|
+ with tf.Session() as sess:
|
|
|
+ model, min_length, max_length, _, _ = initialize(sess)
|
|
|
+ bound = data.bins[-1] + 1
|
|
|
+ for t in tasks:
|
|
|
+ l = min_length
|
|
|
+ while l < max_length + 12 and l < bound:
|
|
|
+ _, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
|
|
|
+ l += 1
|
|
|
+ while l < bound + 1 and not data.test_set[t][l]:
|
|
|
+ l += 1
|
|
|
+ # Animate.
|
|
|
+ anim_size = 2
|
|
|
+ _, _, test_data = single_test(l, model, sess, t, 0, anim_size)
|
|
|
+ animate(l, test_data, anim_size)
|
|
|
+ # More tests.
|
|
|
+ _, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
|
|
|
+ batch_size * 4)
|
|
|
+ if sq < 0.01: # More tests.
|
|
|
+ if data.forward_max > 4000 and len(tasks) == 1:
|
|
|
+ multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
|
|
|
+ batch_size * 64, 0)
|
|
|
+
|
|
|
+
|
|
|
+def interactive():
|
|
|
+ """Interactively probe an existing model."""
|
|
|
+ with tf.Session() as sess:
|
|
|
+ model, _, _, _, _ = initialize(sess)
|
|
|
+ sys.stdout.write("> ")
|
|
|
+ sys.stdout.flush()
|
|
|
+ inpt = sys.stdin.readline()
|
|
|
+ while inpt:
|
|
|
+ ids = [int(c) for c in inpt.strip()]
|
|
|
+ inpt, target = data.get_batch(len(ids), 1, False, "",
|
|
|
+ preset=(ids, [0 for _ in ids]))
|
|
|
+ _, res, _, _ = model.step(sess, inpt, target, False)
|
|
|
+ res = [np.argmax(o, axis=1) for o in res]
|
|
|
+ print " ".join([str(output[0]) for output in res])
|
|
|
+ sys.stdout.write("> ")
|
|
|
+ sys.stdout.flush()
|
|
|
+ inpt = sys.stdin.readline()
|
|
|
+
|
|
|
+
|
|
|
+def main(_):
|
|
|
+ if FLAGS.mode == 0:
|
|
|
+ train()
|
|
|
+ elif FLAGS.mode == 1:
|
|
|
+ evaluate()
|
|
|
+ else:
|
|
|
+ interactive()
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ tf.app.run()
|