neural_gpu_trainer.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. """Neural GPU for Learning Algorithms."""
  2. import math
  3. import os
  4. import random
  5. import sys
  6. import time
  7. import matplotlib.animation as anim
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. import tensorflow as tf
  11. from tensorflow.python.platform import gfile
  12. import data_utils as data
  13. import neural_gpu
  14. tf.app.flags.DEFINE_float("lr", 0.1, "Learning rate.")
  15. tf.app.flags.DEFINE_float("init_weight", 1.0, "Initial weights deviation.")
  16. tf.app.flags.DEFINE_float("max_grad_norm", 0.05, "Clip gradients to this norm.")
  17. tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.")
  18. tf.app.flags.DEFINE_float("pull", 0.0005, "Starting pull of the relaxations.")
  19. tf.app.flags.DEFINE_float("pull_incr", 1.2, "Increase pull by that much.")
  20. tf.app.flags.DEFINE_float("dropout", 0.15, "Dropout that much.")
  21. tf.app.flags.DEFINE_float("grad_noise_scale", 1.0, "Gradient noise scale.")
  22. tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size.")
  23. tf.app.flags.DEFINE_integer("low_batch_size", 16, "Low batch size.")
  24. tf.app.flags.DEFINE_integer("steps_per_checkpoint", 100, "Steps per epoch.")
  25. tf.app.flags.DEFINE_integer("nmaps", 24, "Number of floats in each cell.")
  26. tf.app.flags.DEFINE_integer("niclass", 14, "Number of classes (0 is padding).")
  27. tf.app.flags.DEFINE_integer("noclass", 14, "Number of classes (0 is padding).")
  28. tf.app.flags.DEFINE_integer("train_data_size", 5000, "Training examples/len.")
  29. tf.app.flags.DEFINE_integer("max_length", 41, "Maximum length.")
  30. tf.app.flags.DEFINE_integer("rx_step", 6, "Relax that many recursive steps.")
  31. tf.app.flags.DEFINE_integer("random_seed", 125459, "Random seed.")
  32. tf.app.flags.DEFINE_integer("nconvs", 2, "How many convolutions / 1 step.")
  33. tf.app.flags.DEFINE_integer("kw", 3, "Kernel width.")
  34. tf.app.flags.DEFINE_integer("kh", 3, "Kernel height.")
  35. tf.app.flags.DEFINE_integer("height", 4, "Height.")
  36. tf.app.flags.DEFINE_integer("forward_max", 401, "Maximum forward length.")
  37. tf.app.flags.DEFINE_integer("jobid", -1, "Task id when running on borg.")
  38. tf.app.flags.DEFINE_integer("nprint", 0, "How many test examples to print out.")
  39. tf.app.flags.DEFINE_integer("mode", 0, "Mode: 0-train other-decode.")
  40. tf.app.flags.DEFINE_string("task", "rev", "Which task are we learning?")
  41. tf.app.flags.DEFINE_string("train_dir", "/tmp/", "Directory to store models.")
  42. FLAGS = tf.app.flags.FLAGS
  43. EXTRA_EVAL = 12
  44. def initialize(sess):
  45. """Initialize data and model."""
  46. if FLAGS.jobid >= 0:
  47. data.log_filename = os.path.join(FLAGS.train_dir, "log%d" % FLAGS.jobid)
  48. data.print_out("NN ", newline=False)
  49. # Set random seed.
  50. seed = FLAGS.random_seed + max(0, FLAGS.jobid)
  51. tf.set_random_seed(seed)
  52. random.seed(seed)
  53. np.random.seed(seed)
  54. # Check data sizes.
  55. data.forward_max = max(FLAGS.forward_max, data.bins[-1])
  56. assert data.bins
  57. min_length = 3
  58. max_length = min(FLAGS.max_length, data.bins[-1])
  59. assert max_length + 1 > min_length
  60. while len(data.bins) > 1 and data.bins[-2] > max_length + EXTRA_EVAL:
  61. data.bins = data.bins[:-1]
  62. assert data.bins[0] > FLAGS.rx_step
  63. nclass = min(FLAGS.niclass, FLAGS.noclass)
  64. data_size = FLAGS.train_data_size if FLAGS.mode == 0 else 1000
  65. # Initialize data for each task.
  66. tasks = FLAGS.task.split("-")
  67. for t in tasks:
  68. for l in xrange(max_length + EXTRA_EVAL - 1):
  69. data.init_data(t, l, data_size, nclass)
  70. data.init_data(t, data.bins[-2], data_size, nclass)
  71. data.init_data(t, data.bins[-1], data_size, nclass)
  72. end_size = 4 * 1024 if FLAGS.mode > 0 else 1024
  73. data.init_data(t, data.forward_max, end_size, nclass)
  74. # Print out parameters.
  75. curriculum = 0.12
  76. msg1 = ("layers %d kw %d h %d kh %d relax %d batch %d noise %.2f task %s"
  77. % (FLAGS.nconvs, FLAGS.kw, FLAGS.height, FLAGS.kh, FLAGS.rx_step,
  78. FLAGS.batch_size, FLAGS.grad_noise_scale, FLAGS.task))
  79. msg2 = "data %d %s" % (FLAGS.train_data_size, msg1)
  80. msg3 = ("cut %.2f pull %.3f lr %.2f iw %.2f cr %.2f nm %d d%.4f gn %.2f %s" %
  81. (FLAGS.cutoff, FLAGS.pull_incr, FLAGS.lr, FLAGS.init_weight,
  82. curriculum, FLAGS.nmaps, FLAGS.dropout, FLAGS.max_grad_norm, msg2))
  83. data.print_out(msg3)
  84. # Create checkpoint directory if it does not exist.
  85. checkpoint_dir = os.path.join(FLAGS.train_dir, "neural_gpu%s"
  86. % ("" if FLAGS.jobid < 0 else str(FLAGS.jobid)))
  87. if not gfile.IsDirectory(checkpoint_dir):
  88. data.print_out("Creating checkpoint directory %s." % checkpoint_dir)
  89. gfile.MkDir(checkpoint_dir)
  90. # Create model and initialize it.
  91. tf.get_variable_scope().set_initializer(
  92. tf.uniform_unit_scaling_initializer(factor=1.8 * FLAGS.init_weight))
  93. model = neural_gpu.NeuralGPU(
  94. FLAGS.nmaps, FLAGS.nmaps, FLAGS.niclass, FLAGS.noclass, FLAGS.dropout,
  95. FLAGS.rx_step, FLAGS.max_grad_norm, FLAGS.cutoff, FLAGS.nconvs,
  96. FLAGS.kw, FLAGS.kh, FLAGS.height, FLAGS.mode, FLAGS.lr,
  97. FLAGS.pull, FLAGS.pull_incr, min_length + 3)
  98. data.print_out("Created model.")
  99. sess.run(tf.initialize_all_variables())
  100. data.print_out("Initialized variables.")
  101. # Load model from parameters if a checkpoint exists.
  102. ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
  103. if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
  104. data.print_out("Reading model parameters from %s"
  105. % ckpt.model_checkpoint_path)
  106. model.saver.restore(sess, ckpt.model_checkpoint_path)
  107. # Return the model and needed variables.
  108. return (model, min_length, max_length, checkpoint_dir, curriculum)
  109. def single_test(l, model, sess, task, nprint, batch_size, print_out=True,
  110. offset=None):
  111. """Test model on test data of length l using the given session."""
  112. inpt, target = data.get_batch(l, batch_size, False, task, offset)
  113. _, res, _, steps = model.step(sess, inpt, target, False)
  114. errors, total, seq_err = data.accuracy(inpt, res, target, batch_size, nprint)
  115. seq_err = float(seq_err) / batch_size
  116. if total > 0:
  117. errors = float(errors) / total
  118. if print_out:
  119. data.print_out(" %s len %d errors %.2f sequence-errors %.2f"
  120. % (task, l, 100*errors, 100*seq_err))
  121. return errors, seq_err, (steps, inpt, [np.argmax(o, axis=1) for o in res])
  122. def multi_test(l, model, sess, task, nprint, batch_size, offset=None):
  123. """Run multiple tests at lower batch size to save memory."""
  124. errors, seq_err = 0.0, 0.0
  125. to_print = nprint
  126. low_batch = FLAGS.low_batch_size
  127. low_batch = min(low_batch, batch_size)
  128. for mstep in xrange(batch_size / low_batch):
  129. cur_offset = None if offset is None else offset + mstep * low_batch
  130. err, sq_err, _ = single_test(l, model, sess, task, to_print, low_batch,
  131. False, cur_offset)
  132. to_print = max(0, to_print - low_batch)
  133. errors += err
  134. seq_err += sq_err
  135. if FLAGS.mode > 0:
  136. cur_errors = float(low_batch * errors) / ((mstep+1) * low_batch)
  137. cur_seq_err = float(low_batch * seq_err) / ((mstep+1) * low_batch)
  138. data.print_out(" %s multitest current errors %.2f sequence-errors %.2f"
  139. % (task, 100*cur_errors, 100*cur_seq_err))
  140. errors = float(low_batch) * float(errors) / batch_size
  141. seq_err = float(low_batch) * float(seq_err) / batch_size
  142. data.print_out(" %s len %d errors %.2f sequence-errors %.2f"
  143. % (task, l, 100*errors, 100*seq_err))
  144. return errors, seq_err
  145. def train():
  146. """Train the model."""
  147. batch_size = FLAGS.batch_size
  148. tasks = FLAGS.task.split("-")
  149. with tf.Session() as sess:
  150. model, min_length, max_length, checkpoint_dir, curriculum = initialize(sess)
  151. max_cur_length = min(min_length + 3, max_length)
  152. prev_acc_perp = [1000000 for _ in xrange(3)]
  153. prev_seq_err = 1.0
  154. # Main traning loop.
  155. while True:
  156. global_step, pull, max_cur_length, learning_rate = sess.run(
  157. [model.global_step, model.pull, model.cur_length, model.lr])
  158. acc_loss, acc_total, acc_errors, acc_seq_err = 0.0, 0, 0, 0
  159. acc_grad_norm, step_count, step_time = 0.0, 0, 0.0
  160. for _ in xrange(FLAGS.steps_per_checkpoint):
  161. global_step += 1
  162. task = random.choice(tasks)
  163. # Select the length for curriculum learning.
  164. l = np.random.randint(max_cur_length - min_length + 1) + min_length
  165. # Prefer longer stuff 60% of time.
  166. if np.random.randint(100) < 60:
  167. l1 = np.random.randint(max_cur_length - min_length+1) + min_length
  168. l = max(l, l1)
  169. # Mixed curriculum learning: in 25% of cases go to any larger length.
  170. if np.random.randint(100) < 25:
  171. l1 = np.random.randint(max_length - min_length + 1) + min_length
  172. l = max(l, l1)
  173. # Run a step and time it.
  174. start_time = time.time()
  175. inp, target = data.get_batch(l, batch_size, True, task)
  176. noise_param = math.sqrt(math.pow(global_step, -0.55) *
  177. (20 * prev_seq_err)) * FLAGS.grad_noise_scale
  178. loss, res, gnorm, _ = model.step(sess, inp, target, True, noise_param)
  179. step_time += time.time() - start_time
  180. acc_grad_norm += float(gnorm)
  181. # Accumulate statistics only if we did not exceed curriculum length.
  182. if l < max_cur_length + 1:
  183. step_count += 1
  184. acc_loss += loss
  185. errors, total, seq_err = data.accuracy(inp, res, target,
  186. batch_size, 0)
  187. acc_total += total
  188. acc_errors += errors
  189. acc_seq_err += seq_err
  190. # Normalize and print out accumulated statistics.
  191. acc_loss /= step_count
  192. step_time /= FLAGS.steps_per_checkpoint
  193. acc_seq_err = float(acc_seq_err) / (step_count * batch_size)
  194. prev_seq_err = acc_seq_err
  195. acc_errors = float(acc_errors) / acc_total if acc_total > 0 else 1.0
  196. msg1 = "step %d step-time %.2f" % (global_step, step_time)
  197. msg2 = "lr %.8f pull %.3f" % (learning_rate, pull)
  198. msg3 = ("%s %s grad-norm %.8f"
  199. % (msg1, msg2, acc_grad_norm / FLAGS.steps_per_checkpoint))
  200. data.print_out("%s len %d ppx %.8f errors %.2f sequence-errors %.2f" %
  201. (msg3, max_cur_length, data.safe_exp(acc_loss),
  202. 100*acc_errors, 100*acc_seq_err))
  203. # If errors are below the curriculum threshold, move curriculum forward.
  204. if curriculum > acc_seq_err:
  205. # Increase current length (until the next with training data).
  206. do_incr = True
  207. while do_incr and max_cur_length < max_length:
  208. sess.run(model.cur_length_incr_op)
  209. for t in tasks:
  210. if data.train_set[t]: do_incr = False
  211. # Forget last perplexities if we're not yet at the end.
  212. if max_cur_length < max_length:
  213. prev_acc_perp.append(1000000)
  214. # Either increase pull or, if it's large, average parameters.
  215. if pull < 1:
  216. sess.run(model.pull_incr_op)
  217. else:
  218. data.print_out(" Averaging parameters.")
  219. sess.run([model.avg_op, model.lr_decay_op])
  220. # Lower learning rate if we're worse than the last 3 checkpoints.
  221. acc_perp = data.safe_exp(acc_loss)
  222. if acc_perp > max(prev_acc_perp[-3:]):
  223. sess.run(model.lr_decay_op)
  224. prev_acc_perp.append(acc_perp)
  225. # Save checkpoint.
  226. checkpoint_path = os.path.join(checkpoint_dir, "neural_gpu.ckpt")
  227. model.saver.save(sess, checkpoint_path,
  228. global_step=model.global_step)
  229. # Run evaluation.
  230. bound = data.bins[-1] + 1
  231. for t in tasks:
  232. l = min_length
  233. while l < max_length + EXTRA_EVAL and l < bound:
  234. _, seq_err, _ = single_test(l, model, sess, t,
  235. FLAGS.nprint, batch_size)
  236. l += 1
  237. while l < bound + 1 and not data.test_set[t][l]:
  238. l += 1
  239. if seq_err < 0.5: # Run larger test if we're good enough.
  240. _, seq_err = multi_test(data.forward_max, model, sess, t,
  241. FLAGS.nprint, batch_size * 4)
  242. if seq_err < 0.01: # Super-large test on 1-task large-forward models.
  243. if data.forward_max > 4000 and len(tasks) == 1:
  244. multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
  245. batch_size * 16, 0)
  246. def animate(l, test_data, anim_size):
  247. """Create animation for the given data (hacky matplotlib use)."""
  248. xf = 12 # Extra frames to slow down at start and end.
  249. fps = 2 # Frames per step.
  250. # Make the figure.
  251. fig = plt.figure(figsize=(16, 9), facecolor="white")
  252. ax = fig.add_axes([0, 0, 1, 1], frameon=False, zorder=2)
  253. ax.set_xticks([i * 24-0.5 for i in xrange(4)])
  254. ax.set_xticklabels([])
  255. ax.set_yticks([i - 0.5 for i in xrange(l+1)])
  256. ax.grid(which="major", axis="both", linestyle="-", color="black")
  257. # We need text fields.
  258. text_fields = []
  259. text_size = 24*32/l
  260. for y in xrange(l):
  261. text_fields.append(ax.text(
  262. 11.25, y + 0.15, "", color="g", ha="center", va="center",
  263. bbox={"facecolor": "b", "alpha": 0.01, "pad": 24 * text_size},
  264. size=text_size - (4 * 32 / l), animated=True))
  265. im = ax.imshow(np.zeros_like(test_data[0][0][0]), vmin=-1.0,
  266. vmax=1.0, cmap="gray", aspect="auto", origin="upper",
  267. interpolation="none", animated=True)
  268. im.set_zorder(1)
  269. # Main animation step.
  270. def animation_update(frame_no, test_data, xf, im, text_fields):
  271. """Update an animation frame."""
  272. steps, inpt, out_raw = test_data
  273. length = len(steps)
  274. batch = frame_no / (fps * (l+4*xf))
  275. index = int((frame_no % (fps * (l+4*xf))) / fps)
  276. # Cut output after first padding.
  277. out = [out_raw[i][batch] for i in xrange(len(text_fields))]
  278. if 0 in out:
  279. i = out.index(0)
  280. out = out[0:i] + [0 for _ in xrange(len(out) - i)]
  281. # Show the state after the first frames.
  282. if index >= 2*xf:
  283. im.set_array(steps[min(length - 1, index - 2*xf)][batch])
  284. for i, t in enumerate(text_fields):
  285. if index - 2*xf < length:
  286. t.set_text("")
  287. else:
  288. t.set_text(data.to_symbol(out[i]))
  289. else:
  290. for i, t in enumerate(text_fields):
  291. t.set_text(data.to_symbol(inpt[i][batch]) if index < xf else "")
  292. if index < xf:
  293. im.set_array(np.zeros_like(steps[0][0]))
  294. else:
  295. im.set_array(steps[0][batch])
  296. return im,
  297. # Create the animation and save to mp4.
  298. animation = anim.FuncAnimation(
  299. fig, animation_update, blit=True, frames=(l+4*xf)*anim_size*fps,
  300. interval=500/fps, fargs=(test_data, xf, im, text_fields))
  301. animation.save("/tmp/neural_gpu.mp4", writer="mencoder", fps=4*fps, dpi=3*80)
  302. def evaluate():
  303. """Evaluate an existing model."""
  304. batch_size = FLAGS.batch_size
  305. tasks = FLAGS.task.split("-")
  306. with tf.Session() as sess:
  307. model, min_length, max_length, _, _ = initialize(sess)
  308. bound = data.bins[-1] + 1
  309. for t in tasks:
  310. l = min_length
  311. while l < max_length + EXTRA_EVAL and l < bound:
  312. _, seq_err, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
  313. l += 1
  314. while l < bound + 1 and not data.test_set[t][l]:
  315. l += 1
  316. # Animate.
  317. anim_size = 2
  318. _, _, test_data = single_test(l, model, sess, t, 0, anim_size)
  319. animate(l, test_data, anim_size)
  320. # More tests.
  321. _, seq_err = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
  322. batch_size * 4)
  323. if seq_err < 0.01: # Super-test if we're very good and in large-test mode.
  324. if data.forward_max > 4000 and len(tasks) == 1:
  325. multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
  326. batch_size * 64, 0)
  327. def interactive():
  328. """Interactively probe an existing model."""
  329. with tf.Session() as sess:
  330. model, _, _, _, _ = initialize(sess)
  331. sys.stdout.write("Input to Neural GPU, e.g., 0 1. Use -1 for PAD.\n")
  332. sys.stdout.write("> ")
  333. sys.stdout.flush()
  334. inpt = sys.stdin.readline()
  335. while inpt:
  336. ids = [data.to_id(s) for s in inpt.strip().split()]
  337. inpt, target = data.get_batch(len(ids), 1, False, "",
  338. preset=(ids, [0 for _ in ids]))
  339. _, res, _, _ = model.step(sess, inpt, target, False)
  340. res = [np.argmax(o, axis=1) for o in res]
  341. res = [o for o in res[:len(ids)] if o > 0]
  342. print " " + " ".join([data.to_symbol(output[0]) for output in res])
  343. sys.stdout.write("> ")
  344. sys.stdout.flush()
  345. inpt = sys.stdin.readline()
  346. def main(_):
  347. if FLAGS.mode == 0:
  348. train()
  349. elif FLAGS.mode == 1:
  350. evaluate()
  351. else:
  352. interactive()
  353. if __name__ == "__main__":
  354. tf.app.run()