neural_gpu.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. # Copyright 2015 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """The Neural GPU Model."""
  16. import time
  17. import numpy as np
  18. import tensorflow as tf
  19. from tensorflow.python.framework import function
  20. import data_utils as data
  21. do_jit = False # Gives more speed but experimental for now.
  22. jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
  23. def conv_linear(args, kw, kh, nin, nout, rate, do_bias, bias_start, prefix):
  24. """Convolutional linear map."""
  25. if not isinstance(args, (list, tuple)):
  26. args = [args]
  27. with tf.variable_scope(prefix):
  28. with tf.device("/cpu:0"):
  29. k = tf.get_variable("CvK", [kw, kh, nin, nout])
  30. if len(args) == 1:
  31. arg = args[0]
  32. else:
  33. arg = tf.concat(args, 3)
  34. res = tf.nn.convolution(arg, k, dilation_rate=(rate, 1), padding="SAME")
  35. if not do_bias: return res
  36. with tf.device("/cpu:0"):
  37. bias_term = tf.get_variable(
  38. "CvB", [nout], initializer=tf.constant_initializer(bias_start))
  39. bias_term = tf.reshape(bias_term, [1, 1, 1, nout])
  40. return res + bias_term
  41. def sigmoid_cutoff(x, cutoff):
  42. """Sigmoid with cutoff, e.g., 1.2sigmoid(x) - 0.1."""
  43. y = tf.sigmoid(x)
  44. if cutoff < 1.01: return y
  45. d = (cutoff - 1.0) / 2.0
  46. return tf.minimum(1.0, tf.maximum(0.0, cutoff * y - d), name="cutoff_min")
  47. @function.Defun(tf.float32, noinline=True)
  48. def sigmoid_cutoff_12(x):
  49. """Sigmoid with cutoff 1.2, specialized for speed and memory use."""
  50. y = tf.sigmoid(x)
  51. return tf.minimum(1.0, tf.maximum(0.0, 1.2 * y - 0.1), name="cutoff_min_12")
  52. @function.Defun(tf.float32, noinline=True)
  53. def sigmoid_hard(x):
  54. """Hard sigmoid."""
  55. return tf.minimum(1.0, tf.maximum(0.0, 0.25 * x + 0.5))
  56. def place_at14(decided, selected, it):
  57. """Place selected at it-th coordinate of decided, dim=1 of 4."""
  58. slice1 = decided[:, :it, :, :]
  59. slice2 = decided[:, it + 1:, :, :]
  60. return tf.concat([slice1, selected, slice2], 1)
  61. def place_at13(decided, selected, it):
  62. """Place selected at it-th coordinate of decided, dim=1 of 3."""
  63. slice1 = decided[:, :it, :]
  64. slice2 = decided[:, it + 1:, :]
  65. return tf.concat([slice1, selected, slice2], 1)
  66. def tanh_cutoff(x, cutoff):
  67. """Tanh with cutoff, e.g., 1.1tanh(x) cut to [-1. 1]."""
  68. y = tf.tanh(x)
  69. if cutoff < 1.01: return y
  70. d = (cutoff - 1.0) / 2.0
  71. return tf.minimum(1.0, tf.maximum(-1.0, (1.0 + d) * y))
  72. @function.Defun(tf.float32, noinline=True)
  73. def tanh_hard(x):
  74. """Hard tanh."""
  75. return tf.minimum(1.0, tf.maximum(0.0, x))
  76. def layer_norm(x, nmaps, prefix, epsilon=1e-5):
  77. """Layer normalize the 4D tensor x, averaging over the last dimension."""
  78. with tf.variable_scope(prefix):
  79. scale = tf.get_variable("layer_norm_scale", [nmaps],
  80. initializer=tf.ones_initializer())
  81. bias = tf.get_variable("layer_norm_bias", [nmaps],
  82. initializer=tf.zeros_initializer())
  83. mean, variance = tf.nn.moments(x, [3], keep_dims=True)
  84. norm_x = (x - mean) / tf.sqrt(variance + epsilon)
  85. return norm_x * scale + bias
  86. def conv_gru(inpts, mem, kw, kh, nmaps, rate, cutoff, prefix, do_layer_norm,
  87. args_len=None):
  88. """Convolutional GRU."""
  89. def conv_lin(args, suffix, bias_start):
  90. total_args_len = args_len or len(args) * nmaps
  91. res = conv_linear(args, kw, kh, total_args_len, nmaps, rate, True,
  92. bias_start, prefix + "/" + suffix)
  93. if do_layer_norm:
  94. return layer_norm(res, nmaps, prefix + "/" + suffix)
  95. else:
  96. return res
  97. if cutoff == 1.2:
  98. reset = sigmoid_cutoff_12(conv_lin(inpts + [mem], "r", 1.0))
  99. gate = sigmoid_cutoff_12(conv_lin(inpts + [mem], "g", 1.0))
  100. elif cutoff > 10:
  101. reset = sigmoid_hard(conv_lin(inpts + [mem], "r", 1.0))
  102. gate = sigmoid_hard(conv_lin(inpts + [mem], "g", 1.0))
  103. else:
  104. reset = sigmoid_cutoff(conv_lin(inpts + [mem], "r", 1.0), cutoff)
  105. gate = sigmoid_cutoff(conv_lin(inpts + [mem], "g", 1.0), cutoff)
  106. if cutoff > 10:
  107. candidate = tanh_hard(conv_lin(inpts + [reset * mem], "c", 0.0))
  108. else:
  109. # candidate = tanh_cutoff(conv_lin(inpts + [reset * mem], "c", 0.0), cutoff)
  110. candidate = tf.tanh(conv_lin(inpts + [reset * mem], "c", 0.0))
  111. return gate * mem + (1 - gate) * candidate
  112. CHOOSE_K = 256
  113. def memory_call(q, l, nmaps, mem_size, vocab_size, num_gpus, update_mem):
  114. raise ValueError("Fill for experiments with additional memory structures.")
  115. def memory_run(step, nmaps, mem_size, batch_size, vocab_size,
  116. global_step, do_training, update_mem, decay_factor, num_gpus,
  117. target_emb_weights, output_w, gpu_targets_tn, it):
  118. """Run memory."""
  119. q = step[:, 0, it, :]
  120. mlabels = gpu_targets_tn[:, it, 0]
  121. res, mask, mem_loss = memory_call(
  122. q, mlabels, nmaps, mem_size, vocab_size, num_gpus, update_mem)
  123. res = tf.gather(target_emb_weights, res) * tf.expand_dims(mask[:, 0], 1)
  124. # Mix gold and original in the first steps, 20% later.
  125. gold = tf.nn.dropout(tf.gather(target_emb_weights, mlabels), 0.7)
  126. use_gold = 1.0 - tf.cast(global_step, tf.float32) / (1000. * decay_factor)
  127. use_gold = tf.maximum(use_gold, 0.2) * do_training
  128. mem = tf.cond(tf.less(tf.random_uniform([]), use_gold),
  129. lambda: use_gold * gold + (1.0 - use_gold) * res,
  130. lambda: res)
  131. mem = tf.reshape(mem, [-1, 1, 1, nmaps])
  132. return mem, mem_loss, update_mem
  133. @tf.RegisterGradient("CustomIdG")
  134. def _custom_id_grad(_, grads):
  135. return grads
  136. def quantize(t, quant_scale, max_value=1.0):
  137. """Quantize a tensor t with each element in [-max_value, max_value]."""
  138. t = tf.minimum(max_value, tf.maximum(t, -max_value))
  139. big = quant_scale * (t + max_value) + 0.5
  140. with tf.get_default_graph().gradient_override_map({"Floor": "CustomIdG"}):
  141. res = (tf.floor(big) / quant_scale) - max_value
  142. return res
  143. def quantize_weights_op(quant_scale, max_value):
  144. ops = [v.assign(quantize(v, quant_scale, float(max_value)))
  145. for v in tf.trainable_variables()]
  146. return tf.group(*ops)
  147. def autoenc_quantize(x, nbits, nmaps, do_training, layers=1):
  148. """Autoencoder into nbits vectors of bits, using noise and sigmoids."""
  149. enc_x = tf.reshape(x, [-1, nmaps])
  150. for i in xrange(layers - 1):
  151. enc_x = tf.layers.dense(enc_x, nmaps, name="autoenc_%d" % i)
  152. enc_x = tf.layers.dense(enc_x, nbits, name="autoenc_%d" % (layers - 1))
  153. noise = tf.truncated_normal(tf.shape(enc_x), stddev=2.0)
  154. dec_x = sigmoid_cutoff_12(enc_x + noise * do_training)
  155. dec_x = tf.reshape(dec_x, [-1, nbits])
  156. for i in xrange(layers):
  157. dec_x = tf.layers.dense(dec_x, nmaps, name="autodec_%d" % i)
  158. return tf.reshape(dec_x, tf.shape(x))
  159. def make_dense(targets, noclass, low_param):
  160. """Move a batch of targets to a dense 1-hot representation."""
  161. low = low_param / float(noclass - 1)
  162. high = 1.0 - low * (noclass - 1)
  163. targets = tf.cast(targets, tf.int64)
  164. return tf.one_hot(targets, depth=noclass, on_value=high, off_value=low)
  165. def reorder_beam(beam_size, batch_size, beam_val, output, is_first,
  166. tensors_to_reorder):
  167. """Reorder to minimize beam costs."""
  168. # beam_val is [batch_size x beam_size]; let b = batch_size * beam_size
  169. # decided is len x b x a x b
  170. # output is b x out_size; step is b x len x a x b;
  171. outputs = tf.split(tf.nn.log_softmax(output), beam_size, 0)
  172. all_beam_vals, all_beam_idx = [], []
  173. beam_range = 1 if is_first else beam_size
  174. for i in xrange(beam_range):
  175. top_out, top_out_idx = tf.nn.top_k(outputs[i], k=beam_size)
  176. cur_beam_val = beam_val[:, i]
  177. top_out = tf.Print(top_out, [top_out, top_out_idx, beam_val, i,
  178. cur_beam_val], "GREPO", summarize=8)
  179. all_beam_vals.append(top_out + tf.expand_dims(cur_beam_val, 1))
  180. all_beam_idx.append(top_out_idx)
  181. all_beam_idx = tf.reshape(tf.transpose(tf.concat(all_beam_idx, 1), [1, 0]),
  182. [-1])
  183. top_beam, top_beam_idx = tf.nn.top_k(tf.concat(all_beam_vals, 1), k=beam_size)
  184. top_beam_idx = tf.Print(top_beam_idx, [top_beam, top_beam_idx],
  185. "GREP", summarize=8)
  186. reordered = [[] for _ in xrange(len(tensors_to_reorder) + 1)]
  187. top_out_idx = []
  188. for i in xrange(beam_size):
  189. which_idx = top_beam_idx[:, i] * batch_size + tf.range(batch_size)
  190. top_out_idx.append(tf.gather(all_beam_idx, which_idx))
  191. which_beam = top_beam_idx[:, i] / beam_size # [batch]
  192. which_beam = which_beam * batch_size + tf.range(batch_size)
  193. reordered[0].append(tf.gather(output, which_beam))
  194. for i, t in enumerate(tensors_to_reorder):
  195. reordered[i + 1].append(tf.gather(t, which_beam))
  196. new_tensors = [tf.concat(t, 0) for t in reordered]
  197. top_out_idx = tf.concat(top_out_idx, 0)
  198. return (top_beam, new_tensors[0], top_out_idx, new_tensors[1:])
  199. class NeuralGPU(object):
  200. """Neural GPU Model."""
  201. def __init__(self, nmaps, vec_size, niclass, noclass, dropout,
  202. max_grad_norm, cutoff, nconvs, kw, kh, height, mem_size,
  203. learning_rate, min_length, num_gpus, num_replicas,
  204. grad_noise_scale, sampling_rate, act_noise=0.0, do_rnn=False,
  205. atrous=False, beam_size=1, backward=True, do_layer_norm=False,
  206. autoenc_decay=1.0):
  207. # Feeds for parameters and ops to update them.
  208. self.nmaps = nmaps
  209. if backward:
  210. self.global_step = tf.Variable(0, trainable=False, name="global_step")
  211. self.cur_length = tf.Variable(min_length, trainable=False)
  212. self.cur_length_incr_op = self.cur_length.assign_add(1)
  213. self.lr = tf.Variable(learning_rate, trainable=False)
  214. self.lr_decay_op = self.lr.assign(self.lr * 0.995)
  215. self.do_training = tf.placeholder(tf.float32, name="do_training")
  216. self.update_mem = tf.placeholder(tf.int32, name="update_mem")
  217. self.noise_param = tf.placeholder(tf.float32, name="noise_param")
  218. # Feeds for inputs, targets, outputs, losses, etc.
  219. self.input = tf.placeholder(tf.int32, name="inp")
  220. self.target = tf.placeholder(tf.int32, name="tgt")
  221. self.prev_step = tf.placeholder(tf.float32, name="prev_step")
  222. gpu_input = tf.split(self.input, num_gpus, 0)
  223. gpu_target = tf.split(self.target, num_gpus, 0)
  224. gpu_prev_step = tf.split(self.prev_step, num_gpus, 0)
  225. batch_size = tf.shape(gpu_input[0])[0]
  226. if backward:
  227. adam_lr = 0.005 * self.lr
  228. adam = tf.train.AdamOptimizer(adam_lr, epsilon=1e-3)
  229. def adam_update(grads):
  230. return adam.apply_gradients(zip(grads, tf.trainable_variables()),
  231. global_step=self.global_step,
  232. name="adam_update")
  233. # When switching from Adam to SGD we perform reverse-decay.
  234. if backward:
  235. global_step_float = tf.cast(self.global_step, tf.float32)
  236. sampling_decay_exponent = global_step_float / 100000.0
  237. sampling_decay = tf.maximum(0.05, tf.pow(0.5, sampling_decay_exponent))
  238. self.sampling = sampling_rate * 0.05 / sampling_decay
  239. else:
  240. self.sampling = tf.constant(0.0)
  241. # Cache variables on cpu if needed.
  242. if num_replicas > 1 or num_gpus > 1:
  243. with tf.device("/cpu:0"):
  244. caching_const = tf.constant(0)
  245. tf.get_variable_scope().set_caching_device(caching_const.op.device)
  246. # partitioner = tf.variable_axis_size_partitioner(1024*256*4)
  247. # tf.get_variable_scope().set_partitioner(partitioner)
  248. def gpu_avg(l):
  249. if l[0] is None:
  250. for elem in l:
  251. assert elem is None
  252. return 0.0
  253. if len(l) < 2:
  254. return l[0]
  255. return sum(l) / float(num_gpus)
  256. self.length_tensor = tf.placeholder(tf.int32, name="length")
  257. with tf.device("/cpu:0"):
  258. emb_weights = tf.get_variable(
  259. "embedding", [niclass, vec_size],
  260. initializer=tf.random_uniform_initializer(-1.7, 1.7))
  261. if beam_size > 0:
  262. target_emb_weights = tf.get_variable(
  263. "target_embedding", [noclass, nmaps],
  264. initializer=tf.random_uniform_initializer(-1.7, 1.7))
  265. e0 = tf.scatter_update(emb_weights,
  266. tf.constant(0, dtype=tf.int32, shape=[1]),
  267. tf.zeros([1, vec_size]))
  268. output_w = tf.get_variable("output_w", [nmaps, noclass], tf.float32)
  269. def conv_rate(layer):
  270. if atrous:
  271. return 2**layer
  272. return 1
  273. # pylint: disable=cell-var-from-loop
  274. def enc_step(step):
  275. """Encoder step."""
  276. if autoenc_decay < 1.0:
  277. quant_step = autoenc_quantize(step, 16, nmaps, self.do_training)
  278. if backward:
  279. exp_glob = tf.train.exponential_decay(1.0, self.global_step - 10000,
  280. 1000, autoenc_decay)
  281. dec_factor = 1.0 - exp_glob # * self.do_training
  282. dec_factor = tf.cond(tf.less(self.global_step, 10500),
  283. lambda: tf.constant(0.05), lambda: dec_factor)
  284. else:
  285. dec_factor = 1.0
  286. cur = tf.cond(tf.less(tf.random_uniform([]), dec_factor),
  287. lambda: quant_step, lambda: step)
  288. else:
  289. cur = step
  290. if dropout > 0.0001:
  291. cur = tf.nn.dropout(cur, keep_prob)
  292. if act_noise > 0.00001:
  293. cur += tf.truncated_normal(tf.shape(cur)) * act_noise_scale
  294. # Do nconvs-many CGRU steps.
  295. if do_jit and tf.get_variable_scope().reuse:
  296. with jit_scope():
  297. for layer in xrange(nconvs):
  298. cur = conv_gru([], cur, kw, kh, nmaps, conv_rate(layer),
  299. cutoff, "ecgru_%d" % layer, do_layer_norm)
  300. else:
  301. for layer in xrange(nconvs):
  302. cur = conv_gru([], cur, kw, kh, nmaps, conv_rate(layer),
  303. cutoff, "ecgru_%d" % layer, do_layer_norm)
  304. return cur
  305. zero_tgt = tf.zeros([batch_size, nmaps, 1])
  306. zero_tgt.set_shape([None, nmaps, 1])
  307. def dec_substep(step, decided):
  308. """Decoder sub-step."""
  309. cur = step
  310. if dropout > 0.0001:
  311. cur = tf.nn.dropout(cur, keep_prob)
  312. if act_noise > 0.00001:
  313. cur += tf.truncated_normal(tf.shape(cur)) * act_noise_scale
  314. # Do nconvs-many CGRU steps.
  315. if do_jit and tf.get_variable_scope().reuse:
  316. with jit_scope():
  317. for layer in xrange(nconvs):
  318. cur = conv_gru([decided], cur, kw, kh, nmaps, conv_rate(layer),
  319. cutoff, "dcgru_%d" % layer, do_layer_norm)
  320. else:
  321. for layer in xrange(nconvs):
  322. cur = conv_gru([decided], cur, kw, kh, nmaps, conv_rate(layer),
  323. cutoff, "dcgru_%d" % layer, do_layer_norm)
  324. return cur
  325. # pylint: enable=cell-var-from-loop
  326. def dec_step(step, it, it_int, decided, output_ta, tgts,
  327. mloss, nupd_in, out_idx, beam_cost):
  328. """Decoder step."""
  329. nupd, mem_loss = 0, 0.0
  330. if mem_size > 0:
  331. it_incr = tf.minimum(it+1, length - 1)
  332. mem, mem_loss, nupd = memory_run(
  333. step, nmaps, mem_size, batch_size, noclass, self.global_step,
  334. self.do_training, self.update_mem, 10, num_gpus,
  335. target_emb_weights, output_w, gpu_targets_tn, it_incr)
  336. step = dec_substep(step, decided)
  337. output_l = tf.expand_dims(tf.expand_dims(step[:, it, 0, :], 1), 1)
  338. # Calculate argmax output.
  339. output = tf.reshape(output_l, [-1, nmaps])
  340. # pylint: disable=cell-var-from-loop
  341. output = tf.matmul(output, output_w)
  342. if beam_size > 1:
  343. beam_cost, output, out, reordered = reorder_beam(
  344. beam_size, batch_size, beam_cost, output, it_int == 0,
  345. [output_l, out_idx, step, decided])
  346. [output_l, out_idx, step, decided] = reordered
  347. else:
  348. # Scheduled sampling.
  349. out = tf.multinomial(tf.stop_gradient(output), 1)
  350. out = tf.to_int32(tf.squeeze(out, [1]))
  351. out_write = output_ta.write(it, output_l[:batch_size, :, :, :])
  352. output = tf.gather(target_emb_weights, out)
  353. output = tf.reshape(output, [-1, 1, nmaps])
  354. output = tf.concat([output] * height, 1)
  355. tgt = tgts[it, :, :, :]
  356. selected = tf.cond(tf.less(tf.random_uniform([]), self.sampling),
  357. lambda: output, lambda: tgt)
  358. # pylint: enable=cell-var-from-loop
  359. dec_write = place_at14(decided, tf.expand_dims(selected, 1), it)
  360. out_idx = place_at13(
  361. out_idx, tf.reshape(out, [beam_size * batch_size, 1, 1]), it)
  362. if mem_size > 0:
  363. mem = tf.concat([mem] * height, 2)
  364. dec_write = place_at14(dec_write, mem, it_incr)
  365. return (step, dec_write, out_write, mloss + mem_loss, nupd_in + nupd,
  366. out_idx, beam_cost)
  367. # Main model construction.
  368. gpu_outputs = []
  369. gpu_losses = []
  370. gpu_grad_norms = []
  371. grads_list = []
  372. gpu_out_idx = []
  373. self.after_enc_step = []
  374. for gpu in xrange(num_gpus): # Multi-GPU towers, average gradients later.
  375. length = self.length_tensor
  376. length_float = tf.cast(length, tf.float32)
  377. if gpu > 0:
  378. tf.get_variable_scope().reuse_variables()
  379. gpu_outputs.append([])
  380. gpu_losses.append([])
  381. gpu_grad_norms.append([])
  382. with tf.name_scope("gpu%d" % gpu), tf.device("/gpu:%d" % gpu):
  383. # Main graph creation loop.
  384. data.print_out("Creating model.")
  385. start_time = time.time()
  386. # Embed inputs and calculate mask.
  387. with tf.device("/cpu:0"):
  388. tgt_shape = tf.shape(tf.squeeze(gpu_target[gpu], [1]))
  389. weights = tf.where(tf.squeeze(gpu_target[gpu], [1]) > 0,
  390. tf.ones(tgt_shape), tf.zeros(tgt_shape))
  391. # Embed inputs and targets.
  392. with tf.control_dependencies([e0]):
  393. start = tf.gather(emb_weights, gpu_input[gpu]) # b x h x l x nmaps
  394. gpu_targets_tn = gpu_target[gpu] # b x 1 x len
  395. if beam_size > 0:
  396. embedded_targets_tn = tf.gather(target_emb_weights,
  397. gpu_targets_tn)
  398. embedded_targets_tn = tf.transpose(
  399. embedded_targets_tn, [2, 0, 1, 3]) # len x b x 1 x nmaps
  400. embedded_targets_tn = tf.concat([embedded_targets_tn] * height, 2)
  401. # First image comes from start by applying convolution and adding 0s.
  402. start = tf.transpose(start, [0, 2, 1, 3]) # Now b x len x h x vec_s
  403. first = conv_linear(start, 1, 1, vec_size, nmaps, 1, True, 0.0, "input")
  404. first = layer_norm(first, nmaps, "input")
  405. # Computation steps.
  406. keep_prob = dropout * 3.0 / tf.sqrt(length_float)
  407. keep_prob = 1.0 - self.do_training * keep_prob
  408. act_noise_scale = act_noise * self.do_training
  409. # Start with a convolutional gate merging previous step.
  410. step = conv_gru([gpu_prev_step[gpu]], first,
  411. kw, kh, nmaps, 1, cutoff, "first", do_layer_norm)
  412. # This is just for running a baseline RNN seq2seq model.
  413. if do_rnn:
  414. self.after_enc_step.append(step) # Not meaningful here, but needed.
  415. lstm_cell = tf.contrib.rnn.BasicLSTMCell(height * nmaps)
  416. cell = tf.contrib.rnn.MultiRNNCell([lstm_cell] * nconvs)
  417. with tf.variable_scope("encoder"):
  418. encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
  419. cell, tf.reshape(step, [batch_size, length, height * nmaps]),
  420. dtype=tf.float32, time_major=False)
  421. # Attention.
  422. attn = tf.layers.dense(
  423. encoder_outputs, height * nmaps, name="attn1")
  424. # pylint: disable=cell-var-from-loop
  425. @function.Defun(noinline=True)
  426. def attention_query(query, attn_v):
  427. vecs = tf.tanh(attn + tf.expand_dims(query, 1))
  428. mask = tf.reduce_sum(vecs * tf.reshape(attn_v, [1, 1, -1]), 2)
  429. mask = tf.nn.softmax(mask)
  430. return tf.reduce_sum(encoder_outputs * tf.expand_dims(mask, 2), 1)
  431. with tf.variable_scope("decoder"):
  432. def decoder_loop_fn((state, prev_cell_out, _), (cell_inp, cur_tgt)):
  433. """Decoder loop function."""
  434. attn_q = tf.layers.dense(prev_cell_out, height * nmaps,
  435. name="attn_query")
  436. attn_res = attention_query(attn_q, tf.get_variable(
  437. "attn_v", [height * nmaps],
  438. initializer=tf.random_uniform_initializer(-0.1, 0.1)))
  439. concatenated = tf.reshape(tf.concat([cell_inp, attn_res], 1),
  440. [batch_size, 2 * height * nmaps])
  441. cell_inp = tf.layers.dense(
  442. concatenated, height * nmaps, name="attn_merge")
  443. output, new_state = cell(cell_inp, state)
  444. mem_loss = 0.0
  445. if mem_size > 0:
  446. res, mask, mem_loss = memory_call(
  447. output, cur_tgt, height * nmaps, mem_size, noclass,
  448. num_gpus, self.update_mem)
  449. res = tf.gather(target_emb_weights, res)
  450. res *= tf.expand_dims(mask[:, 0], 1)
  451. output = tf.layers.dense(
  452. tf.concat([output, res], 1), height * nmaps, name="rnnmem")
  453. return new_state, output, mem_loss
  454. # pylint: enable=cell-var-from-loop
  455. gpu_targets = tf.squeeze(gpu_target[gpu], [1]) # b x len
  456. gpu_tgt_trans = tf.transpose(gpu_targets, [1, 0])
  457. dec_zero = tf.zeros([batch_size, 1], dtype=tf.int32)
  458. dec_inp = tf.concat([dec_zero, gpu_targets], 1)
  459. dec_inp = dec_inp[:, :length]
  460. embedded_dec_inp = tf.gather(target_emb_weights, dec_inp)
  461. embedded_dec_inp_proj = tf.layers.dense(
  462. embedded_dec_inp, height * nmaps, name="dec_proj")
  463. embedded_dec_inp_proj = tf.transpose(embedded_dec_inp_proj,
  464. [1, 0, 2])
  465. init_vals = (encoder_state,
  466. tf.zeros([batch_size, height * nmaps]), 0.0)
  467. _, dec_outputs, mem_losses = tf.scan(
  468. decoder_loop_fn, (embedded_dec_inp_proj, gpu_tgt_trans),
  469. initializer=init_vals)
  470. mem_loss = tf.reduce_mean(mem_losses)
  471. outputs = tf.layers.dense(dec_outputs, nmaps, name="out_proj")
  472. # Final convolution to get logits, list outputs.
  473. outputs = tf.matmul(tf.reshape(outputs, [-1, nmaps]), output_w)
  474. outputs = tf.reshape(outputs, [length, batch_size, noclass])
  475. gpu_out_idx.append(tf.argmax(outputs, 2))
  476. else: # Here we go with the Neural GPU.
  477. # Encoder.
  478. enc_length = length
  479. step = enc_step(step) # First step hard-coded.
  480. # pylint: disable=cell-var-from-loop
  481. i = tf.constant(1)
  482. c = lambda i, _s: tf.less(i, enc_length)
  483. def enc_step_lambda(i, step):
  484. with tf.variable_scope(tf.get_variable_scope(), reuse=True):
  485. new_step = enc_step(step)
  486. return (i + 1, new_step)
  487. _, step = tf.while_loop(
  488. c, enc_step_lambda, [i, step],
  489. parallel_iterations=1, swap_memory=True)
  490. # pylint: enable=cell-var-from-loop
  491. self.after_enc_step.append(step)
  492. # Decoder.
  493. if beam_size > 0:
  494. output_ta = tf.TensorArray(
  495. dtype=tf.float32, size=length, dynamic_size=False,
  496. infer_shape=False, name="outputs")
  497. out_idx = tf.zeros([beam_size * batch_size, length, 1],
  498. dtype=tf.int32)
  499. decided_t = tf.zeros([beam_size * batch_size, length,
  500. height, vec_size])
  501. # Prepare for beam search.
  502. tgts = tf.concat([embedded_targets_tn] * beam_size, 1)
  503. beam_cost = tf.zeros([batch_size, beam_size])
  504. step = tf.concat([step] * beam_size, 0)
  505. # First step hard-coded.
  506. step, decided_t, output_ta, mem_loss, nupd, oi, bc = dec_step(
  507. step, 0, 0, decided_t, output_ta, tgts, 0.0, 0, out_idx,
  508. beam_cost)
  509. tf.get_variable_scope().reuse_variables()
  510. # pylint: disable=cell-var-from-loop
  511. def step_lambda(i, step, dec_t, out_ta, ml, nu, oi, bc):
  512. with tf.variable_scope(tf.get_variable_scope(), reuse=True):
  513. s, d, t, nml, nu, oi, bc = dec_step(
  514. step, i, 1, dec_t, out_ta, tgts, ml, nu, oi, bc)
  515. return (i + 1, s, d, t, nml, nu, oi, bc)
  516. i = tf.constant(1)
  517. c = lambda i, _s, _d, _o, _ml, _nu, _oi, _bc: tf.less(i, length)
  518. _, step, _, output_ta, mem_loss, nupd, out_idx, _ = tf.while_loop(
  519. c, step_lambda,
  520. [i, step, decided_t, output_ta, mem_loss, nupd, oi, bc],
  521. parallel_iterations=1, swap_memory=True)
  522. # pylint: enable=cell-var-from-loop
  523. gpu_out_idx.append(tf.squeeze(out_idx, [2]))
  524. outputs = output_ta.stack()
  525. outputs = tf.squeeze(outputs, [2, 3]) # Now l x b x nmaps
  526. else:
  527. # If beam_size is 0 or less, we don't have a decoder.
  528. mem_loss = 0.0
  529. outputs = tf.transpose(step[:, :, 1, :], [1, 0, 2])
  530. gpu_out_idx.append(tf.argmax(outputs, 2))
  531. # Final convolution to get logits, list outputs.
  532. outputs = tf.matmul(tf.reshape(outputs, [-1, nmaps]), output_w)
  533. outputs = tf.reshape(outputs, [length, batch_size, noclass])
  534. gpu_outputs[gpu] = tf.nn.softmax(outputs)
  535. # Calculate cross-entropy loss and normalize it.
  536. targets_soft = make_dense(tf.squeeze(gpu_target[gpu], [1]),
  537. noclass, 0.1)
  538. targets_soft = tf.reshape(targets_soft, [-1, noclass])
  539. targets_hard = make_dense(tf.squeeze(gpu_target[gpu], [1]),
  540. noclass, 0.0)
  541. targets_hard = tf.reshape(targets_hard, [-1, noclass])
  542. output = tf.transpose(outputs, [1, 0, 2])
  543. xent_soft = tf.reshape(tf.nn.softmax_cross_entropy_with_logits(
  544. logits=tf.reshape(output, [-1, noclass]), labels=targets_soft),
  545. [batch_size, length])
  546. xent_hard = tf.reshape(tf.nn.softmax_cross_entropy_with_logits(
  547. logits=tf.reshape(output, [-1, noclass]), labels=targets_hard),
  548. [batch_size, length])
  549. low, high = 0.1 / float(noclass - 1), 0.9
  550. const = high * tf.log(high) + float(noclass - 1) * low * tf.log(low)
  551. weight_sum = tf.reduce_sum(weights) + 1e-20
  552. true_perp = tf.reduce_sum(xent_hard * weights) / weight_sum
  553. soft_loss = tf.reduce_sum(xent_soft * weights) / weight_sum
  554. perp_loss = soft_loss + const
  555. # Final loss: cross-entropy + shared parameter relaxation part + extra.
  556. mem_loss = 0.5 * tf.reduce_mean(mem_loss) / length_float
  557. total_loss = perp_loss + mem_loss
  558. gpu_losses[gpu].append(true_perp)
  559. # Gradients.
  560. if backward:
  561. data.print_out("Creating backward pass for the model.")
  562. grads = tf.gradients(
  563. total_loss, tf.trainable_variables(),
  564. colocate_gradients_with_ops=True)
  565. for g_i, g in enumerate(grads):
  566. if isinstance(g, tf.IndexedSlices):
  567. grads[g_i] = tf.convert_to_tensor(g)
  568. grads, norm = tf.clip_by_global_norm(grads, max_grad_norm)
  569. gpu_grad_norms[gpu].append(norm)
  570. for g in grads:
  571. if grad_noise_scale > 0.001:
  572. g += tf.truncated_normal(tf.shape(g)) * self.noise_param
  573. grads_list.append(grads)
  574. else:
  575. gpu_grad_norms[gpu].append(0.0)
  576. data.print_out("Created model for gpu %d in %.2f s."
  577. % (gpu, time.time() - start_time))
  578. self.updates = []
  579. self.after_enc_step = tf.concat(self.after_enc_step, 0) # Concat GPUs.
  580. if backward:
  581. tf.get_variable_scope()._reuse = False
  582. tf.get_variable_scope().set_caching_device(None)
  583. grads = [gpu_avg([grads_list[g][i] for g in xrange(num_gpus)])
  584. for i in xrange(len(grads_list[0]))]
  585. update = adam_update(grads)
  586. self.updates.append(update)
  587. else:
  588. self.updates.append(tf.no_op())
  589. self.losses = [gpu_avg([gpu_losses[g][i] for g in xrange(num_gpus)])
  590. for i in xrange(len(gpu_losses[0]))]
  591. self.out_idx = tf.concat(gpu_out_idx, 0)
  592. self.grad_norms = [gpu_avg([gpu_grad_norms[g][i] for g in xrange(num_gpus)])
  593. for i in xrange(len(gpu_grad_norms[0]))]
  594. self.outputs = [tf.concat([gpu_outputs[g] for g in xrange(num_gpus)], 1)]
  595. self.quantize_op = quantize_weights_op(512, 8)
  596. if backward:
  597. self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
  598. def step(self, sess, inp, target, do_backward_in, noise_param=None,
  599. beam_size=2, eos_id=2, eos_cost=0.0, update_mem=None, state=None):
  600. """Run a step of the network."""
  601. batch_size, height, length = inp.shape[0], inp.shape[1], inp.shape[2]
  602. do_backward = do_backward_in
  603. train_mode = True
  604. if do_backward_in is None:
  605. do_backward = False
  606. train_mode = False
  607. if update_mem is None:
  608. update_mem = do_backward
  609. feed_in = {}
  610. # print " feeding sequences of length %d" % length
  611. if state is None:
  612. state = np.zeros([batch_size, length, height, self.nmaps])
  613. feed_in[self.prev_step.name] = state
  614. feed_in[self.length_tensor.name] = length
  615. feed_in[self.noise_param.name] = noise_param if noise_param else 0.0
  616. feed_in[self.do_training.name] = 1.0 if do_backward else 0.0
  617. feed_in[self.update_mem.name] = 1 if update_mem else 0
  618. if do_backward_in is False:
  619. feed_in[self.sampling.name] = 0.0
  620. index = 0 # We're dynamic now.
  621. feed_out = []
  622. if do_backward:
  623. feed_out.append(self.updates[index])
  624. feed_out.append(self.grad_norms[index])
  625. if train_mode:
  626. feed_out.append(self.losses[index])
  627. feed_in[self.input.name] = inp
  628. feed_in[self.target.name] = target
  629. feed_out.append(self.outputs[index])
  630. if train_mode:
  631. # Make a full-sequence training step with one call to session.run.
  632. res = sess.run([self.after_enc_step] + feed_out, feed_in)
  633. after_enc_state, res = res[0], res[1:]
  634. else:
  635. # Make a full-sequence decoding step with one call to session.run.
  636. feed_in[self.sampling.name] = 1.1 # Sample every time.
  637. res = sess.run([self.after_enc_step, self.out_idx] + feed_out, feed_in)
  638. after_enc_state, out_idx = res[0], res[1]
  639. res = [res[2][l] for l in xrange(length)]
  640. outputs = [out_idx[:, i] for i in xrange(length)]
  641. cost = [0.0 for _ in xrange(beam_size * batch_size)]
  642. seen_eos = [0 for _ in xrange(beam_size * batch_size)]
  643. for idx, logit in enumerate(res):
  644. best = outputs[idx]
  645. for b in xrange(batch_size):
  646. if seen_eos[b] > 1:
  647. cost[b] -= eos_cost
  648. else:
  649. cost[b] += np.log(logit[b][best[b]])
  650. if best[b] in [eos_id]:
  651. seen_eos[b] += 1
  652. res = [[-c for c in cost]] + outputs
  653. # Collect and output results.
  654. offset = 0
  655. norm = None
  656. if do_backward:
  657. offset = 2
  658. norm = res[1]
  659. if train_mode:
  660. outputs = res[offset + 1]
  661. outputs = [outputs[l] for l in xrange(length)]
  662. return res[offset], outputs, norm, after_enc_state