neural_gpu_trainer.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. """Neural GPU for Learning Algorithms."""
  2. import math
  3. import os
  4. import random
  5. import sys
  6. import time
  7. import google3
  8. import matplotlib.animation as anim
  9. import matplotlib.pyplot as plt
  10. import numpy as np
  11. import tensorflow as tf
  12. from google3.third_party.tensorflow.python.platform import gfile
  13. import google3.experimental.users.lukaszkaiser.neural_gpu.data_utils as data
  14. import google3.experimental.users.lukaszkaiser.neural_gpu.neural_gpu as ngpu
  15. tf.app.flags.DEFINE_float("lr", 0.1, "Learning rate.")
  16. tf.app.flags.DEFINE_float("init_weight", 1.0, "Initial weights deviation.")
  17. tf.app.flags.DEFINE_float("max_grad_norm", 0.05, "Clip gradients to this norm.")
  18. tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.")
  19. tf.app.flags.DEFINE_float("pull", 0.0005, "Starting pull of the relaxations.")
  20. tf.app.flags.DEFINE_float("pull_incr", 1.2, "Increase pull by that much.")
  21. tf.app.flags.DEFINE_float("dropout", 0.2, "Dropout that much.")
  22. tf.app.flags.DEFINE_float("grad_noise_scale", 1.0, "Gradient noise scale.")
  23. tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size.")
  24. tf.app.flags.DEFINE_integer("low_batch_size", 16, "Low batch size.")
  25. tf.app.flags.DEFINE_integer("steps_per_checkpoint", 100, "Steps per epoch.")
  26. tf.app.flags.DEFINE_integer("nmaps", 24, "Number of floats in each cell.")
  27. tf.app.flags.DEFINE_integer("niclass", 14, "Number of classes (0 is padding).")
  28. tf.app.flags.DEFINE_integer("noclass", 14, "Number of classes (0 is padding).")
  29. tf.app.flags.DEFINE_integer("train_data_size", 5000, "Training examples/len.")
  30. tf.app.flags.DEFINE_integer("max_length", 41, "Maximum length.")
  31. tf.app.flags.DEFINE_integer("rx_step", 6, "Relax that many recursive steps.")
  32. tf.app.flags.DEFINE_integer("random_seed", 125459, "Random seed.")
  33. tf.app.flags.DEFINE_integer("nconvs", 2, "How many convolutions / 1 step.")
  34. tf.app.flags.DEFINE_integer("kw", 3, "Kernel width.")
  35. tf.app.flags.DEFINE_integer("kh", 3, "Kernel height.")
  36. tf.app.flags.DEFINE_integer("height", 4, "Height.")
  37. tf.app.flags.DEFINE_integer("forward_max", 401, "Maximum forward length.")
  38. tf.app.flags.DEFINE_integer("jobid", -1, "Task id when running on borg.")
  39. tf.app.flags.DEFINE_integer("nprint", 0, "How many test examples to print out.")
  40. tf.app.flags.DEFINE_integer("mode", 0, "Mode: 0-train other-decode.")
  41. tf.app.flags.DEFINE_string("task", "rev", "Which task are we learning?")
  42. tf.app.flags.DEFINE_string("train_dir", "/tmp/", "Directory to store models.")
  43. FLAGS = tf.app.flags.FLAGS
  44. def initialize(sess):
  45. """Initialize data and model."""
  46. if FLAGS.jobid >= 0:
  47. data.log_filename = os.path.join(FLAGS.train_dir, "log%d" % FLAGS.jobid)
  48. data.print_out("NN ", newline=False)
  49. # Set random seed.
  50. seed = FLAGS.random_seed + max(0, FLAGS.jobid)
  51. tf.set_random_seed(seed)
  52. random.seed(seed)
  53. np.random.seed(seed)
  54. # Check data sizes.
  55. data.forward_max = max(FLAGS.forward_max, data.bins[-1])
  56. assert data.bins
  57. min_length = 3
  58. max_length = min(FLAGS.max_length, data.bins[-1])
  59. assert max_length + 1 > min_length
  60. while len(data.bins) > 1 and data.bins[-2] > max_length + 12:
  61. data.bins = data.bins[:-1]
  62. assert data.bins[0] > FLAGS.rx_step
  63. nclass = min(FLAGS.niclass, FLAGS.noclass)
  64. data_size = FLAGS.train_data_size if FLAGS.mode == 0 else 1000
  65. # Initialize data for each task.
  66. tasks = FLAGS.task.split("-")
  67. for t in tasks:
  68. for l in xrange(max_length + 11):
  69. data.init_data(t, l, data_size, nclass)
  70. data.init_data(t, data.bins[-2], data_size, nclass)
  71. data.init_data(t, data.bins[-1], data_size, nclass)
  72. end_size = 4 * 1024 if FLAGS.mode > 0 else 1024
  73. data.init_data(t, data.forward_max, end_size, nclass)
  74. # Print out parameters.
  75. curriculum = 0.12
  76. fin = ("cv %d kw %d h %d kh %d rxr %d bs %d ns %.2f t %s"
  77. % (FLAGS.nconvs, FLAGS.kw, FLAGS.height, FLAGS.kh, FLAGS.rx_step,
  78. FLAGS.batch_size, FLAGS.grad_noise_scale, FLAGS.task))
  79. fin = "data %d %s" % (FLAGS.train_data_size, fin)
  80. tag = ("df %.2f p %.3f lr %.2f iw %.2f cr %.2f nm %d d%.4f gn %.2f %s" %
  81. (FLAGS.cutoff, FLAGS.pull_incr, FLAGS.lr, FLAGS.init_weight,
  82. curriculum, FLAGS.nmaps, FLAGS.dropout, FLAGS.max_grad_norm, fin))
  83. data.print_out(tag)
  84. # Create checkpoint directory if it does not exist.
  85. checkpoint_dir = os.path.join(FLAGS.train_dir, "neural_gpu%s"
  86. % ("" if FLAGS.jobid < 0 else str(FLAGS.jobid)))
  87. if not gfile.IsDirectory(checkpoint_dir):
  88. data.print_out("Creating checkpoint directory %s." % checkpoint_dir)
  89. gfile.MkDir(checkpoint_dir)
  90. # Create model and initialize it.
  91. tf.get_variable_scope().set_initializer(
  92. tf.uniform_unit_scaling_initializer(factor=1.8 * FLAGS.init_weight))
  93. model = ngpu.NeuralGPU(
  94. FLAGS.nmaps, FLAGS.nmaps, FLAGS.niclass, FLAGS.noclass, FLAGS.dropout,
  95. FLAGS.rx_step, FLAGS.max_grad_norm, FLAGS.cutoff, FLAGS.nconvs,
  96. FLAGS.kw, FLAGS.kh, FLAGS.height, FLAGS.mode, FLAGS.lr,
  97. FLAGS.pull, FLAGS.pull_incr, min_length + 3)
  98. data.print_out("Created model.")
  99. sess.run(tf.initialize_all_variables())
  100. data.print_out("Initialized variables.")
  101. # Load model from parameters if a checkpoint exists.
  102. ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
  103. if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
  104. data.print_out("Reading model parameters from %s"
  105. % ckpt.model_checkpoint_path)
  106. model.saver.restore(sess, ckpt.model_checkpoint_path)
  107. # Return the model and needed variables.
  108. return (model, min_length, max_length, checkpoint_dir, curriculum)
  109. def single_test(l, model, sess, task, nprint, batch_size, print_out=True,
  110. offset=None):
  111. """Test model on test data of length l using the given session."""
  112. inpt, target = data.get_batch(l, batch_size, False, task, offset)
  113. _, res, _, steps = model.step(sess, inpt, target, False)
  114. errors, total, seq = data.accuracy(inpt, res, target, batch_size, nprint)
  115. seq = float(seq) / batch_size
  116. if total > 0:
  117. errors = float(errors) / total
  118. if print_out:
  119. data.print_out(" %s len %d errors %.2f sequence-errors %.2f"
  120. % (task, l, 100*errors, 100*seq))
  121. return errors, seq, (steps, inpt, [np.argmax(o, axis=1) for o in res])
  122. def multi_test(l, model, sess, task, nprint, batch_size, offset=None):
  123. """Run multiple tests at lower batch size to save memory."""
  124. errors = 0.0
  125. seq = 0.0
  126. to_print = nprint
  127. low_batch = FLAGS.low_batch_size
  128. low_batch = min(low_batch, batch_size)
  129. for mstep in xrange(batch_size / low_batch):
  130. cur_offset = None if offset is None else offset + mstep * low_batch
  131. err, sq, _ = single_test(l, model, sess, task, to_print, low_batch, False,
  132. cur_offset)
  133. to_print = max(0, to_print - low_batch)
  134. errors += err
  135. seq += sq
  136. if FLAGS.mode > 0:
  137. cur_errors = float(low_batch * errors) / ((mstep+1) * low_batch)
  138. cur_seq = float(low_batch * seq) / ((mstep+1) * low_batch)
  139. data.print_out(" %s multitest current errors %.2f sequence-errors %.2f"
  140. % (task, 100*cur_errors, 100*cur_seq))
  141. errors = float(low_batch) * float(errors) / batch_size
  142. seq = float(low_batch) * float(seq) / batch_size
  143. data.print_out(" %s len %d errors %.2f sequence-errors %.2f"
  144. % (task, l, 100*errors, 100*seq))
  145. return errors, seq
  146. def train():
  147. """Main training function."""
  148. batch_size = FLAGS.batch_size
  149. tasks = FLAGS.task.split("-")
  150. with tf.Session() as sess:
  151. model, min_length, max_length, checkpoint_dir, curriculum = initialize(sess)
  152. max_cur_length = min(min_length + 3, max_length)
  153. prev_acc_perp = [1000000 for _ in xrange(3)]
  154. prev_sq = 1.0
  155. while True:
  156. global_step, pull, max_cur_length, learning_rate = sess.run(
  157. [model.global_step, model.pull, model.cur_length, model.lr])
  158. ep = global_step / FLAGS.steps_per_checkpoint
  159. acc_loss, acc_total, acc_errors, acc_seq = 0.0, 0, 0, 0
  160. acc_grad_norm, step_count, step_time = 0.0, 0, 0.0
  161. for _ in xrange(FLAGS.steps_per_checkpoint):
  162. global_step += 1
  163. task = random.choice(tasks)
  164. l1 = np.random.randint(max_cur_length - min_length + 1) + min_length
  165. l = l1
  166. if np.random.randint(10) > 3: # Prefer longer stuff 60% of time.
  167. l = np.random.randint(max_cur_length - min_length+1) + min_length
  168. l = max(l, l1)
  169. if np.random.randint(4) < 1: # Mixed learning: once in a while big.
  170. l = np.random.randint(max_length - min_length + 1) + min_length
  171. l = max(l, l1)
  172. start_time = time.time()
  173. inp, target = data.get_batch(l, batch_size, True, task)
  174. stepp = math.pow(global_step, -0.55)
  175. noise_param = math.sqrt(stepp * 20 * prev_sq) * FLAGS.grad_noise_scale
  176. loss, res, gnorm, _ = model.step(sess, inp, target, True, noise_param)
  177. step_time += time.time() - start_time
  178. acc_grad_norm += float(gnorm)
  179. if l < max_cur_length + 1:
  180. step_count += 1
  181. acc_loss += loss
  182. errors, total, seq = data.accuracy(inp, res, target,
  183. batch_size, 0)
  184. acc_total += total
  185. acc_errors += errors
  186. acc_seq += seq
  187. acc_loss /= step_count
  188. step_time /= FLAGS.steps_per_checkpoint
  189. acc_seq = float(acc_seq) / (step_count * batch_size)
  190. prev_sq = acc_seq
  191. acc_errors = float(acc_errors) / acc_total if acc_total > 0 else 1.0
  192. msg1 = "ep %d st %.2f lr %.8f" % (ep, step_time, learning_rate)
  193. msg2 = "pl %.3f cme %.3f" % (pull, curriculum)
  194. msg = ("%s %s gn %.8f"
  195. % (msg1, msg2, acc_grad_norm / FLAGS.steps_per_checkpoint))
  196. data.print_out("%s len %d ppx %.8f errs %.2f sq %.2f" %
  197. (msg, max_cur_length, data.safe_exp(acc_loss),
  198. 100*acc_errors, 100*acc_seq))
  199. if curriculum > acc_seq:
  200. prev_acc_perp.append(1000000)
  201. do_incr = True
  202. while do_incr and max_cur_length < max_length:
  203. sess.run(model.cur_length_incr_op)
  204. for t in tasks:
  205. if data.train_set[t]: do_incr = False
  206. if pull < 1:
  207. sess.run(model.pull_incr_op)
  208. else:
  209. data.print_out(" Averaging parameters.")
  210. sess.run([model.avg_op, model.lr_decay_op])
  211. else:
  212. acc_perp = data.safe_exp(acc_loss)
  213. if acc_perp > max(prev_acc_perp[-3:]):
  214. sess.run(model.lr_decay_op)
  215. prev_acc_perp.append(acc_perp)
  216. checkpoint_path = os.path.join(checkpoint_dir, "neural_gpu.ckpt")
  217. model.saver.save(sess, checkpoint_path,
  218. global_step=model.global_step)
  219. # Run evaluation.
  220. should_exit = True
  221. bound = data.bins[-1] + 1
  222. for t in tasks:
  223. l = min_length
  224. while l < max_length + 12 and l < bound:
  225. _, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
  226. l += 1
  227. while l < bound + 1 and not data.test_set[t][l]:
  228. l += 1
  229. if sq < 0.5:
  230. _, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
  231. batch_size * 4)
  232. if sq > 0.001: should_exit = False
  233. if should_exit:
  234. if data.forward_max > 4000 and len(tasks) == 1:
  235. multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
  236. batch_size * 16, 0)
  237. def animate(l, test_data, anim_size):
  238. """Create animation for the given data (hacky matplotlib use)."""
  239. xf = 12
  240. fps = 2
  241. fig = plt.figure(figsize=(16, 9), facecolor="white")
  242. ax = fig.add_axes([0, 0, 1, 1], frameon=False, zorder=2)
  243. ax.set_xticks([i * 24-0.5 for i in xrange(4)])
  244. ax.set_xticklabels([])
  245. ax.set_yticks([i - 0.5 for i in xrange(l+1)])
  246. ax.grid(which="major", axis="both", linestyle="-", color="black")
  247. text_fields = []
  248. text_size = 24*32/l
  249. for y in xrange(l):
  250. text_fields.append(ax.text(
  251. 11.25, y + 0.15, "", color="g", ha="center", va="center",
  252. bbox={"facecolor": "b", "alpha": 0.01, "pad": 24 * text_size},
  253. size=text_size - (4 * 32 / l), animated=True))
  254. im = ax.imshow(np.zeros_like(test_data[0][0][0]), vmin=-1.0,
  255. vmax=1.0, cmap="gray", aspect="auto", origin="upper",
  256. interpolation="none", animated=True)
  257. im.set_zorder(1)
  258. def to_symbol(i):
  259. if i == 0: return ""
  260. if i == 11: return "+"
  261. if i == 12: return "*"
  262. return str(i-1)
  263. def animation_update(frame_no, test_data, xf, im, text_fields):
  264. """Update an animation frame."""
  265. steps, inpt, out_raw = test_data
  266. length = len(steps)
  267. batch = frame_no / (fps * (l+4*xf))
  268. index = int((frame_no % (fps * (l+4*xf))) / fps)
  269. # Cut output after first padding.
  270. out = [out_raw[i][batch] for i in xrange(len(text_fields))]
  271. if 0 in out:
  272. i = out.index(0)
  273. out = out[0:i] + [0 for _ in xrange(len(out) - i)]
  274. # Show the state after the first frames.
  275. if index >= 2*xf:
  276. im.set_array(steps[min(length - 1, index - 2*xf)][batch])
  277. for i, t in enumerate(text_fields):
  278. if index - 2*xf < length:
  279. t.set_text("")
  280. else:
  281. t.set_text(to_symbol(out[i]))
  282. else:
  283. for i, t in enumerate(text_fields):
  284. t.set_text(to_symbol(inpt[i][batch]) if index < xf else "")
  285. if index < xf:
  286. im.set_array(np.zeros_like(steps[0][0]))
  287. else:
  288. im.set_array(steps[0][batch])
  289. return im,
  290. animation = anim.FuncAnimation(
  291. fig, animation_update, blit=True, frames=(l+4*xf)*anim_size*fps,
  292. interval=500/fps, fargs=(test_data, xf, im, text_fields))
  293. animation.save("/tmp/neural_gpu.mp4", writer="mencoder", fps=4*fps, dpi=3*80)
  294. def evaluate():
  295. """Evaluate an existing model."""
  296. batch_size = FLAGS.batch_size
  297. tasks = FLAGS.task.split("-")
  298. with tf.Session() as sess:
  299. model, min_length, max_length, _, _ = initialize(sess)
  300. bound = data.bins[-1] + 1
  301. for t in tasks:
  302. l = min_length
  303. while l < max_length + 12 and l < bound:
  304. _, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
  305. l += 1
  306. while l < bound + 1 and not data.test_set[t][l]:
  307. l += 1
  308. # Animate.
  309. anim_size = 2
  310. _, _, test_data = single_test(l, model, sess, t, 0, anim_size)
  311. animate(l, test_data, anim_size)
  312. # More tests.
  313. _, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
  314. batch_size * 4)
  315. if sq < 0.01: # More tests.
  316. if data.forward_max > 4000 and len(tasks) == 1:
  317. multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
  318. batch_size * 64, 0)
  319. def interactive():
  320. """Interactively probe an existing model."""
  321. with tf.Session() as sess:
  322. model, _, _, _, _ = initialize(sess)
  323. sys.stdout.write("> ")
  324. sys.stdout.flush()
  325. inpt = sys.stdin.readline()
  326. while inpt:
  327. ids = [int(c) for c in inpt.strip()]
  328. inpt, target = data.get_batch(len(ids), 1, False, "",
  329. preset=(ids, [0 for _ in ids]))
  330. _, res, _, _ = model.step(sess, inpt, target, False)
  331. res = [np.argmax(o, axis=1) for o in res]
  332. print " ".join([str(output[0]) for output in res])
  333. sys.stdout.write("> ")
  334. sys.stdout.flush()
  335. inpt = sys.stdin.readline()
  336. def main(_):
  337. if FLAGS.mode == 0:
  338. train()
  339. elif FLAGS.mode == 1:
  340. evaluate()
  341. else:
  342. interactive()
  343. if __name__ == "__main__":
  344. tf.app.run()