neural_gpu_trainer.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. # Copyright 2015 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. #==============================================================================
  16. """Neural GPU for Learning Algorithms."""
  17. import math
  18. import os
  19. import random
  20. import sys
  21. import time
  22. import google3
  23. import matplotlib.animation as anim
  24. import matplotlib.pyplot as plt
  25. import numpy as np
  26. import tensorflow as tf
  27. from google3.third_party.tensorflow.python.platform import gfile
  28. import google3.experimental.users.lukaszkaiser.neural_gpu.data_utils as data
  29. import google3.experimental.users.lukaszkaiser.neural_gpu.neural_gpu as ngpu
  30. tf.app.flags.DEFINE_float("lr", 0.1, "Learning rate.")
  31. tf.app.flags.DEFINE_float("init_weight", 1.0, "Initial weights deviation.")
  32. tf.app.flags.DEFINE_float("max_grad_norm", 0.05, "Clip gradients to this norm.")
  33. tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.")
  34. tf.app.flags.DEFINE_float("pull", 0.0005, "Starting pull of the relaxations.")
  35. tf.app.flags.DEFINE_float("pull_incr", 1.2, "Increase pull by that much.")
  36. tf.app.flags.DEFINE_float("dropout", 0.2, "Dropout that much.")
  37. tf.app.flags.DEFINE_float("grad_noise_scale", 1.0, "Gradient noise scale.")
  38. tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size.")
  39. tf.app.flags.DEFINE_integer("low_batch_size", 16, "Low batch size.")
  40. tf.app.flags.DEFINE_integer("steps_per_checkpoint", 100, "Steps per epoch.")
  41. tf.app.flags.DEFINE_integer("nmaps", 24, "Number of floats in each cell.")
  42. tf.app.flags.DEFINE_integer("niclass", 14, "Number of classes (0 is padding).")
  43. tf.app.flags.DEFINE_integer("noclass", 14, "Number of classes (0 is padding).")
  44. tf.app.flags.DEFINE_integer("train_data_size", 5000, "Training examples/len.")
  45. tf.app.flags.DEFINE_integer("max_length", 41, "Maximum length.")
  46. tf.app.flags.DEFINE_integer("rx_step", 6, "Relax that many recursive steps.")
  47. tf.app.flags.DEFINE_integer("random_seed", 125459, "Random seed.")
  48. tf.app.flags.DEFINE_integer("nconvs", 2, "How many convolutions / 1 step.")
  49. tf.app.flags.DEFINE_integer("kw", 3, "Kernel width.")
  50. tf.app.flags.DEFINE_integer("kh", 3, "Kernel height.")
  51. tf.app.flags.DEFINE_integer("height", 4, "Height.")
  52. tf.app.flags.DEFINE_integer("forward_max", 401, "Maximum forward length.")
  53. tf.app.flags.DEFINE_integer("jobid", -1, "Task id when running on borg.")
  54. tf.app.flags.DEFINE_integer("nprint", 0, "How many test examples to print out.")
  55. tf.app.flags.DEFINE_integer("mode", 0, "Mode: 0-train other-decode.")
  56. tf.app.flags.DEFINE_string("task", "rev", "Which task are we learning?")
  57. tf.app.flags.DEFINE_string("train_dir", "/tmp/", "Directory to store models.")
  58. FLAGS = tf.app.flags.FLAGS
  59. def initialize(sess):
  60. """Initialize data and model."""
  61. if FLAGS.jobid >= 0:
  62. data.log_filename = os.path.join(FLAGS.train_dir, "log%d" % FLAGS.jobid)
  63. data.print_out("NN ", newline=False)
  64. # Set random seed.
  65. seed = FLAGS.random_seed + max(0, FLAGS.jobid)
  66. tf.set_random_seed(seed)
  67. random.seed(seed)
  68. np.random.seed(seed)
  69. # Check data sizes.
  70. data.forward_max = max(FLAGS.forward_max, data.bins[-1])
  71. assert data.bins
  72. min_length = 3
  73. max_length = min(FLAGS.max_length, data.bins[-1])
  74. assert max_length + 1 > min_length
  75. while len(data.bins) > 1 and data.bins[-2] > max_length + 12:
  76. data.bins = data.bins[:-1]
  77. assert data.bins[0] > FLAGS.rx_step
  78. nclass = min(FLAGS.niclass, FLAGS.noclass)
  79. data_size = FLAGS.train_data_size if FLAGS.mode == 0 else 1000
  80. # Initialize data for each task.
  81. tasks = FLAGS.task.split("-")
  82. for t in tasks:
  83. for l in xrange(max_length + 11):
  84. data.init_data(t, l, data_size, nclass)
  85. data.init_data(t, data.bins[-2], data_size, nclass)
  86. data.init_data(t, data.bins[-1], data_size, nclass)
  87. end_size = 4 * 1024 if FLAGS.mode > 0 else 1024
  88. data.init_data(t, data.forward_max, end_size, nclass)
  89. # Print out parameters.
  90. curriculum = 0.12
  91. fin = ("cv %d kw %d h %d kh %d rxr %d bs %d ns %.2f t %s"
  92. % (FLAGS.nconvs, FLAGS.kw, FLAGS.height, FLAGS.kh, FLAGS.rx_step,
  93. FLAGS.batch_size, FLAGS.grad_noise_scale, FLAGS.task))
  94. fin = "data %d %s" % (FLAGS.train_data_size, fin)
  95. tag = ("df %.2f p %.3f lr %.2f iw %.2f cr %.2f nm %d d%.4f gn %.2f %s" %
  96. (FLAGS.cutoff, FLAGS.pull_incr, FLAGS.lr, FLAGS.init_weight,
  97. curriculum, FLAGS.nmaps, FLAGS.dropout, FLAGS.max_grad_norm, fin))
  98. data.print_out(tag)
  99. # Create checkpoint directory if it does not exist.
  100. checkpoint_dir = os.path.join(FLAGS.train_dir, "neural_gpu%s"
  101. % ("" if FLAGS.jobid < 0 else str(FLAGS.jobid)))
  102. if not gfile.IsDirectory(checkpoint_dir):
  103. data.print_out("Creating checkpoint directory %s." % checkpoint_dir)
  104. gfile.MkDir(checkpoint_dir)
  105. # Create model and initialize it.
  106. tf.get_variable_scope().set_initializer(
  107. tf.uniform_unit_scaling_initializer(factor=1.8 * FLAGS.init_weight))
  108. model = ngpu.NeuralGPU(
  109. FLAGS.nmaps, FLAGS.nmaps, FLAGS.niclass, FLAGS.noclass, FLAGS.dropout,
  110. FLAGS.rx_step, FLAGS.max_grad_norm, FLAGS.cutoff, FLAGS.nconvs,
  111. FLAGS.kw, FLAGS.kh, FLAGS.height, FLAGS.mode, FLAGS.lr,
  112. FLAGS.pull, FLAGS.pull_incr, min_length + 3)
  113. data.print_out("Created model.")
  114. sess.run(tf.initialize_all_variables())
  115. data.print_out("Initialized variables.")
  116. # Load model from parameters if a checkpoint exists.
  117. ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
  118. if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
  119. data.print_out("Reading model parameters from %s"
  120. % ckpt.model_checkpoint_path)
  121. model.saver.restore(sess, ckpt.model_checkpoint_path)
  122. # Return the model and needed variables.
  123. return (model, min_length, max_length, checkpoint_dir, curriculum)
  124. def single_test(l, model, sess, task, nprint, batch_size, print_out=True,
  125. offset=None):
  126. """Test model on test data of length l using the given session."""
  127. inpt, target = data.get_batch(l, batch_size, False, task, offset)
  128. _, res, _, steps = model.step(sess, inpt, target, False)
  129. errors, total, seq = data.accuracy(inpt, res, target, batch_size, nprint)
  130. seq = float(seq) / batch_size
  131. if total > 0:
  132. errors = float(errors) / total
  133. if print_out:
  134. data.print_out(" %s len %d errors %.2f sequence-errors %.2f"
  135. % (task, l, 100*errors, 100*seq))
  136. return errors, seq, (steps, inpt, [np.argmax(o, axis=1) for o in res])
  137. def multi_test(l, model, sess, task, nprint, batch_size, offset=None):
  138. """Run multiple tests at lower batch size to save memory."""
  139. errors = 0.0
  140. seq = 0.0
  141. to_print = nprint
  142. low_batch = FLAGS.low_batch_size
  143. low_batch = min(low_batch, batch_size)
  144. for mstep in xrange(batch_size / low_batch):
  145. cur_offset = None if offset is None else offset + mstep * low_batch
  146. err, sq, _ = single_test(l, model, sess, task, to_print, low_batch, False,
  147. cur_offset)
  148. to_print = max(0, to_print - low_batch)
  149. errors += err
  150. seq += sq
  151. if FLAGS.mode > 0:
  152. cur_errors = float(low_batch * errors) / ((mstep+1) * low_batch)
  153. cur_seq = float(low_batch * seq) / ((mstep+1) * low_batch)
  154. data.print_out(" %s multitest current errors %.2f sequence-errors %.2f"
  155. % (task, 100*cur_errors, 100*cur_seq))
  156. errors = float(low_batch) * float(errors) / batch_size
  157. seq = float(low_batch) * float(seq) / batch_size
  158. data.print_out(" %s len %d errors %.2f sequence-errors %.2f"
  159. % (task, l, 100*errors, 100*seq))
  160. return errors, seq
  161. def train():
  162. """Main training function."""
  163. batch_size = FLAGS.batch_size
  164. tasks = FLAGS.task.split("-")
  165. with tf.Session() as sess:
  166. model, min_length, max_length, checkpoint_dir, curriculum = initialize(sess)
  167. max_cur_length = min(min_length + 3, max_length)
  168. prev_acc_perp = [1000000 for _ in xrange(3)]
  169. prev_sq = 1.0
  170. while True:
  171. global_step, pull, max_cur_length, learning_rate = sess.run(
  172. [model.global_step, model.pull, model.cur_length, model.lr])
  173. ep = global_step / FLAGS.steps_per_checkpoint
  174. acc_loss, acc_total, acc_errors, acc_seq = 0.0, 0, 0, 0
  175. acc_grad_norm, step_count, step_time = 0.0, 0, 0.0
  176. for _ in xrange(FLAGS.steps_per_checkpoint):
  177. global_step += 1
  178. task = random.choice(tasks)
  179. l1 = np.random.randint(max_cur_length - min_length + 1) + min_length
  180. l = l1
  181. if np.random.randint(10) > 3: # Prefer longer stuff 60% of time.
  182. l = np.random.randint(max_cur_length - min_length+1) + min_length
  183. l = max(l, l1)
  184. if np.random.randint(4) < 1: # Mixed learning: once in a while big.
  185. l = np.random.randint(max_length - min_length + 1) + min_length
  186. l = max(l, l1)
  187. start_time = time.time()
  188. inp, target = data.get_batch(l, batch_size, True, task)
  189. stepp = math.pow(global_step, -0.55)
  190. noise_param = math.sqrt(stepp * 20 * prev_sq) * FLAGS.grad_noise_scale
  191. loss, res, gnorm, _ = model.step(sess, inp, target, True, noise_param)
  192. step_time += time.time() - start_time
  193. acc_grad_norm += float(gnorm)
  194. if l < max_cur_length + 1:
  195. step_count += 1
  196. acc_loss += loss
  197. errors, total, seq = data.accuracy(inp, res, target,
  198. batch_size, 0)
  199. acc_total += total
  200. acc_errors += errors
  201. acc_seq += seq
  202. acc_loss /= step_count
  203. step_time /= FLAGS.steps_per_checkpoint
  204. acc_seq = float(acc_seq) / (step_count * batch_size)
  205. prev_sq = acc_seq
  206. acc_errors = float(acc_errors) / acc_total if acc_total > 0 else 1.0
  207. msg1 = "ep %d st %.2f lr %.8f" % (ep, step_time, learning_rate)
  208. msg2 = "pl %.3f cme %.3f" % (pull, curriculum)
  209. msg = ("%s %s gn %.8f"
  210. % (msg1, msg2, acc_grad_norm / FLAGS.steps_per_checkpoint))
  211. data.print_out("%s len %d ppx %.8f errs %.2f sq %.2f" %
  212. (msg, max_cur_length, data.safe_exp(acc_loss),
  213. 100*acc_errors, 100*acc_seq))
  214. if curriculum > acc_seq:
  215. prev_acc_perp.append(1000000)
  216. do_incr = True
  217. while do_incr and max_cur_length < max_length:
  218. sess.run(model.cur_length_incr_op)
  219. for t in tasks:
  220. if data.train_set[t]: do_incr = False
  221. if pull < 1:
  222. sess.run(model.pull_incr_op)
  223. else:
  224. data.print_out(" Averaging parameters.")
  225. sess.run([model.avg_op, model.lr_decay_op])
  226. else:
  227. acc_perp = data.safe_exp(acc_loss)
  228. if acc_perp > max(prev_acc_perp[-3:]):
  229. sess.run(model.lr_decay_op)
  230. prev_acc_perp.append(acc_perp)
  231. checkpoint_path = os.path.join(checkpoint_dir, "neural_gpu.ckpt")
  232. model.saver.save(sess, checkpoint_path,
  233. global_step=model.global_step)
  234. # Run evaluation.
  235. should_exit = True
  236. bound = data.bins[-1] + 1
  237. for t in tasks:
  238. l = min_length
  239. while l < max_length + 12 and l < bound:
  240. _, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
  241. l += 1
  242. while l < bound + 1 and not data.test_set[t][l]:
  243. l += 1
  244. if sq < 0.5:
  245. _, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
  246. batch_size * 4)
  247. if sq > 0.001: should_exit = False
  248. if should_exit:
  249. if data.forward_max > 4000 and len(tasks) == 1:
  250. multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
  251. batch_size * 16, 0)
  252. def animate(l, test_data, anim_size):
  253. """Create animation for the given data (hacky matplotlib use)."""
  254. xf = 12
  255. fps = 2
  256. fig = plt.figure(figsize=(16, 9), facecolor="white")
  257. ax = fig.add_axes([0, 0, 1, 1], frameon=False, zorder=2)
  258. ax.set_xticks([i * 24-0.5 for i in xrange(4)])
  259. ax.set_xticklabels([])
  260. ax.set_yticks([i - 0.5 for i in xrange(l+1)])
  261. ax.grid(which="major", axis="both", linestyle="-", color="black")
  262. text_fields = []
  263. text_size = 24*32/l
  264. for y in xrange(l):
  265. text_fields.append(ax.text(
  266. 11.25, y + 0.15, "", color="g", ha="center", va="center",
  267. bbox={"facecolor": "b", "alpha": 0.01, "pad": 24 * text_size},
  268. size=text_size - (4 * 32 / l), animated=True))
  269. im = ax.imshow(np.zeros_like(test_data[0][0][0]), vmin=-1.0,
  270. vmax=1.0, cmap="gray", aspect="auto", origin="upper",
  271. interpolation="none", animated=True)
  272. im.set_zorder(1)
  273. def to_symbol(i):
  274. if i == 0: return ""
  275. if i == 11: return "+"
  276. if i == 12: return "*"
  277. return str(i-1)
  278. def animation_update(frame_no, test_data, xf, im, text_fields):
  279. """Update an animation frame."""
  280. steps, inpt, out_raw = test_data
  281. length = len(steps)
  282. batch = frame_no / (fps * (l+4*xf))
  283. index = int((frame_no % (fps * (l+4*xf))) / fps)
  284. # Cut output after first padding.
  285. out = [out_raw[i][batch] for i in xrange(len(text_fields))]
  286. if 0 in out:
  287. i = out.index(0)
  288. out = out[0:i] + [0 for _ in xrange(len(out) - i)]
  289. # Show the state after the first frames.
  290. if index >= 2*xf:
  291. im.set_array(steps[min(length - 1, index - 2*xf)][batch])
  292. for i, t in enumerate(text_fields):
  293. if index - 2*xf < length:
  294. t.set_text("")
  295. else:
  296. t.set_text(to_symbol(out[i]))
  297. else:
  298. for i, t in enumerate(text_fields):
  299. t.set_text(to_symbol(inpt[i][batch]) if index < xf else "")
  300. if index < xf:
  301. im.set_array(np.zeros_like(steps[0][0]))
  302. else:
  303. im.set_array(steps[0][batch])
  304. return im,
  305. animation = anim.FuncAnimation(
  306. fig, animation_update, blit=True, frames=(l+4*xf)*anim_size*fps,
  307. interval=500/fps, fargs=(test_data, xf, im, text_fields))
  308. animation.save("/tmp/neural_gpu.mp4", writer="mencoder", fps=4*fps, dpi=3*80)
  309. def evaluate():
  310. """Evaluate an existing model."""
  311. batch_size = FLAGS.batch_size
  312. tasks = FLAGS.task.split("-")
  313. with tf.Session() as sess:
  314. model, min_length, max_length, _, _ = initialize(sess)
  315. bound = data.bins[-1] + 1
  316. for t in tasks:
  317. l = min_length
  318. while l < max_length + 12 and l < bound:
  319. _, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
  320. l += 1
  321. while l < bound + 1 and not data.test_set[t][l]:
  322. l += 1
  323. # Animate.
  324. anim_size = 2
  325. _, _, test_data = single_test(l, model, sess, t, 0, anim_size)
  326. animate(l, test_data, anim_size)
  327. # More tests.
  328. _, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
  329. batch_size * 4)
  330. if sq < 0.01: # More tests.
  331. if data.forward_max > 4000 and len(tasks) == 1:
  332. multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
  333. batch_size * 64, 0)
  334. def interactive():
  335. """Interactively probe an existing model."""
  336. with tf.Session() as sess:
  337. model, _, _, _, _ = initialize(sess)
  338. sys.stdout.write("> ")
  339. sys.stdout.flush()
  340. inpt = sys.stdin.readline()
  341. while inpt:
  342. ids = [int(c) for c in inpt.strip()]
  343. inpt, target = data.get_batch(len(ids), 1, False, "",
  344. preset=(ids, [0 for _ in ids]))
  345. _, res, _, _ = model.step(sess, inpt, target, False)
  346. res = [np.argmax(o, axis=1) for o in res]
  347. print " ".join([str(output[0]) for output in res])
  348. sys.stdout.write("> ")
  349. sys.stdout.flush()
  350. inpt = sys.stdin.readline()
  351. def main(_):
  352. if FLAGS.mode == 0:
  353. train()
  354. elif FLAGS.mode == 1:
  355. evaluate()
  356. else:
  357. interactive()
  358. if __name__ == "__main__":
  359. tf.app.run()