neural_gpu_trainer.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025
  1. # Copyright 2015 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Neural GPU."""
  16. import math
  17. import os
  18. import random
  19. import sys
  20. import threading
  21. import time
  22. import numpy as np
  23. import tensorflow as tf
  24. import program_utils
  25. import data_utils as data
  26. import neural_gpu as ngpu
  27. import wmt_utils as wmt
  28. tf.app.flags.DEFINE_float("lr", 0.1, "Learning rate.")
  29. tf.app.flags.DEFINE_float("init_weight", 0.8, "Initial weights deviation.")
  30. tf.app.flags.DEFINE_float("max_grad_norm", 4.0, "Clip gradients to this norm.")
  31. tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.")
  32. tf.app.flags.DEFINE_float("curriculum_ppx", 9.9, "Move curriculum if ppl < X.")
  33. tf.app.flags.DEFINE_float("curriculum_seq", 0.3, "Move curriculum if seq < X.")
  34. tf.app.flags.DEFINE_float("dropout", 0.1, "Dropout that much.")
  35. tf.app.flags.DEFINE_float("grad_noise_scale", 0.0, "Gradient noise scale.")
  36. tf.app.flags.DEFINE_float("max_sampling_rate", 0.1, "Maximal sampling rate.")
  37. tf.app.flags.DEFINE_float("length_norm", 0.0, "Length normalization.")
  38. tf.app.flags.DEFINE_float("train_beam_freq", 0.0, "Beam-based training.")
  39. tf.app.flags.DEFINE_float("train_beam_anneal", 20000, "How many steps anneal.")
  40. tf.app.flags.DEFINE_integer("eval_beam_steps", 4, "How many beam steps eval.")
  41. tf.app.flags.DEFINE_integer("batch_size", 32, "Batch size.")
  42. tf.app.flags.DEFINE_integer("steps_per_checkpoint", 100, "Steps per epoch.")
  43. tf.app.flags.DEFINE_integer("nmaps", 64, "Number of floats in each cell.")
  44. tf.app.flags.DEFINE_integer("vec_size", 64, "Size of word vectors.")
  45. tf.app.flags.DEFINE_integer("train_data_size", 1000, "Training examples/len.")
  46. tf.app.flags.DEFINE_integer("max_length", 40, "Maximum length.")
  47. tf.app.flags.DEFINE_integer("random_seed", 125459, "Random seed.")
  48. tf.app.flags.DEFINE_integer("nconvs", 2, "How many convolutions / 1 step.")
  49. tf.app.flags.DEFINE_integer("kw", 3, "Kernel width.")
  50. tf.app.flags.DEFINE_integer("kh", 3, "Kernel height.")
  51. tf.app.flags.DEFINE_integer("height", 4, "Height.")
  52. tf.app.flags.DEFINE_integer("mem_size", -1, "Memory size (sqrt)")
  53. tf.app.flags.DEFINE_integer("soft_mem_size", 1024, "Softmax memory this size.")
  54. tf.app.flags.DEFINE_integer("num_gpus", 1, "Number of GPUs to use.")
  55. tf.app.flags.DEFINE_integer("num_replicas", 1, "Number of replicas in use.")
  56. tf.app.flags.DEFINE_integer("beam_size", 1, "Beam size during decoding. "
  57. "If 0, no decoder, the non-extended Neural GPU.")
  58. tf.app.flags.DEFINE_integer("max_target_vocab", 0,
  59. "Maximal size of target vocabulary.")
  60. tf.app.flags.DEFINE_integer("decode_offset", 0, "Offset for decoding.")
  61. tf.app.flags.DEFINE_integer("task", -1, "Task id when running on borg.")
  62. tf.app.flags.DEFINE_integer("nprint", 0, "How many test examples to print out.")
  63. tf.app.flags.DEFINE_integer("eval_bin_print", 3, "How many bins step in eval.")
  64. tf.app.flags.DEFINE_integer("mode", 0, "Mode: 0-train other-decode.")
  65. tf.app.flags.DEFINE_bool("atrous", False, "Whether to use atrous convs.")
  66. tf.app.flags.DEFINE_bool("layer_norm", False, "Do layer normalization.")
  67. tf.app.flags.DEFINE_bool("quantize", False, "Whether to quantize variables.")
  68. tf.app.flags.DEFINE_bool("do_train", True, "If false, only update memory.")
  69. tf.app.flags.DEFINE_bool("rnn_baseline", False, "If true build an RNN instead.")
  70. tf.app.flags.DEFINE_bool("simple_tokenizer", False,
  71. "If true, tokenize on spaces only, digits are 0.")
  72. tf.app.flags.DEFINE_bool("normalize_digits", True,
  73. "Whether to normalize digits with simple tokenizer.")
  74. tf.app.flags.DEFINE_integer("vocab_size", 16, "Joint vocabulary size.")
  75. tf.app.flags.DEFINE_string("data_dir", "/tmp", "Data directory")
  76. tf.app.flags.DEFINE_string("train_dir", "/tmp/", "Directory to store models.")
  77. tf.app.flags.DEFINE_string("test_file_prefix", "", "Files to test (.en,.fr).")
  78. tf.app.flags.DEFINE_integer("max_train_data_size", 0,
  79. "Limit on the size of training data (0: no limit).")
  80. tf.app.flags.DEFINE_string("word_vector_file_en", "",
  81. "Optional file with word vectors to start training.")
  82. tf.app.flags.DEFINE_string("word_vector_file_fr", "",
  83. "Optional file with word vectors to start training.")
  84. tf.app.flags.DEFINE_string("problem", "wmt", "What problem are we solving?.")
  85. tf.app.flags.DEFINE_integer("ps_tasks", 0, "Number of ps tasks used.")
  86. tf.app.flags.DEFINE_string("master", "", "Name of the TensorFlow master.")
  87. FLAGS = tf.app.flags.FLAGS
  88. EXTRA_EVAL = 10
  89. EVAL_LEN_INCR = 8
  90. MAXLEN_F = 2.0
  91. def zero_split(tok_list, append=None):
  92. """Split tok_list (list of ints) on 0s, append int to all parts if given."""
  93. res, cur, l = [], [], 0
  94. for tok in tok_list:
  95. if tok == 0:
  96. if append is not None:
  97. cur.append(append)
  98. res.append(cur)
  99. l = max(l, len(cur))
  100. cur = []
  101. else:
  102. cur.append(tok)
  103. if append is not None:
  104. cur.append(append)
  105. res.append(cur)
  106. l = max(l, len(cur))
  107. return res, l
  108. def read_data(source_path, target_path, buckets, max_size=None, print_out=True):
  109. """Read data from source and target files and put into buckets.
  110. Args:
  111. source_path: path to the files with token-ids for the source language.
  112. target_path: path to the file with token-ids for the target language;
  113. it must be aligned with the source file: n-th line contains the desired
  114. output for n-th line from the source_path.
  115. buckets: the buckets to use.
  116. max_size: maximum number of lines to read, all other will be ignored;
  117. if 0 or None, data files will be read completely (no limit).
  118. If set to 1, no data will be returned (empty lists of the right form).
  119. print_out: whether to print out status or not.
  120. Returns:
  121. data_set: a list of length len(_buckets); data_set[n] contains a list of
  122. (source, target) pairs read from the provided data files that fit
  123. into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
  124. len(target) < _buckets[n][1]; source and target are lists of token-ids.
  125. """
  126. data_set = [[] for _ in buckets]
  127. counter = 0
  128. if max_size != 1:
  129. with tf.gfile.GFile(source_path, mode="r") as source_file:
  130. with tf.gfile.GFile(target_path, mode="r") as target_file:
  131. source, target = source_file.readline(), target_file.readline()
  132. while source and target and (not max_size or counter < max_size):
  133. counter += 1
  134. if counter % 100000 == 0 and print_out:
  135. print " reading data line %d" % counter
  136. sys.stdout.flush()
  137. source_ids = [int(x) for x in source.split()]
  138. target_ids = [int(x) for x in target.split()]
  139. source_ids, source_len = zero_split(source_ids)
  140. target_ids, target_len = zero_split(target_ids, append=wmt.EOS_ID)
  141. for bucket_id, size in enumerate(buckets):
  142. if source_len <= size and target_len <= size:
  143. data_set[bucket_id].append([source_ids, target_ids])
  144. break
  145. source, target = source_file.readline(), target_file.readline()
  146. return data_set
  147. global_train_set = {"wmt": []}
  148. train_buckets_scale = {"wmt": []}
  149. def calculate_buckets_scale(data_set, buckets, problem):
  150. """Calculate buckets scales for the given data set."""
  151. train_bucket_sizes = [len(data_set[b]) for b in xrange(len(buckets))]
  152. train_total_size = max(1, float(sum(train_bucket_sizes)))
  153. # A bucket scale is a list of increasing numbers from 0 to 1 that we'll use
  154. # to select a bucket. Length of [scale[i], scale[i+1]] is proportional to
  155. # the size if i-th training bucket, as used later.
  156. if problem not in train_buckets_scale:
  157. train_buckets_scale[problem] = []
  158. train_buckets_scale[problem].append(
  159. [sum(train_bucket_sizes[:i + 1]) / train_total_size
  160. for i in xrange(len(train_bucket_sizes))])
  161. return train_total_size
  162. def read_data_into_global(source_path, target_path, buckets,
  163. max_size=None, print_out=True):
  164. """Read data into the global variables (can be in a separate thread)."""
  165. # pylint: disable=global-variable-not-assigned
  166. global global_train_set, train_buckets_scale
  167. # pylint: enable=global-variable-not-assigned
  168. data_set = read_data(source_path, target_path, buckets, max_size, print_out)
  169. global_train_set["wmt"].append(data_set)
  170. train_total_size = calculate_buckets_scale(data_set, buckets, "wmt")
  171. if print_out:
  172. print " Finished global data reading (%d)." % train_total_size
  173. def initialize(sess=None):
  174. """Initialize data and model."""
  175. global MAXLEN_F
  176. # Create training directory if it does not exist.
  177. if not tf.gfile.IsDirectory(FLAGS.train_dir):
  178. data.print_out("Creating training directory %s." % FLAGS.train_dir)
  179. tf.gfile.MkDir(FLAGS.train_dir)
  180. decode_suffix = "beam%dln%d" % (FLAGS.beam_size,
  181. int(100 * FLAGS.length_norm))
  182. if FLAGS.mode == 0:
  183. decode_suffix = ""
  184. if FLAGS.task >= 0:
  185. data.log_filename = os.path.join(FLAGS.train_dir,
  186. "log%d%s" % (FLAGS.task, decode_suffix))
  187. else:
  188. data.log_filename = os.path.join(FLAGS.train_dir, "neural_gpu/log")
  189. # Set random seed.
  190. if FLAGS.random_seed > 0:
  191. seed = FLAGS.random_seed + max(0, FLAGS.task)
  192. tf.set_random_seed(seed)
  193. random.seed(seed)
  194. np.random.seed(seed)
  195. # Check data sizes.
  196. assert data.bins
  197. max_length = min(FLAGS.max_length, data.bins[-1])
  198. while len(data.bins) > 1 and data.bins[-2] >= max_length + EXTRA_EVAL:
  199. data.bins = data.bins[:-1]
  200. if sess is None and FLAGS.task == 0 and FLAGS.num_replicas > 1:
  201. if max_length > 60:
  202. max_length = max_length * 1 / 2 # Save memory on chief.
  203. min_length = min(14, max_length - 3) if FLAGS.problem == "wmt" else 3
  204. for p in FLAGS.problem.split("-"):
  205. if p in ["progeval", "progsynth"]:
  206. min_length = max(26, min_length)
  207. assert max_length + 1 > min_length
  208. while len(data.bins) > 1 and data.bins[-2] >= max_length + EXTRA_EVAL:
  209. data.bins = data.bins[:-1]
  210. # Create checkpoint directory if it does not exist.
  211. if FLAGS.mode == 0 or FLAGS.task < 0:
  212. checkpoint_dir = os.path.join(FLAGS.train_dir, "neural_gpu%s"
  213. % ("" if FLAGS.task < 0 else str(FLAGS.task)))
  214. else:
  215. checkpoint_dir = FLAGS.train_dir
  216. if not tf.gfile.IsDirectory(checkpoint_dir):
  217. data.print_out("Creating checkpoint directory %s." % checkpoint_dir)
  218. tf.gfile.MkDir(checkpoint_dir)
  219. # Prepare data.
  220. if FLAGS.problem == "wmt":
  221. # Prepare WMT data.
  222. data.print_out("Preparing WMT data in %s" % FLAGS.data_dir)
  223. if FLAGS.simple_tokenizer:
  224. MAXLEN_F = 3.5
  225. (en_train, fr_train, en_dev, fr_dev,
  226. en_path, fr_path) = wmt.prepare_wmt_data(
  227. FLAGS.data_dir, FLAGS.vocab_size,
  228. tokenizer=wmt.space_tokenizer,
  229. normalize_digits=FLAGS.normalize_digits)
  230. else:
  231. (en_train, fr_train, en_dev, fr_dev,
  232. en_path, fr_path) = wmt.prepare_wmt_data(
  233. FLAGS.data_dir, FLAGS.vocab_size)
  234. # Read data into buckets and compute their sizes.
  235. fr_vocab, rev_fr_vocab = wmt.initialize_vocabulary(fr_path)
  236. data.vocab = fr_vocab
  237. data.rev_vocab = rev_fr_vocab
  238. data.print_out("Reading development and training data (limit: %d)."
  239. % FLAGS.max_train_data_size)
  240. dev_set = {}
  241. dev_set["wmt"] = read_data(en_dev, fr_dev, data.bins)
  242. def data_read(size, print_out):
  243. read_data_into_global(en_train, fr_train, data.bins, size, print_out)
  244. data_read(50000, False)
  245. read_thread_small = threading.Thread(
  246. name="reading-data-small", target=lambda: data_read(900000, False))
  247. read_thread_small.start()
  248. read_thread_full = threading.Thread(
  249. name="reading-data-full",
  250. target=lambda: data_read(FLAGS.max_train_data_size, True))
  251. read_thread_full.start()
  252. data.print_out("Data reading set up.")
  253. else:
  254. # Prepare algorithmic data.
  255. en_path, fr_path = None, None
  256. tasks = FLAGS.problem.split("-")
  257. data_size = FLAGS.train_data_size
  258. for t in tasks:
  259. data.print_out("Generating data for %s." % t)
  260. if t in ["progeval", "progsynth"]:
  261. data.init_data(t, data.bins[-1], 20 * data_size, FLAGS.vocab_size)
  262. if len(program_utils.prog_vocab) > FLAGS.vocab_size - 2:
  263. raise ValueError("Increase vocab_size to %d for prog-tasks."
  264. % (len(program_utils.prog_vocab) + 2))
  265. data.rev_vocab = program_utils.prog_vocab
  266. data.vocab = program_utils.prog_rev_vocab
  267. else:
  268. for l in xrange(max_length + EXTRA_EVAL - 1):
  269. data.init_data(t, l, data_size, FLAGS.vocab_size)
  270. data.init_data(t, data.bins[-2], data_size, FLAGS.vocab_size)
  271. data.init_data(t, data.bins[-1], data_size, FLAGS.vocab_size)
  272. if t not in global_train_set:
  273. global_train_set[t] = []
  274. global_train_set[t].append(data.train_set[t])
  275. calculate_buckets_scale(data.train_set[t], data.bins, t)
  276. dev_set = data.test_set
  277. # Grid-search parameters.
  278. lr = FLAGS.lr
  279. init_weight = FLAGS.init_weight
  280. max_grad_norm = FLAGS.max_grad_norm
  281. if sess is not None and FLAGS.task > -1:
  282. def job_id_factor(step):
  283. """If jobid / step mod 3 is 0, 1, 2: say 0, 1, -1."""
  284. return ((((FLAGS.task / step) % 3) + 1) % 3) - 1
  285. lr *= math.pow(2, job_id_factor(1))
  286. init_weight *= math.pow(1.5, job_id_factor(3))
  287. max_grad_norm *= math.pow(2, job_id_factor(9))
  288. # Print out parameters.
  289. curriculum = FLAGS.curriculum_seq
  290. msg1 = ("layers %d kw %d h %d kh %d batch %d noise %.2f"
  291. % (FLAGS.nconvs, FLAGS.kw, FLAGS.height, FLAGS.kh,
  292. FLAGS.batch_size, FLAGS.grad_noise_scale))
  293. msg2 = ("cut %.2f lr %.3f iw %.2f cr %.2f nm %d d%.4f gn %.2f %s"
  294. % (FLAGS.cutoff, lr, init_weight, curriculum, FLAGS.nmaps,
  295. FLAGS.dropout, max_grad_norm, msg1))
  296. data.print_out(msg2)
  297. # Create model and initialize it.
  298. tf.get_variable_scope().set_initializer(
  299. tf.orthogonal_initializer(gain=1.8 * init_weight))
  300. max_sampling_rate = FLAGS.max_sampling_rate if FLAGS.mode == 0 else 0.0
  301. o = FLAGS.vocab_size if FLAGS.max_target_vocab < 1 else FLAGS.max_target_vocab
  302. ngpu.CHOOSE_K = FLAGS.soft_mem_size
  303. do_beam_model = FLAGS.train_beam_freq > 0.0001 and FLAGS.beam_size > 1
  304. beam_size = FLAGS.beam_size if FLAGS.mode > 0 and not do_beam_model else 1
  305. beam_size = min(beam_size, FLAGS.beam_size)
  306. beam_model = None
  307. def make_ngpu(cur_beam_size, back):
  308. return ngpu.NeuralGPU(
  309. FLAGS.nmaps, FLAGS.vec_size, FLAGS.vocab_size, o,
  310. FLAGS.dropout, max_grad_norm, FLAGS.cutoff, FLAGS.nconvs,
  311. FLAGS.kw, FLAGS.kh, FLAGS.height, FLAGS.mem_size,
  312. lr / math.sqrt(FLAGS.num_replicas), min_length + 3, FLAGS.num_gpus,
  313. FLAGS.num_replicas, FLAGS.grad_noise_scale, max_sampling_rate,
  314. atrous=FLAGS.atrous, do_rnn=FLAGS.rnn_baseline,
  315. do_layer_norm=FLAGS.layer_norm, beam_size=cur_beam_size, backward=back)
  316. if sess is None:
  317. with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
  318. model = make_ngpu(beam_size, True)
  319. if do_beam_model:
  320. tf.get_variable_scope().reuse_variables()
  321. beam_model = make_ngpu(FLAGS.beam_size, False)
  322. else:
  323. model = make_ngpu(beam_size, True)
  324. if do_beam_model:
  325. tf.get_variable_scope().reuse_variables()
  326. beam_model = make_ngpu(FLAGS.beam_size, False)
  327. sv = None
  328. if sess is None:
  329. # The supervisor configuration has a few overriden options.
  330. sv = tf.train.Supervisor(logdir=checkpoint_dir,
  331. is_chief=(FLAGS.task < 1),
  332. saver=model.saver,
  333. summary_op=None,
  334. save_summaries_secs=60,
  335. save_model_secs=15 * 60,
  336. global_step=model.global_step)
  337. config = tf.ConfigProto(allow_soft_placement=True)
  338. sess = sv.PrepareSession(FLAGS.master, config=config)
  339. data.print_out("Created model. Checkpoint dir %s" % checkpoint_dir)
  340. # Load model from parameters if a checkpoint exists.
  341. ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
  342. if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path + ".index"):
  343. data.print_out("Reading model parameters from %s"
  344. % ckpt.model_checkpoint_path)
  345. model.saver.restore(sess, ckpt.model_checkpoint_path)
  346. elif sv is None:
  347. sess.run(tf.global_variables_initializer())
  348. data.print_out("Initialized variables (no supervisor mode).")
  349. elif FLAGS.task < 1 and FLAGS.mem_size > 0:
  350. # sess.run(model.mem_norm_op)
  351. data.print_out("Created new model and normalized mem (on chief).")
  352. # Return the model and needed variables.
  353. return (model, beam_model, min_length, max_length, checkpoint_dir,
  354. (global_train_set, dev_set, en_path, fr_path), sv, sess)
  355. def m_step(model, beam_model, sess, batch_size, inp, target, bucket, nsteps, p):
  356. """Evaluation multi-step for program synthesis."""
  357. state, scores, hist = None, [[-11.0 for _ in xrange(batch_size)]], []
  358. for _ in xrange(nsteps):
  359. # Get the best beam (no training, just forward model).
  360. new_target, new_first, new_inp, new_scores = get_best_beam(
  361. beam_model, sess, inp, target,
  362. batch_size, FLAGS.beam_size, bucket, hist, p, test_mode=True)
  363. hist.append(new_first)
  364. _, _, _, state = model.step(sess, inp, new_target, False, state=state)
  365. inp = new_inp
  366. scores.append([max(scores[-1][i], new_scores[i])
  367. for i in xrange(batch_size)])
  368. # The final step with the true target.
  369. loss, res, _, _ = model.step(sess, inp, target, False, state=state)
  370. return loss, res, new_target, scores[1:]
  371. def single_test(bin_id, model, sess, nprint, batch_size, dev, p, print_out=True,
  372. offset=None, beam_model=None):
  373. """Test model on test data of length l using the given session."""
  374. if not dev[p][bin_id]:
  375. data.print_out(" bin %d (%d)\t%s\tppl NA errors NA seq-errors NA"
  376. % (bin_id, data.bins[bin_id], p))
  377. return 1.0, 1.0, 0.0
  378. inpt, target = data.get_batch(
  379. bin_id, batch_size, dev[p], FLAGS.height, offset)
  380. if FLAGS.beam_size > 1 and beam_model:
  381. loss, res, new_tgt, scores = m_step(
  382. model, beam_model, sess, batch_size, inpt, target, bin_id,
  383. FLAGS.eval_beam_steps, p)
  384. score_avgs = [sum(s) / float(len(s)) for s in scores]
  385. score_maxs = [max(s) for s in scores]
  386. score_str = ["(%.2f, %.2f)" % (score_avgs[i], score_maxs[i])
  387. for i in xrange(FLAGS.eval_beam_steps)]
  388. data.print_out(" == scores (avg, max): %s" % "; ".join(score_str))
  389. errors, total, seq_err = data.accuracy(inpt, res, target, batch_size,
  390. nprint, new_tgt, scores[-1])
  391. else:
  392. loss, res, _, _ = model.step(sess, inpt, target, False)
  393. errors, total, seq_err = data.accuracy(inpt, res, target, batch_size,
  394. nprint)
  395. seq_err = float(seq_err) / batch_size
  396. if total > 0:
  397. errors = float(errors) / total
  398. if print_out:
  399. data.print_out(" bin %d (%d)\t%s\tppl %.2f errors %.2f seq-errors %.2f"
  400. % (bin_id, data.bins[bin_id], p, data.safe_exp(loss),
  401. 100 * errors, 100 * seq_err))
  402. return (errors, seq_err, loss)
  403. def assign_vectors(word_vector_file, embedding_key, vocab_path, sess):
  404. """Assign the embedding_key variable from the given word vectors file."""
  405. # For words in the word vector file, set their embedding at start.
  406. if not tf.gfile.Exists(word_vector_file):
  407. data.print_out("Word vector file does not exist: %s" % word_vector_file)
  408. sys.exit(1)
  409. vocab, _ = wmt.initialize_vocabulary(vocab_path)
  410. vectors_variable = [v for v in tf.trainable_variables()
  411. if embedding_key == v.name]
  412. if len(vectors_variable) != 1:
  413. data.print_out("Word vector variable not found or too many.")
  414. sys.exit(1)
  415. vectors_variable = vectors_variable[0]
  416. vectors = vectors_variable.eval()
  417. data.print_out("Pre-setting word vectors from %s" % word_vector_file)
  418. with tf.gfile.GFile(word_vector_file, mode="r") as f:
  419. # Lines have format: dog 0.045123 -0.61323 0.413667 ...
  420. for line in f:
  421. line_parts = line.split()
  422. # The first part is the word.
  423. word = line_parts[0]
  424. if word in vocab:
  425. # Remaining parts are components of the vector.
  426. word_vector = np.array(map(float, line_parts[1:]))
  427. if len(word_vector) != FLAGS.vec_size:
  428. data.print_out("Warn: Word '%s', Expecting vector size %d, "
  429. "found %d" % (word, FLAGS.vec_size,
  430. len(word_vector)))
  431. else:
  432. vectors[vocab[word]] = word_vector
  433. # Assign the modified vectors to the vectors_variable in the graph.
  434. sess.run([vectors_variable.initializer],
  435. {vectors_variable.initializer.inputs[1]: vectors})
  436. def print_vectors(embedding_key, vocab_path, word_vector_file):
  437. """Print vectors from the given variable."""
  438. _, rev_vocab = wmt.initialize_vocabulary(vocab_path)
  439. vectors_variable = [v for v in tf.trainable_variables()
  440. if embedding_key == v.name]
  441. if len(vectors_variable) != 1:
  442. data.print_out("Word vector variable not found or too many.")
  443. sys.exit(1)
  444. vectors_variable = vectors_variable[0]
  445. vectors = vectors_variable.eval()
  446. l, s = vectors.shape[0], vectors.shape[1]
  447. data.print_out("Printing %d word vectors from %s to %s."
  448. % (l, embedding_key, word_vector_file))
  449. with tf.gfile.GFile(word_vector_file, mode="w") as f:
  450. # Lines have format: dog 0.045123 -0.61323 0.413667 ...
  451. for i in xrange(l):
  452. f.write(rev_vocab[i])
  453. for j in xrange(s):
  454. f.write(" %.8f" % vectors[i][j])
  455. f.write("\n")
  456. def get_bucket_id(train_buckets_scale_c, max_cur_length, data_set):
  457. """Get a random bucket id."""
  458. # Choose a bucket according to data distribution. Pick a random number
  459. # in [0, 1] and use the corresponding interval in train_buckets_scale.
  460. random_number_01 = np.random.random_sample()
  461. bucket_id = min([i for i in xrange(len(train_buckets_scale_c))
  462. if train_buckets_scale_c[i] > random_number_01])
  463. while bucket_id > 0 and not data_set[bucket_id]:
  464. bucket_id -= 1
  465. for _ in xrange(10 if np.random.random_sample() < 0.9 else 1):
  466. if data.bins[bucket_id] > max_cur_length:
  467. random_number_01 = min(random_number_01, np.random.random_sample())
  468. bucket_id = min([i for i in xrange(len(train_buckets_scale_c))
  469. if train_buckets_scale_c[i] > random_number_01])
  470. while bucket_id > 0 and not data_set[bucket_id]:
  471. bucket_id -= 1
  472. return bucket_id
  473. def score_beams(beams, target, inp, history, p,
  474. print_out=False, test_mode=False):
  475. """Score beams."""
  476. if p == "progsynth":
  477. return score_beams_prog(beams, target, inp, history, print_out, test_mode)
  478. elif test_mode:
  479. return beams[0], 10.0 if str(beams[0][:len(target)]) == str(target) else 0.0
  480. else:
  481. history_s = [str(h) for h in history]
  482. best, best_score, tgt, eos_id = None, -1000.0, target, None
  483. if p == "wmt":
  484. eos_id = wmt.EOS_ID
  485. if eos_id and eos_id in target:
  486. tgt = target[:target.index(eos_id)]
  487. for beam in beams:
  488. if eos_id and eos_id in beam:
  489. beam = beam[:beam.index(eos_id)]
  490. l = min(len(tgt), len(beam))
  491. score = len([i for i in xrange(l) if tgt[i] == beam[i]]) / float(len(tgt))
  492. hist_score = 20.0 if str([b for b in beam if b > 0]) in history_s else 0.0
  493. if score < 1.0:
  494. score -= hist_score
  495. if score > best_score:
  496. best = beam
  497. best_score = score
  498. return best, best_score
  499. def score_beams_prog(beams, target, inp, history, print_out=False,
  500. test_mode=False):
  501. """Score beams for program synthesis."""
  502. tgt_prog = linearize(target, program_utils.prog_vocab, True, 1)
  503. hist_progs = [linearize(h, program_utils.prog_vocab, True, 1)
  504. for h in history]
  505. tgt_set = set(target)
  506. if print_out:
  507. print "target: ", tgt_prog
  508. inps, tgt_outs = [], []
  509. for i in xrange(3):
  510. ilist = [inp[i + 1, l] for l in xrange(inp.shape[1])]
  511. clist = [program_utils.prog_vocab[x] for x in ilist if x > 0]
  512. olist = clist[clist.index("]") + 1:] # outputs
  513. clist = clist[1:clist.index("]")] # inputs
  514. inps.append([int(x) for x in clist])
  515. if olist[0] == "[": # olist may be [int] or just int
  516. tgt_outs.append(str([int(x) for x in olist[1:-1]]))
  517. else:
  518. if len(olist) == 1:
  519. tgt_outs.append(olist[0])
  520. else:
  521. print [program_utils.prog_vocab[x] for x in ilist if x > 0]
  522. print olist
  523. print tgt_prog
  524. print program_utils.evaluate(tgt_prog, {"a": inps[-1]})
  525. print "AAAAA"
  526. tgt_outs.append(olist[0])
  527. if not test_mode:
  528. for _ in xrange(7):
  529. ilen = np.random.randint(len(target) - 3) + 1
  530. inps.append([random.choice(range(-15, 15)) for _ in range(ilen)])
  531. tgt_outs.extend([program_utils.evaluate(tgt_prog, {"a": inp})
  532. for inp in inps[3:]])
  533. best, best_prog, best_score = None, "", -1000.0
  534. for beam in beams:
  535. b_prog = linearize(beam, program_utils.prog_vocab, True, 1)
  536. b_set = set(beam)
  537. jsim = len(tgt_set & b_set) / float(len(tgt_set | b_set))
  538. b_outs = [program_utils.evaluate(b_prog, {"a": inp}) for inp in inps]
  539. errs = len([x for x in b_outs if x == "ERROR"])
  540. imatches = len([i for i in xrange(3) if b_outs[i] == tgt_outs[i]])
  541. perfect = 10.0 if imatches == 3 else 0.0
  542. hist_score = 20.0 if b_prog in hist_progs else 0.0
  543. if test_mode:
  544. score = perfect - errs
  545. else:
  546. matches = len([i for i in xrange(10) if b_outs[i] == tgt_outs[i]])
  547. score = perfect + matches + jsim - errs
  548. if score < 10.0:
  549. score -= hist_score
  550. # print b_prog
  551. # print "jsim: ", jsim, " errs: ", errs, " mtchs: ", matches, " s: ", score
  552. if score > best_score:
  553. best = beam
  554. best_prog = b_prog
  555. best_score = score
  556. if print_out:
  557. print "best score: ", best_score, " best prog: ", best_prog
  558. return best, best_score
  559. def get_best_beam(beam_model, sess, inp, target, batch_size, beam_size,
  560. bucket, history, p, test_mode=False):
  561. """Run beam_model, score beams, and return the best as target and in input."""
  562. _, output_logits, _, _ = beam_model.step(
  563. sess, inp, target, None, beam_size=FLAGS.beam_size)
  564. new_targets, new_firsts, scores, new_inp = [], [], [], np.copy(inp)
  565. for b in xrange(batch_size):
  566. outputs = []
  567. history_b = [[h[b, 0, l] for l in xrange(data.bins[bucket])]
  568. for h in history]
  569. for beam_idx in xrange(beam_size):
  570. outputs.append([int(o[beam_idx * batch_size + b])
  571. for o in output_logits])
  572. target_t = [target[b, 0, l] for l in xrange(data.bins[bucket])]
  573. best, best_score = score_beams(
  574. outputs, [t for t in target_t if t > 0], inp[b, :, :],
  575. [[t for t in h if t > 0] for h in history_b], p, test_mode=test_mode)
  576. scores.append(best_score)
  577. if 1 in best: # Only until _EOS.
  578. best = best[:best.index(1) + 1]
  579. best += [0 for _ in xrange(len(target_t) - len(best))]
  580. new_targets.append([best])
  581. first, _ = score_beams(
  582. outputs, [t for t in target_t if t > 0], inp[b, :, :],
  583. [[t for t in h if t > 0] for h in history_b], p, test_mode=True)
  584. if 1 in first: # Only until _EOS.
  585. first = first[:first.index(1) + 1]
  586. first += [0 for _ in xrange(len(target_t) - len(first))]
  587. new_inp[b, 0, :] = np.array(first, dtype=np.int32)
  588. new_firsts.append([first])
  589. # Change target if we found a great answer.
  590. new_target = np.array(new_targets, dtype=np.int32)
  591. for b in xrange(batch_size):
  592. if scores[b] >= 10.0:
  593. target[b, 0, :] = new_target[b, 0, :]
  594. new_first = np.array(new_firsts, dtype=np.int32)
  595. return new_target, new_first, new_inp, scores
  596. def train():
  597. """Train the model."""
  598. batch_size = FLAGS.batch_size * FLAGS.num_gpus
  599. (model, beam_model, min_length, max_length, checkpoint_dir,
  600. (train_set, dev_set, en_vocab_path, fr_vocab_path), sv, sess) = initialize()
  601. with sess.as_default():
  602. quant_op = model.quantize_op
  603. max_cur_length = min(min_length + 3, max_length)
  604. prev_acc_perp = [1000000 for _ in xrange(5)]
  605. prev_seq_err = 1.0
  606. is_chief = FLAGS.task < 1
  607. do_report = False
  608. # Main traning loop.
  609. while not sv.ShouldStop():
  610. global_step, max_cur_length, learning_rate = sess.run(
  611. [model.global_step, model.cur_length, model.lr])
  612. acc_loss, acc_l1, acc_total, acc_errors, acc_seq_err = 0.0, 0.0, 0, 0, 0
  613. acc_grad_norm, step_count, step_c1, step_time = 0.0, 0, 0, 0.0
  614. # For words in the word vector file, set their embedding at start.
  615. bound1 = FLAGS.steps_per_checkpoint - 1
  616. if FLAGS.word_vector_file_en and global_step < bound1 and is_chief:
  617. assign_vectors(FLAGS.word_vector_file_en, "embedding:0",
  618. en_vocab_path, sess)
  619. if FLAGS.max_target_vocab < 1:
  620. assign_vectors(FLAGS.word_vector_file_en, "target_embedding:0",
  621. en_vocab_path, sess)
  622. if FLAGS.word_vector_file_fr and global_step < bound1 and is_chief:
  623. assign_vectors(FLAGS.word_vector_file_fr, "embedding:0",
  624. fr_vocab_path, sess)
  625. if FLAGS.max_target_vocab < 1:
  626. assign_vectors(FLAGS.word_vector_file_fr, "target_embedding:0",
  627. fr_vocab_path, sess)
  628. for _ in xrange(FLAGS.steps_per_checkpoint):
  629. step_count += 1
  630. step_c1 += 1
  631. global_step = int(model.global_step.eval())
  632. train_beam_anneal = global_step / float(FLAGS.train_beam_anneal)
  633. train_beam_freq = FLAGS.train_beam_freq * min(1.0, train_beam_anneal)
  634. p = random.choice(FLAGS.problem.split("-"))
  635. train_set = global_train_set[p][-1]
  636. bucket_id = get_bucket_id(train_buckets_scale[p][-1], max_cur_length,
  637. train_set)
  638. # Prefer longer stuff 60% of time if not wmt.
  639. if np.random.randint(100) < 60 and FLAGS.problem != "wmt":
  640. bucket1 = get_bucket_id(train_buckets_scale[p][-1], max_cur_length,
  641. train_set)
  642. bucket_id = max(bucket1, bucket_id)
  643. # Run a step and time it.
  644. start_time = time.time()
  645. inp, target = data.get_batch(bucket_id, batch_size, train_set,
  646. FLAGS.height)
  647. noise_param = math.sqrt(math.pow(global_step + 1, -0.55) *
  648. prev_seq_err) * FLAGS.grad_noise_scale
  649. # In multi-step mode, we use best from beam for middle steps.
  650. state, new_target, scores, history = None, None, None, []
  651. while (FLAGS.beam_size > 1 and
  652. train_beam_freq > np.random.random_sample()):
  653. # Get the best beam (no training, just forward model).
  654. new_target, new_first, new_inp, scores = get_best_beam(
  655. beam_model, sess, inp, target,
  656. batch_size, FLAGS.beam_size, bucket_id, history, p)
  657. history.append(new_first)
  658. # Training step with the previous input and the best beam as target.
  659. _, _, _, state = model.step(sess, inp, new_target, FLAGS.do_train,
  660. noise_param, update_mem=True, state=state)
  661. # Change input to the new one for the next step.
  662. inp = new_inp
  663. # If all results are great, stop (todo: not to wait for all?).
  664. if FLAGS.nprint > 1:
  665. print scores
  666. if sum(scores) / float(len(scores)) >= 10.0:
  667. break
  668. # The final step with the true target.
  669. loss, res, gnorm, _ = model.step(
  670. sess, inp, target, FLAGS.do_train, noise_param,
  671. update_mem=True, state=state)
  672. step_time += time.time() - start_time
  673. acc_grad_norm += 0.0 if gnorm is None else float(gnorm)
  674. # Accumulate statistics.
  675. acc_loss += loss
  676. acc_l1 += loss
  677. errors, total, seq_err = data.accuracy(
  678. inp, res, target, batch_size, 0, new_target, scores)
  679. if FLAGS.nprint > 1:
  680. print "seq_err: ", seq_err
  681. acc_total += total
  682. acc_errors += errors
  683. acc_seq_err += seq_err
  684. # Report summary every 10 steps.
  685. if step_count + 3 > FLAGS.steps_per_checkpoint:
  686. do_report = True # Don't polute plot too early.
  687. if is_chief and step_count % 10 == 1 and do_report:
  688. cur_loss = acc_l1 / float(step_c1)
  689. acc_l1, step_c1 = 0.0, 0
  690. cur_perp = data.safe_exp(cur_loss)
  691. summary = tf.Summary()
  692. summary.value.extend(
  693. [tf.Summary.Value(tag="log_perplexity", simple_value=cur_loss),
  694. tf.Summary.Value(tag="perplexity", simple_value=cur_perp)])
  695. sv.SummaryComputed(sess, summary, global_step)
  696. # Normalize and print out accumulated statistics.
  697. acc_loss /= step_count
  698. step_time /= FLAGS.steps_per_checkpoint
  699. acc_seq_err = float(acc_seq_err) / (step_count * batch_size)
  700. prev_seq_err = max(0.0, acc_seq_err - 0.02) # No noise at error < 2%.
  701. acc_errors = float(acc_errors) / acc_total if acc_total > 0 else 1.0
  702. t_size = float(sum([len(x) for x in train_set])) / float(1000000)
  703. msg = ("step %d step-time %.2f train-size %.3f lr %.6f grad-norm %.4f"
  704. % (global_step + 1, step_time, t_size, learning_rate,
  705. acc_grad_norm / FLAGS.steps_per_checkpoint))
  706. data.print_out("%s len %d ppl %.6f errors %.2f sequence-errors %.2f" %
  707. (msg, max_cur_length, data.safe_exp(acc_loss),
  708. 100*acc_errors, 100*acc_seq_err))
  709. # If errors are below the curriculum threshold, move curriculum forward.
  710. is_good = FLAGS.curriculum_ppx > data.safe_exp(acc_loss)
  711. is_good = is_good and FLAGS.curriculum_seq > acc_seq_err
  712. if is_good and is_chief:
  713. if FLAGS.quantize:
  714. # Quantize weights.
  715. data.print_out(" Quantizing parameters.")
  716. sess.run([quant_op])
  717. # Increase current length (until the next with training data).
  718. sess.run(model.cur_length_incr_op)
  719. # Forget last perplexities if we're not yet at the end.
  720. if max_cur_length < max_length:
  721. prev_acc_perp.append(1000000)
  722. # Lower learning rate if we're worse than the last 5 checkpoints.
  723. acc_perp = data.safe_exp(acc_loss)
  724. if acc_perp > max(prev_acc_perp[-5:]) and is_chief:
  725. sess.run(model.lr_decay_op)
  726. prev_acc_perp.append(acc_perp)
  727. # Save checkpoint.
  728. if is_chief:
  729. checkpoint_path = os.path.join(checkpoint_dir, "neural_gpu.ckpt")
  730. model.saver.save(sess, checkpoint_path,
  731. global_step=model.global_step)
  732. # Run evaluation.
  733. bin_bound = 4
  734. for p in FLAGS.problem.split("-"):
  735. total_loss, total_err, tl_counter = 0.0, 0.0, 0
  736. for bin_id in xrange(len(data.bins)):
  737. if bin_id < bin_bound or bin_id % FLAGS.eval_bin_print == 1:
  738. err, _, loss = single_test(bin_id, model, sess, FLAGS.nprint,
  739. batch_size * 4, dev_set, p,
  740. beam_model=beam_model)
  741. if loss > 0.0:
  742. total_loss += loss
  743. total_err += err
  744. tl_counter += 1
  745. test_loss = total_loss / max(1, tl_counter)
  746. test_err = total_err / max(1, tl_counter)
  747. test_perp = data.safe_exp(test_loss)
  748. summary = tf.Summary()
  749. summary.value.extend(
  750. [tf.Summary.Value(tag="test/%s/loss" % p, simple_value=test_loss),
  751. tf.Summary.Value(tag="test/%s/error" % p, simple_value=test_err),
  752. tf.Summary.Value(tag="test/%s/perplexity" % p,
  753. simple_value=test_perp)])
  754. sv.SummaryComputed(sess, summary, global_step)
  755. def linearize(output, rev_fr_vocab, simple_tokenizer=None, eos_id=wmt.EOS_ID):
  756. # If there is an EOS symbol in outputs, cut them at that point (WMT).
  757. if eos_id in output:
  758. output = output[:output.index(eos_id)]
  759. # Print out French sentence corresponding to outputs.
  760. if simple_tokenizer or FLAGS.simple_tokenizer:
  761. vlen = len(rev_fr_vocab)
  762. def vget(o):
  763. if o < vlen:
  764. return rev_fr_vocab[o]
  765. return "UNK"
  766. return " ".join([vget(o) for o in output])
  767. else:
  768. return wmt.basic_detokenizer([rev_fr_vocab[o] for o in output])
  769. def evaluate():
  770. """Evaluate an existing model."""
  771. batch_size = FLAGS.batch_size * FLAGS.num_gpus
  772. with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
  773. (model, beam_model, _, _, _,
  774. (_, dev_set, en_vocab_path, fr_vocab_path), _, sess) = initialize(sess)
  775. for p in FLAGS.problem.split("-"):
  776. for bin_id in xrange(len(data.bins)):
  777. if (FLAGS.task >= 0 and bin_id > 4) or (FLAGS.nprint == 0 and
  778. bin_id > 8 and p == "wmt"):
  779. break
  780. single_test(bin_id, model, sess, FLAGS.nprint, batch_size, dev_set, p,
  781. beam_model=beam_model)
  782. path = FLAGS.test_file_prefix
  783. xid = "" if FLAGS.task < 0 else ("%.4d" % (FLAGS.task+FLAGS.decode_offset))
  784. en_path, fr_path = path + ".en" + xid, path + ".fr" + xid
  785. # Evaluate the test file if they exist.
  786. if path and tf.gfile.Exists(en_path) and tf.gfile.Exists(fr_path):
  787. data.print_out("Translating test set %s" % en_path)
  788. # Read lines.
  789. en_lines, fr_lines = [], []
  790. with tf.gfile.GFile(en_path, mode="r") as f:
  791. for line in f:
  792. en_lines.append(line.strip())
  793. with tf.gfile.GFile(fr_path, mode="r") as f:
  794. for line in f:
  795. fr_lines.append(line.strip())
  796. # Tokenize and convert to ids.
  797. en_vocab, _ = wmt.initialize_vocabulary(en_vocab_path)
  798. _, rev_fr_vocab = wmt.initialize_vocabulary(fr_vocab_path)
  799. if FLAGS.simple_tokenizer:
  800. en_ids = [wmt.sentence_to_token_ids(
  801. l, en_vocab, tokenizer=wmt.space_tokenizer,
  802. normalize_digits=FLAGS.normalize_digits)
  803. for l in en_lines]
  804. else:
  805. en_ids = [wmt.sentence_to_token_ids(l, en_vocab) for l in en_lines]
  806. # Translate.
  807. results = []
  808. for idx, token_ids in enumerate(en_ids):
  809. if idx % 5 == 0:
  810. data.print_out("Translating example %d of %d." % (idx, len(en_ids)))
  811. # Which bucket does it belong to?
  812. buckets = [b for b in xrange(len(data.bins))
  813. if data.bins[b] >= len(token_ids)]
  814. if buckets:
  815. result, result_cost = [], 100000000.0
  816. for bucket_id in buckets:
  817. if data.bins[bucket_id] > MAXLEN_F * len(token_ids) + EVAL_LEN_INCR:
  818. break
  819. # Get a 1-element batch to feed the sentence to the model.
  820. used_batch_size = 1 # batch_size
  821. inp, target = data.get_batch(
  822. bucket_id, used_batch_size, None, FLAGS.height,
  823. preset=([token_ids], [[]]))
  824. loss, output_logits, _, _ = model.step(
  825. sess, inp, target, None, beam_size=FLAGS.beam_size)
  826. outputs = [int(o[0]) for o in output_logits]
  827. loss = loss[0] - (data.bins[bucket_id] * FLAGS.length_norm)
  828. if FLAGS.simple_tokenizer:
  829. cur_out = outputs
  830. if wmt.EOS_ID in cur_out:
  831. cur_out = cur_out[:cur_out.index(wmt.EOS_ID)]
  832. res_tags = [rev_fr_vocab[o] for o in cur_out]
  833. bad_words, bad_brack = wmt.parse_constraints(token_ids, res_tags)
  834. loss += 1000.0 * bad_words + 100.0 * bad_brack
  835. # print (bucket_id, loss)
  836. if loss < result_cost:
  837. result = outputs
  838. result_cost = loss
  839. final = linearize(result, rev_fr_vocab)
  840. results.append("%s\t%s\n" % (final, fr_lines[idx]))
  841. # print result_cost
  842. sys.stderr.write(results[-1])
  843. sys.stderr.flush()
  844. else:
  845. sys.stderr.write("TOOO_LONG\t%s\n" % fr_lines[idx])
  846. sys.stderr.flush()
  847. if xid:
  848. decode_suffix = "beam%dln%dn" % (FLAGS.beam_size,
  849. int(100 * FLAGS.length_norm))
  850. with tf.gfile.GFile(path + ".res" + decode_suffix + xid, mode="w") as f:
  851. for line in results:
  852. f.write(line)
  853. def mul(l):
  854. res = 1.0
  855. for s in l:
  856. res *= s
  857. return res
  858. def interactive():
  859. """Interactively probe an existing model."""
  860. with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
  861. # Initialize model.
  862. (model, _, _, _, _, (_, _, en_path, fr_path), _, _) = initialize(sess)
  863. # Load vocabularies.
  864. en_vocab, rev_en_vocab = wmt.initialize_vocabulary(en_path)
  865. _, rev_fr_vocab = wmt.initialize_vocabulary(fr_path)
  866. # Print out vectors and variables.
  867. if FLAGS.nprint > 0 and FLAGS.word_vector_file_en:
  868. print_vectors("embedding:0", en_path, FLAGS.word_vector_file_en)
  869. if FLAGS.nprint > 0 and FLAGS.word_vector_file_fr:
  870. print_vectors("target_embedding:0", fr_path, FLAGS.word_vector_file_fr)
  871. total = 0
  872. for v in tf.trainable_variables():
  873. shape = v.get_shape().as_list()
  874. total += mul(shape)
  875. print (v.name, shape, mul(shape))
  876. print total
  877. # Start interactive loop.
  878. sys.stdout.write("Input to Neural GPU Translation Model.\n")
  879. sys.stdout.write("> ")
  880. sys.stdout.flush()
  881. inpt = sys.stdin.readline(), ""
  882. while inpt:
  883. cures = []
  884. # Get token-ids for the input sentence.
  885. if FLAGS.simple_tokenizer:
  886. token_ids = wmt.sentence_to_token_ids(
  887. inpt, en_vocab, tokenizer=wmt.space_tokenizer,
  888. normalize_digits=FLAGS.normalize_digits)
  889. else:
  890. token_ids = wmt.sentence_to_token_ids(inpt, en_vocab)
  891. print [rev_en_vocab[t] for t in token_ids]
  892. # Which bucket does it belong to?
  893. buckets = [b for b in xrange(len(data.bins))
  894. if data.bins[b] >= max(len(token_ids), len(cures))]
  895. if cures:
  896. buckets = [buckets[0]]
  897. if buckets:
  898. result, result_cost = [], 10000000.0
  899. for bucket_id in buckets:
  900. if data.bins[bucket_id] > MAXLEN_F * len(token_ids) + EVAL_LEN_INCR:
  901. break
  902. glen = 1
  903. for gen_idx in xrange(glen):
  904. # Get a 1-element batch to feed the sentence to the model.
  905. inp, target = data.get_batch(
  906. bucket_id, 1, None, FLAGS.height, preset=([token_ids], [cures]))
  907. loss, output_logits, _, _ = model.step(
  908. sess, inp, target, None, beam_size=FLAGS.beam_size,
  909. update_mem=False)
  910. # If it is a greedy decoder, outputs are argmaxes of output_logits.
  911. if FLAGS.beam_size > 1:
  912. outputs = [int(o) for o in output_logits]
  913. else:
  914. loss = loss[0] - (data.bins[bucket_id] * FLAGS.length_norm)
  915. outputs = [int(np.argmax(logit, axis=1))
  916. for logit in output_logits]
  917. print [rev_fr_vocab[t] for t in outputs]
  918. print loss, data.bins[bucket_id]
  919. print linearize(outputs, rev_fr_vocab)
  920. cures.append(outputs[gen_idx])
  921. print cures
  922. print linearize(cures, rev_fr_vocab)
  923. if FLAGS.simple_tokenizer:
  924. cur_out = outputs
  925. if wmt.EOS_ID in cur_out:
  926. cur_out = cur_out[:cur_out.index(wmt.EOS_ID)]
  927. res_tags = [rev_fr_vocab[o] for o in cur_out]
  928. bad_words, bad_brack = wmt.parse_constraints(token_ids, res_tags)
  929. loss += 1000.0 * bad_words + 100.0 * bad_brack
  930. if loss < result_cost:
  931. result = outputs
  932. result_cost = loss
  933. print ("FINAL", result_cost)
  934. print [rev_fr_vocab[t] for t in result]
  935. print linearize(result, rev_fr_vocab)
  936. else:
  937. print "TOOO_LONG"
  938. sys.stdout.write("> ")
  939. sys.stdout.flush()
  940. inpt = sys.stdin.readline(), ""
  941. def main(_):
  942. if FLAGS.mode == 0:
  943. train()
  944. elif FLAGS.mode == 1:
  945. evaluate()
  946. else:
  947. interactive()
  948. if __name__ == "__main__":
  949. tf.app.run()