neural_gpu_trainer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. # Copyright 2015 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Neural GPU for Learning Algorithms."""
  16. import math
  17. import os
  18. import random
  19. import sys
  20. import time
  21. import matplotlib.animation as anim
  22. import matplotlib.pyplot as plt
  23. import numpy as np
  24. import tensorflow as tf
  25. from tensorflow.python.platform import gfile
  26. import data_utils as data
  27. import neural_gpu
  28. tf.app.flags.DEFINE_float("lr", 0.001, "Learning rate.")
  29. tf.app.flags.DEFINE_float("init_weight", 1.0, "Initial weights deviation.")
  30. tf.app.flags.DEFINE_float("max_grad_norm", 1.0, "Clip gradients to this norm.")
  31. tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.")
  32. tf.app.flags.DEFINE_float("pull", 0.0005, "Starting pull of the relaxations.")
  33. tf.app.flags.DEFINE_float("pull_incr", 1.2, "Increase pull by that much.")
  34. tf.app.flags.DEFINE_float("curriculum_bound", 0.15, "Move curriculum < this.")
  35. tf.app.flags.DEFINE_float("dropout", 0.15, "Dropout that much.")
  36. tf.app.flags.DEFINE_float("grad_noise_scale", 0.0, "Gradient noise scale.")
  37. tf.app.flags.DEFINE_integer("batch_size", 32, "Batch size.")
  38. tf.app.flags.DEFINE_integer("low_batch_size", 16, "Low batch size.")
  39. tf.app.flags.DEFINE_integer("steps_per_checkpoint", 200, "Steps per epoch.")
  40. tf.app.flags.DEFINE_integer("nmaps", 128, "Number of floats in each cell.")
  41. tf.app.flags.DEFINE_integer("niclass", 33, "Number of classes (0 is padding).")
  42. tf.app.flags.DEFINE_integer("noclass", 33, "Number of classes (0 is padding).")
  43. tf.app.flags.DEFINE_integer("train_data_size", 5000, "Training examples/len.")
  44. tf.app.flags.DEFINE_integer("max_length", 41, "Maximum length.")
  45. tf.app.flags.DEFINE_integer("rx_step", 6, "Relax that many recursive steps.")
  46. tf.app.flags.DEFINE_integer("random_seed", 125459, "Random seed.")
  47. tf.app.flags.DEFINE_integer("nconvs", 2, "How many convolutions / 1 step.")
  48. tf.app.flags.DEFINE_integer("kw", 3, "Kernel width.")
  49. tf.app.flags.DEFINE_integer("kh", 3, "Kernel height.")
  50. tf.app.flags.DEFINE_integer("height", 4, "Height.")
  51. tf.app.flags.DEFINE_integer("forward_max", 401, "Maximum forward length.")
  52. tf.app.flags.DEFINE_integer("jobid", -1, "Task id when running on borg.")
  53. tf.app.flags.DEFINE_integer("nprint", 0, "How many test examples to print out.")
  54. tf.app.flags.DEFINE_integer("mode", 0, "Mode: 0-train other-decode.")
  55. tf.app.flags.DEFINE_bool("animate", False, "Whether to produce an animation.")
  56. tf.app.flags.DEFINE_bool("quantize", False, "Whether to quantize variables.")
  57. tf.app.flags.DEFINE_string("task", "rev", "Which task are we learning?")
  58. tf.app.flags.DEFINE_string("train_dir", "/tmp/", "Directory to store models.")
  59. tf.app.flags.DEFINE_string("ensemble", "", "Model paths for ensemble.")
  60. FLAGS = tf.app.flags.FLAGS
  61. EXTRA_EVAL = 12
  62. def initialize(sess):
  63. """Initialize data and model."""
  64. if FLAGS.jobid >= 0:
  65. data.log_filename = os.path.join(FLAGS.train_dir, "log%d" % FLAGS.jobid)
  66. data.print_out("NN ", newline=False)
  67. # Set random seed.
  68. seed = FLAGS.random_seed + max(0, FLAGS.jobid)
  69. tf.set_random_seed(seed)
  70. random.seed(seed)
  71. np.random.seed(seed)
  72. # Check data sizes.
  73. assert data.bins
  74. min_length = 3
  75. max_length = min(FLAGS.max_length, data.bins[-1])
  76. assert max_length + 1 > min_length
  77. while len(data.bins) > 1 and data.bins[-2] > max_length + EXTRA_EVAL:
  78. data.bins = data.bins[:-1]
  79. assert data.bins[0] > FLAGS.rx_step
  80. data.forward_max = max(FLAGS.forward_max, data.bins[-1])
  81. nclass = min(FLAGS.niclass, FLAGS.noclass)
  82. data_size = FLAGS.train_data_size if FLAGS.mode == 0 else 1000
  83. # Initialize data for each task.
  84. tasks = FLAGS.task.split("-")
  85. for t in tasks:
  86. for l in xrange(max_length + EXTRA_EVAL - 1):
  87. data.init_data(t, l, data_size, nclass)
  88. data.init_data(t, data.bins[-2], data_size, nclass)
  89. data.init_data(t, data.bins[-1], data_size, nclass)
  90. end_size = 4 * 1024 if FLAGS.mode > 0 else 1024
  91. data.init_data(t, data.forward_max, end_size, nclass)
  92. # Print out parameters.
  93. curriculum = FLAGS.curriculum_bound
  94. msg1 = ("layers %d kw %d h %d kh %d relax %d batch %d noise %.2f task %s"
  95. % (FLAGS.nconvs, FLAGS.kw, FLAGS.height, FLAGS.kh, FLAGS.rx_step,
  96. FLAGS.batch_size, FLAGS.grad_noise_scale, FLAGS.task))
  97. msg2 = "data %d %s" % (FLAGS.train_data_size, msg1)
  98. msg3 = ("cut %.2f pull %.3f lr %.2f iw %.2f cr %.2f nm %d d%.4f gn %.2f %s" %
  99. (FLAGS.cutoff, FLAGS.pull_incr, FLAGS.lr, FLAGS.init_weight,
  100. curriculum, FLAGS.nmaps, FLAGS.dropout, FLAGS.max_grad_norm, msg2))
  101. data.print_out(msg3)
  102. # Create checkpoint directory if it does not exist.
  103. checkpoint_dir = os.path.join(FLAGS.train_dir, "neural_gpu%s"
  104. % ("" if FLAGS.jobid < 0 else str(FLAGS.jobid)))
  105. if not gfile.IsDirectory(checkpoint_dir):
  106. data.print_out("Creating checkpoint directory %s." % checkpoint_dir)
  107. gfile.MkDir(checkpoint_dir)
  108. # Create model and initialize it.
  109. tf.get_variable_scope().set_initializer(
  110. tf.uniform_unit_scaling_initializer(factor=1.8 * FLAGS.init_weight))
  111. model = neural_gpu.NeuralGPU(
  112. FLAGS.nmaps, FLAGS.nmaps, FLAGS.niclass, FLAGS.noclass, FLAGS.dropout,
  113. FLAGS.rx_step, FLAGS.max_grad_norm, FLAGS.cutoff, FLAGS.nconvs,
  114. FLAGS.kw, FLAGS.kh, FLAGS.height, FLAGS.mode, FLAGS.lr,
  115. FLAGS.pull, FLAGS.pull_incr, min_length + 3)
  116. data.print_out("Created model.")
  117. sess.run(tf.initialize_all_variables())
  118. data.print_out("Initialized variables.")
  119. # Load model from parameters if a checkpoint exists.
  120. ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
  121. if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
  122. data.print_out("Reading model parameters from %s"
  123. % ckpt.model_checkpoint_path)
  124. model.saver.restore(sess, ckpt.model_checkpoint_path)
  125. # Check if there are ensemble models and get their checkpoints.
  126. ensemble = []
  127. ensemble_dir_list = [d for d in FLAGS.ensemble.split(",") if d]
  128. for ensemble_dir in ensemble_dir_list:
  129. ckpt = tf.train.get_checkpoint_state(ensemble_dir)
  130. if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
  131. data.print_out("Found ensemble model %s" % ckpt.model_checkpoint_path)
  132. ensemble.append(ckpt.model_checkpoint_path)
  133. # Return the model and needed variables.
  134. return (model, min_length, max_length, checkpoint_dir, curriculum, ensemble)
  135. def single_test(l, model, sess, task, nprint, batch_size, print_out=True,
  136. offset=None, ensemble=None, get_steps=False):
  137. """Test model on test data of length l using the given session."""
  138. inpt, target = data.get_batch(l, batch_size, False, task, offset)
  139. _, res, _, steps = model.step(sess, inpt, target, False, get_steps=get_steps)
  140. errors, total, seq_err = data.accuracy(inpt, res, target, batch_size, nprint)
  141. seq_err = float(seq_err) / batch_size
  142. if total > 0:
  143. errors = float(errors) / total
  144. if print_out:
  145. data.print_out(" %s len %d errors %.2f sequence-errors %.2f"
  146. % (task, l, 100*errors, 100*seq_err))
  147. # Ensemble eval.
  148. if ensemble:
  149. results = []
  150. for m in ensemble:
  151. model.saver.restore(sess, m)
  152. _, result, _, _ = model.step(sess, inpt, target, False)
  153. m_errors, m_total, m_seq_err = data.accuracy(inpt, result, target,
  154. batch_size, nprint)
  155. m_seq_err = float(m_seq_err) / batch_size
  156. if total > 0:
  157. m_errors = float(m_errors) / m_total
  158. data.print_out(" %s len %d m-errors %.2f m-sequence-errors %.2f"
  159. % (task, l, 100*m_errors, 100*m_seq_err))
  160. results.append(result)
  161. ens = [sum(o) for o in zip(*results)]
  162. errors, total, seq_err = data.accuracy(inpt, ens, target,
  163. batch_size, nprint)
  164. seq_err = float(seq_err) / batch_size
  165. if total > 0:
  166. errors = float(errors) / total
  167. if print_out:
  168. data.print_out(" %s len %d ens-errors %.2f ens-sequence-errors %.2f"
  169. % (task, l, 100*errors, 100*seq_err))
  170. return errors, seq_err, (steps, inpt, [np.argmax(o, axis=1) for o in res])
  171. def multi_test(l, model, sess, task, nprint, batch_size, offset=None,
  172. ensemble=None):
  173. """Run multiple tests at lower batch size to save memory."""
  174. errors, seq_err = 0.0, 0.0
  175. to_print = nprint
  176. low_batch = FLAGS.low_batch_size
  177. low_batch = min(low_batch, batch_size)
  178. for mstep in xrange(batch_size / low_batch):
  179. cur_offset = None if offset is None else offset + mstep * low_batch
  180. err, sq_err, _ = single_test(l, model, sess, task, to_print, low_batch,
  181. False, cur_offset, ensemble=ensemble)
  182. to_print = max(0, to_print - low_batch)
  183. errors += err
  184. seq_err += sq_err
  185. if FLAGS.mode > 0:
  186. cur_errors = float(low_batch * errors) / ((mstep+1) * low_batch)
  187. cur_seq_err = float(low_batch * seq_err) / ((mstep+1) * low_batch)
  188. data.print_out(" %s multitest current errors %.2f sequence-errors %.2f"
  189. % (task, 100*cur_errors, 100*cur_seq_err))
  190. errors = float(low_batch) * float(errors) / batch_size
  191. seq_err = float(low_batch) * float(seq_err) / batch_size
  192. data.print_out(" %s len %d errors %.2f sequence-errors %.2f"
  193. % (task, l, 100*errors, 100*seq_err))
  194. return errors, seq_err
  195. def train():
  196. """Train the model."""
  197. batch_size = FLAGS.batch_size
  198. tasks = FLAGS.task.split("-")
  199. with tf.Session() as sess:
  200. (model, min_length, max_length, checkpoint_dir,
  201. curriculum, _) = initialize(sess)
  202. quant_op = neural_gpu.quantize_weights_op(512, 8)
  203. max_cur_length = min(min_length + 3, max_length)
  204. prev_acc_perp = [1000000 for _ in xrange(3)]
  205. prev_seq_err = 1.0
  206. # Main traning loop.
  207. while True:
  208. global_step, pull, max_cur_length, learning_rate = sess.run(
  209. [model.global_step, model.pull, model.cur_length, model.lr])
  210. acc_loss, acc_total, acc_errors, acc_seq_err = 0.0, 0, 0, 0
  211. acc_grad_norm, step_count, step_time = 0.0, 0, 0.0
  212. for _ in xrange(FLAGS.steps_per_checkpoint):
  213. global_step += 1
  214. task = random.choice(tasks)
  215. # Select the length for curriculum learning.
  216. l = np.random.randint(max_cur_length - min_length + 1) + min_length
  217. # Prefer longer stuff 60% of time.
  218. if np.random.randint(100) < 60:
  219. l1 = np.random.randint(max_cur_length - min_length+1) + min_length
  220. l = max(l, l1)
  221. # Mixed curriculum learning: in 25% of cases go to any larger length.
  222. if np.random.randint(100) < 25:
  223. l1 = np.random.randint(max_length - min_length + 1) + min_length
  224. l = max(l, l1)
  225. # Run a step and time it.
  226. start_time = time.time()
  227. inp, target = data.get_batch(l, batch_size, True, task)
  228. noise_param = math.sqrt(math.pow(global_step, -0.55) *
  229. prev_seq_err) * FLAGS.grad_noise_scale
  230. loss, res, gnorm, _ = model.step(sess, inp, target, True, noise_param)
  231. step_time += time.time() - start_time
  232. acc_grad_norm += float(gnorm)
  233. # Accumulate statistics only if we did not exceed curriculum length.
  234. if l < max_cur_length + 1:
  235. step_count += 1
  236. acc_loss += loss
  237. errors, total, seq_err = data.accuracy(inp, res, target,
  238. batch_size, 0)
  239. acc_total += total
  240. acc_errors += errors
  241. acc_seq_err += seq_err
  242. # Normalize and print out accumulated statistics.
  243. acc_loss /= step_count
  244. step_time /= FLAGS.steps_per_checkpoint
  245. acc_seq_err = float(acc_seq_err) / (step_count * batch_size)
  246. prev_seq_err = max(0.0, acc_seq_err - 0.02) # No noise at error < 2%.
  247. acc_errors = float(acc_errors) / acc_total if acc_total > 0 else 1.0
  248. msg1 = "step %d step-time %.2f" % (global_step, step_time)
  249. msg2 = "lr %.8f pull %.3f" % (learning_rate, pull)
  250. msg3 = ("%s %s grad-norm %.8f"
  251. % (msg1, msg2, acc_grad_norm / FLAGS.steps_per_checkpoint))
  252. data.print_out("%s len %d ppx %.8f errors %.2f sequence-errors %.2f" %
  253. (msg3, max_cur_length, data.safe_exp(acc_loss),
  254. 100*acc_errors, 100*acc_seq_err))
  255. # If errors are below the curriculum threshold, move curriculum forward.
  256. if curriculum > acc_seq_err:
  257. if FLAGS.quantize:
  258. # Quantize weights.
  259. data.print_out(" Quantizing parameters.")
  260. sess.run([quant_op])
  261. # Increase current length (until the next with training data).
  262. do_incr = True
  263. while do_incr and max_cur_length < max_length:
  264. sess.run(model.cur_length_incr_op)
  265. for t in tasks:
  266. if data.train_set[t]: do_incr = False
  267. # Forget last perplexities if we're not yet at the end.
  268. if max_cur_length < max_length:
  269. prev_acc_perp.append(1000000)
  270. # Either increase pull or, if it's large, average parameters.
  271. if pull < 0.1:
  272. sess.run(model.pull_incr_op)
  273. else:
  274. data.print_out(" Averaging parameters.")
  275. sess.run(model.avg_op)
  276. if acc_seq_err < (curriculum / 3.0):
  277. sess.run(model.lr_decay_op)
  278. # Lower learning rate if we're worse than the last 3 checkpoints.
  279. acc_perp = data.safe_exp(acc_loss)
  280. if acc_perp > max(prev_acc_perp[-3:]):
  281. sess.run(model.lr_decay_op)
  282. prev_acc_perp.append(acc_perp)
  283. # Save checkpoint.
  284. checkpoint_path = os.path.join(checkpoint_dir, "neural_gpu.ckpt")
  285. model.saver.save(sess, checkpoint_path,
  286. global_step=model.global_step)
  287. # Run evaluation.
  288. bound = data.bins[-1] + 1
  289. for t in tasks:
  290. l = min_length
  291. while l < max_length + EXTRA_EVAL and l < bound:
  292. _, seq_err, _ = single_test(l, model, sess, t,
  293. FLAGS.nprint, batch_size)
  294. l += 1
  295. while l < bound + 1 and not data.test_set[t][l]:
  296. l += 1
  297. if seq_err < 0.05: # Run larger test if we're good enough.
  298. _, seq_err = multi_test(data.forward_max, model, sess, t,
  299. FLAGS.nprint, batch_size * 4)
  300. if seq_err < 0.01: # Super-large test on 1-task large-forward models.
  301. if data.forward_max > 4000 and len(tasks) == 1:
  302. multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
  303. batch_size * 16, 0)
  304. def animate(l, test_data, anim_size):
  305. """Create animation for the given data (hacky matplotlib use)."""
  306. xf = 12 # Extra frames to slow down at start and end.
  307. fps = 2 # Frames per step.
  308. # Make the figure.
  309. fig = plt.figure(figsize=(16, 9), facecolor="white")
  310. ax = fig.add_axes([0, 0, 1, 1], frameon=False, zorder=2)
  311. ax.set_xticks([i * 24-0.5 for i in xrange(4)])
  312. ax.set_xticklabels([])
  313. ax.set_yticks([i - 0.5 for i in xrange(l+1)])
  314. ax.grid(which="major", axis="both", linestyle="-", color="black")
  315. # We need text fields.
  316. text_fields = []
  317. text_size = 24*32/l
  318. for y in xrange(l):
  319. text_fields.append(ax.text(
  320. 11.25, y + 0.15, "", color="g", ha="center", va="center",
  321. bbox={"facecolor": "b", "alpha": 0.01, "pad": 24 * text_size},
  322. size=text_size - (4 * 32 / l), animated=True))
  323. im = ax.imshow(np.zeros_like(test_data[0][0][0]), vmin=-1.0,
  324. vmax=1.0, cmap="gray", aspect="auto", origin="upper",
  325. interpolation="none", animated=True)
  326. im.set_zorder(1)
  327. # Main animation step.
  328. def animation_update(frame_no, test_data, xf, im, text_fields):
  329. """Update an animation frame."""
  330. steps, inpt, out_raw = test_data
  331. length = len(steps)
  332. batch = frame_no / (fps * (l+4*xf))
  333. index = int((frame_no % (fps * (l+4*xf))) / fps)
  334. # Cut output after first padding.
  335. out = [out_raw[i][batch] for i in xrange(len(text_fields))]
  336. if 0 in out:
  337. i = out.index(0)
  338. out = out[0:i] + [0 for _ in xrange(len(out) - i)]
  339. # Show the state after the first frames.
  340. if index >= 2*xf:
  341. im.set_array(steps[min(length - 1, index - 2*xf)][batch])
  342. for i, t in enumerate(text_fields):
  343. if index - 2*xf < length:
  344. t.set_text("")
  345. else:
  346. t.set_text(data.to_symbol(out[i]))
  347. else:
  348. for i, t in enumerate(text_fields):
  349. t.set_text(data.to_symbol(inpt[i][batch]) if index < xf else "")
  350. if index < xf:
  351. im.set_array(np.zeros_like(steps[0][0]))
  352. else:
  353. im.set_array(steps[0][batch])
  354. return im,
  355. # Create the animation and save to mp4.
  356. animation = anim.FuncAnimation(
  357. fig, animation_update, blit=True, frames=(l+4*xf)*anim_size*fps,
  358. interval=500/fps, fargs=(test_data, xf, im, text_fields))
  359. animation.save("/tmp/neural_gpu.mp4", writer="mencoder", fps=4*fps, dpi=3*80)
  360. def evaluate():
  361. """Evaluate an existing model."""
  362. batch_size = FLAGS.batch_size
  363. tasks = FLAGS.task.split("-")
  364. with tf.Session() as sess:
  365. model, min_length, max_length, _, _, ensemble = initialize(sess)
  366. bound = data.bins[-1] + 1
  367. for t in tasks:
  368. l = min_length
  369. while l < max_length + EXTRA_EVAL and l < bound:
  370. _, seq_err, _ = single_test(l, model, sess, t, FLAGS.nprint,
  371. batch_size, ensemble=ensemble)
  372. l += 1
  373. while l < bound + 1 and not data.test_set[t][l]:
  374. l += 1
  375. # Animate.
  376. if FLAGS.animate:
  377. anim_size = 2
  378. _, _, test_data = single_test(l, model, sess, t, 0, anim_size,
  379. get_steps=True)
  380. animate(l, test_data, anim_size)
  381. # More tests.
  382. _, seq_err = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
  383. batch_size * 4, ensemble=ensemble)
  384. if seq_err < 0.01: # Super-test if we're very good and in large-test mode.
  385. if data.forward_max > 4000 and len(tasks) == 1:
  386. multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
  387. batch_size * 64, 0, ensemble=ensemble)
  388. def interactive():
  389. """Interactively probe an existing model."""
  390. with tf.Session() as sess:
  391. model, _, _, _, _, _ = initialize(sess)
  392. sys.stdout.write("Input to Neural GPU, e.g., 0 1. Use -1 for PAD.\n")
  393. sys.stdout.write("> ")
  394. sys.stdout.flush()
  395. inpt = sys.stdin.readline()
  396. while inpt:
  397. ids = [data.to_id(s) for s in inpt.strip().split()]
  398. inpt, target = data.get_batch(len(ids), 1, False, "",
  399. preset=(ids, [0 for _ in ids]))
  400. _, res, _, _ = model.step(sess, inpt, target, False)
  401. res = [np.argmax(o, axis=1) for o in res]
  402. res = [o for o in res[:len(ids)] if o > 0]
  403. print " " + " ".join([data.to_symbol(output[0]) for output in res])
  404. sys.stdout.write("> ")
  405. sys.stdout.flush()
  406. inpt = sys.stdin.readline()
  407. def main(_):
  408. if FLAGS.mode == 0:
  409. train()
  410. elif FLAGS.mode == 1:
  411. evaluate()
  412. else:
  413. interactive()
  414. if __name__ == "__main__":
  415. tf.app.run()