neural_gpu.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. """The Neural GPU Model."""
  2. import time
  3. import google3
  4. import tensorflow as tf
  5. from google3.experimental.users.lukaszkaiser.neural_gpu import data_utils
  6. def conv_linear(args, kw, kh, nin, nout, do_bias, bias_start, prefix):
  7. """Convolutional linear map."""
  8. assert args
  9. if not isinstance(args, (list, tuple)):
  10. args = [args]
  11. with tf.variable_scope(prefix):
  12. k = tf.get_variable("CvK", [kw, kh, nin, nout])
  13. if len(args) == 1:
  14. res = tf.nn.conv2d(args[0], k, [1, 1, 1, 1], "SAME")
  15. else:
  16. res = tf.nn.conv2d(tf.concat(3, args), k, [1, 1, 1, 1], "SAME")
  17. if not do_bias: return res
  18. bias_term = tf.get_variable("CvB", [nout],
  19. initializer=tf.constant_initializer(0.0))
  20. return res + bias_term + bias_start
  21. def sigmoid_cutoff(x, cutoff):
  22. """Sigmoid with cutoff, e.g., 1.2sigmoid(x) - 0.1."""
  23. y = tf.sigmoid(x)
  24. if cutoff < 1.01: return y
  25. d = (cutoff - 1.0) / 2.0
  26. return tf.minimum(1.0, tf.maximum(0.0, cutoff * y - d))
  27. def conv_gru(inpts, mem, kw, kh, nmaps, cutoff, prefix):
  28. """Convolutional GRU."""
  29. def conv_lin(args, suffix, bias_start):
  30. return conv_linear(args, kw, kh, len(args) * nmaps, nmaps, True, bias_start,
  31. prefix + "/" + suffix)
  32. reset = sigmoid_cutoff(conv_lin(inpts + [mem], "r", 1.0), cutoff)
  33. candidate = tf.tanh(conv_lin(inpts + [reset * mem], "c", 0.0))
  34. gate = sigmoid_cutoff(conv_lin(inpts + [mem], "g", 1.0), cutoff)
  35. return gate * mem + (1 - gate) * candidate
  36. def relaxed_average(var_name_suffix, rx_step):
  37. """Calculate the average of relaxed variables having var_name_suffix."""
  38. relaxed_vars = []
  39. for l in xrange(rx_step):
  40. with tf.variable_scope("RX%d" % l, reuse=True):
  41. try:
  42. relaxed_vars.append(tf.get_variable(var_name_suffix))
  43. except ValueError:
  44. pass
  45. dsum = tf.add_n(relaxed_vars)
  46. avg = dsum / len(relaxed_vars)
  47. diff = [v - avg for v in relaxed_vars]
  48. davg = tf.add_n([d*d for d in diff])
  49. return avg, tf.reduce_sum(davg)
  50. def relaxed_distance(rx_step):
  51. """Distance between relaxed variables and their average."""
  52. res, ops, rx_done = [], [], {}
  53. for v in tf.trainable_variables():
  54. if v.name[0:2] == "RX":
  55. rx_name = v.op.name[v.name.find("/") + 1:]
  56. if rx_name not in rx_done:
  57. avg, dist_loss = relaxed_average(rx_name, rx_step)
  58. res.append(dist_loss)
  59. rx_done[rx_name] = avg
  60. ops.append(v.assign(rx_done[rx_name]))
  61. return tf.add_n(res), tf.group(*ops)
  62. def make_dense(targets, noclass):
  63. """Move a batch of targets to a dense 1-hot representation."""
  64. with tf.device("/cpu:0"):
  65. shape = tf.shape(targets)
  66. batch_size = shape[0]
  67. indices = targets + noclass * tf.range(0, batch_size)
  68. length = batch_size * noclass
  69. dense = tf.sparse_to_dense(indices, length, 1.0, 0.0)
  70. return tf.reshape(dense, [-1, noclass])
  71. def check_for_zero(sparse):
  72. """In a sparse batch of ints, make 1.0 if it's 0 and 0.0 else."""
  73. with tf.device("/cpu:0"):
  74. shape = tf.shape(sparse)
  75. batch_size = shape[0]
  76. sparse = tf.minimum(sparse, 1)
  77. indices = sparse + 2 * tf.range(0, batch_size)
  78. dense = tf.sparse_to_dense(indices, 2 * batch_size, 1.0, 0.0)
  79. reshaped = tf.reshape(dense, [-1, 2])
  80. return tf.reshape(tf.slice(reshaped, [0, 0], [-1, 1]), [-1])
  81. class NeuralGPU(object):
  82. """Neural GPU Model."""
  83. def __init__(self, nmaps, vec_size, niclass, noclass, dropout, rx_step,
  84. max_grad_norm, cutoff, nconvs, kw, kh, height, mode,
  85. learning_rate, pull, pull_incr, min_length):
  86. # Feeds for parameters and ops to update them.
  87. self.global_step = tf.Variable(0, trainable=False)
  88. self.cur_length = tf.Variable(min_length, trainable=False)
  89. self.cur_length_incr_op = self.cur_length.assign_add(1)
  90. self.lr = tf.Variable(float(learning_rate), trainable=False)
  91. self.lr_decay_op = self.lr.assign(self.lr * 0.98)
  92. self.pull = tf.Variable(float(pull), trainable=False)
  93. self.pull_incr_op = self.pull.assign(self.pull * pull_incr)
  94. self.do_training = tf.placeholder(tf.float32, name="do_training")
  95. self.noise_param = tf.placeholder(tf.float32, name="noise_param")
  96. # Feeds for inputs, targets, outputs, losses, etc.
  97. self.input = []
  98. self.target = []
  99. for l in xrange(data_utils.forward_max + 1):
  100. self.input.append(tf.placeholder(tf.int32, name="inp{0}".format(l)))
  101. self.target.append(tf.placeholder(tf.int32, name="tgt{0}".format(l)))
  102. self.outputs = []
  103. self.losses = []
  104. self.grad_norms = []
  105. self.updates = []
  106. # Computation.
  107. inp0_shape = tf.shape(self.input[0])
  108. batch_size = inp0_shape[0]
  109. with tf.device("/cpu:0"):
  110. emb_weights = tf.get_variable(
  111. "embedding", [niclass, vec_size],
  112. initializer=tf.random_uniform_initializer(-1.7, 1.7))
  113. e0 = tf.scatter_update(emb_weights,
  114. tf.constant(0, dtype=tf.int32, shape=[1]),
  115. tf.zeros([1, vec_size]))
  116. adam = tf.train.AdamOptimizer(0.01*self.lr, epsilon=1e-5)
  117. # Main graph creation loop, for every bin in data_utils.
  118. self.steps = []
  119. for length in sorted(list(set(data_utils.bins + [data_utils.forward_max]))):
  120. data_utils.print_out("Creating model for bin of length %d." % length)
  121. start_time = time.time()
  122. if length > data_utils.bins[0]:
  123. tf.get_variable_scope().reuse_variables()
  124. # Embed inputs and calculate mask.
  125. with tf.device("/cpu:0"):
  126. with tf.control_dependencies([e0]):
  127. embedded = [tf.nn.embedding_lookup(emb_weights, self.input[l])
  128. for l in xrange(length)]
  129. # Mask to 0-out padding space in each step.
  130. imask = [check_for_zero(self.input[l]) for l in xrange(length)]
  131. omask = [check_for_zero(self.target[l]) for l in xrange(length)]
  132. mask = [1.0 - (imask[i] * omask[i]) for i in xrange(length)]
  133. mask = [tf.reshape(m, [-1, 1]) for m in mask]
  134. # Use a shifted mask for step scaling and concatenated for weights.
  135. shifted_mask = mask + [tf.zeros_like(mask[0])]
  136. scales = [shifted_mask[i] * (1.0 - shifted_mask[i+1])
  137. for i in xrange(length)]
  138. scales = [tf.reshape(s, [-1, 1, 1, 1]) for s in scales]
  139. mask = tf.concat(1, mask[0:length]) # batch x length
  140. weights = mask
  141. # Add a height dimension to mask to use later for masking.
  142. mask = tf.reshape(mask, [-1, length, 1, 1])
  143. mask = tf.concat(2, [mask for _ in xrange(height)]) + tf.zeros(
  144. tf.pack([batch_size, length, height, nmaps]), dtype=tf.float32)
  145. # Start is a length-list of batch-by-nmaps tensors, reshape and concat.
  146. start = [tf.tanh(embedded[l]) for l in xrange(length)]
  147. start = [tf.reshape(start[l], [-1, 1, nmaps]) for l in xrange(length)]
  148. start = tf.reshape(tf.concat(1, start), [-1, length, 1, nmaps])
  149. # First image comes from start by applying one convolution and adding 0s.
  150. first = conv_linear(start, 1, 1, vec_size, nmaps, True, 0.0, "input")
  151. first = [first] + [tf.zeros(tf.pack([batch_size, length, 1, nmaps]),
  152. dtype=tf.float32) for _ in xrange(height - 1)]
  153. first = tf.concat(2, first)
  154. # Computation steps.
  155. step = [tf.nn.dropout(first, 1.0 - self.do_training * dropout) * mask]
  156. outputs = []
  157. for it in xrange(length):
  158. with tf.variable_scope("RX%d" % (it % rx_step)) as vs:
  159. if it >= rx_step:
  160. vs.reuse_variables()
  161. cur = step[it]
  162. # Do nconvs-many CGRU steps.
  163. for layer in xrange(nconvs):
  164. cur = conv_gru([], cur, kw, kh, nmaps, cutoff, "cgru_%d" % layer)
  165. cur = tf.nn.dropout(cur, 1.0 - self.do_training * dropout)
  166. step.append(cur * mask)
  167. outputs.append(tf.slice(step[-1], [0, 0, 0, 0], [-1, -1, 1, -1]))
  168. self.steps.append([tf.reshape(s, [-1, length, height * nmaps])
  169. for s in step])
  170. # Output is the n-th step output; n = current length, as in scales.
  171. output = tf.add_n([outputs[i] * scales[i] for i in xrange(length)])
  172. # Final convolution to get logits, list outputs.
  173. output = conv_linear(output, 1, 1, nmaps, noclass, True, 0.0, "output")
  174. output = tf.reshape(output, [-1, length, noclass])
  175. self.outputs.append([tf.reshape(o, [-1, noclass])
  176. for o in list(tf.split(1, length, output))])
  177. # Calculate cross-entropy loss and normalize it.
  178. targets = tf.concat(1, [make_dense(self.target[l], noclass)
  179. for l in xrange(length)])
  180. targets = tf.reshape(targets, [-1, noclass])
  181. xent = tf.reshape(tf.nn.softmax_cross_entropy_with_logits(
  182. tf.reshape(output, [-1, noclass]), targets), [-1, length])
  183. perp_loss = tf.reduce_sum(xent * weights)
  184. perp_loss /= tf.cast(batch_size, dtype=tf.float32)
  185. perp_loss /= length
  186. # Final loss: cross-entropy + shared parameter relaxation part.
  187. relax_dist, self.avg_op = relaxed_distance(rx_step)
  188. total_loss = perp_loss + relax_dist * self.pull
  189. self.losses.append(perp_loss)
  190. # Gradients and Adam update operation.
  191. if length == data_utils.bins[0] or (mode == 0 and
  192. length < data_utils.bins[-1] + 1):
  193. data_utils.print_out("Creating backward for bin of length %d." % length)
  194. params = tf.trainable_variables()
  195. grads = tf.gradients(total_loss, params)
  196. grads, norm = tf.clip_by_global_norm(grads, max_grad_norm)
  197. self.grad_norms.append(norm)
  198. for grad in grads:
  199. if isinstance(grad, tf.Tensor):
  200. grad += tf.truncated_normal(tf.shape(grad)) * self.noise_param
  201. update = adam.apply_gradients(zip(grads, params),
  202. global_step=self.global_step)
  203. self.updates.append(update)
  204. data_utils.print_out("Created model for bin of length %d in"
  205. " %.2f s." % (length, time.time() - start_time))
  206. self.saver = tf.train.Saver(tf.all_variables())
  207. def step(self, sess, inp, target, do_backward, noise_param=None):
  208. """Run a step of the network."""
  209. assert len(inp) == len(target)
  210. length = len(target)
  211. feed_in = {}
  212. feed_in[self.noise_param.name] = noise_param if noise_param else 0.0
  213. feed_in[self.do_training.name] = 1.0 if do_backward else 0.0
  214. feed_out = []
  215. index = len(data_utils.bins)
  216. if length < data_utils.bins[-1] + 1:
  217. index = data_utils.bins.index(length)
  218. if do_backward:
  219. feed_out.append(self.updates[index])
  220. feed_out.append(self.grad_norms[index])
  221. feed_out.append(self.losses[index])
  222. for l in xrange(length):
  223. feed_in[self.input[l].name] = inp[l]
  224. for l in xrange(length):
  225. feed_in[self.target[l].name] = target[l]
  226. feed_out.append(self.outputs[index][l])
  227. for l in xrange(length+1):
  228. feed_out.append(self.steps[index][l])
  229. res = sess.run(feed_out, feed_in)
  230. offset = 0
  231. norm = None
  232. if do_backward:
  233. offset = 2
  234. norm = res[1]
  235. outputs = res[offset + 1:offset + 1 + length]
  236. steps = res[offset + 1 + length:]
  237. return res[offset], outputs, norm, steps