neural_gpu.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. # Copyright 2015 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. #==============================================================================
  16. """The Neural GPU Model."""
  17. import time
  18. import google3
  19. import tensorflow as tf
  20. from google3.experimental.users.lukaszkaiser.neural_gpu import data_utils
  21. def conv_linear(args, kw, kh, nin, nout, do_bias, bias_start, prefix):
  22. """Convolutional linear map."""
  23. assert args
  24. if not isinstance(args, (list, tuple)):
  25. args = [args]
  26. with tf.variable_scope(prefix):
  27. k = tf.get_variable("CvK", [kw, kh, nin, nout])
  28. if len(args) == 1:
  29. res = tf.nn.conv2d(args[0], k, [1, 1, 1, 1], "SAME")
  30. else:
  31. res = tf.nn.conv2d(tf.concat(3, args), k, [1, 1, 1, 1], "SAME")
  32. if not do_bias: return res
  33. bias_term = tf.get_variable("CvB", [nout],
  34. initializer=tf.constant_initializer(0.0))
  35. return res + bias_term + bias_start
  36. def sigmoid_cutoff(x, cutoff):
  37. """Sigmoid with cutoff, e.g., 1.2sigmoid(x) - 0.1."""
  38. y = tf.sigmoid(x)
  39. if cutoff < 1.01: return y
  40. d = (cutoff - 1.0) / 2.0
  41. return tf.minimum(1.0, tf.maximum(0.0, cutoff * y - d))
  42. def conv_gru(inpts, mem, kw, kh, nmaps, cutoff, prefix):
  43. """Convolutional GRU."""
  44. def conv_lin(args, suffix, bias_start):
  45. return conv_linear(args, kw, kh, len(args) * nmaps, nmaps, True, bias_start,
  46. prefix + "/" + suffix)
  47. reset = sigmoid_cutoff(conv_lin(inpts + [mem], "r", 1.0), cutoff)
  48. candidate = tf.tanh(conv_lin(inpts + [reset * mem], "c", 0.0))
  49. gate = sigmoid_cutoff(conv_lin(inpts + [mem], "g", 1.0), cutoff)
  50. return gate * mem + (1 - gate) * candidate
  51. def relaxed_average(var_name_suffix, rx_step):
  52. """Calculate the average of relaxed variables having var_name_suffix."""
  53. relaxed_vars = []
  54. for l in xrange(rx_step):
  55. with tf.variable_scope("RX%d" % l, reuse=True):
  56. try:
  57. relaxed_vars.append(tf.get_variable(var_name_suffix))
  58. except ValueError:
  59. pass
  60. dsum = tf.add_n(relaxed_vars)
  61. avg = dsum / len(relaxed_vars)
  62. diff = [v - avg for v in relaxed_vars]
  63. davg = tf.add_n([d*d for d in diff])
  64. return avg, tf.reduce_sum(davg)
  65. def relaxed_distance(rx_step):
  66. """Distance between relaxed variables and their average."""
  67. res, ops, rx_done = [], [], {}
  68. for v in tf.trainable_variables():
  69. if v.name[0:2] == "RX":
  70. rx_name = v.op.name[v.name.find("/") + 1:]
  71. if rx_name not in rx_done:
  72. avg, dist_loss = relaxed_average(rx_name, rx_step)
  73. res.append(dist_loss)
  74. rx_done[rx_name] = avg
  75. ops.append(v.assign(rx_done[rx_name]))
  76. return tf.add_n(res), tf.group(*ops)
  77. def make_dense(targets, noclass):
  78. """Move a batch of targets to a dense 1-hot representation."""
  79. with tf.device("/cpu:0"):
  80. shape = tf.shape(targets)
  81. batch_size = shape[0]
  82. indices = targets + noclass * tf.range(0, batch_size)
  83. length = batch_size * noclass
  84. dense = tf.sparse_to_dense(indices, length, 1.0, 0.0)
  85. return tf.reshape(dense, [-1, noclass])
  86. def check_for_zero(sparse):
  87. """In a sparse batch of ints, make 1.0 if it's 0 and 0.0 else."""
  88. with tf.device("/cpu:0"):
  89. shape = tf.shape(sparse)
  90. batch_size = shape[0]
  91. sparse = tf.minimum(sparse, 1)
  92. indices = sparse + 2 * tf.range(0, batch_size)
  93. dense = tf.sparse_to_dense(indices, 2 * batch_size, 1.0, 0.0)
  94. reshaped = tf.reshape(dense, [-1, 2])
  95. return tf.reshape(tf.slice(reshaped, [0, 0], [-1, 1]), [-1])
  96. class NeuralGPU(object):
  97. """Neural GPU Model."""
  98. def __init__(self, nmaps, vec_size, niclass, noclass, dropout, rx_step,
  99. max_grad_norm, cutoff, nconvs, kw, kh, height, mode,
  100. learning_rate, pull, pull_incr, min_length):
  101. # Feeds for parameters and ops to update them.
  102. self.global_step = tf.Variable(0, trainable=False)
  103. self.cur_length = tf.Variable(min_length, trainable=False)
  104. self.cur_length_incr_op = self.cur_length.assign_add(1)
  105. self.lr = tf.Variable(float(learning_rate), trainable=False)
  106. self.lr_decay_op = self.lr.assign(self.lr * 0.98)
  107. self.pull = tf.Variable(float(pull), trainable=False)
  108. self.pull_incr_op = self.pull.assign(self.pull * pull_incr)
  109. self.do_training = tf.placeholder(tf.float32, name="do_training")
  110. self.noise_param = tf.placeholder(tf.float32, name="noise_param")
  111. # Feeds for inputs, targets, outputs, losses, etc.
  112. self.input = []
  113. self.target = []
  114. for l in xrange(data_utils.forward_max + 1):
  115. self.input.append(tf.placeholder(tf.int32, name="inp{0}".format(l)))
  116. self.target.append(tf.placeholder(tf.int32, name="tgt{0}".format(l)))
  117. self.outputs = []
  118. self.losses = []
  119. self.grad_norms = []
  120. self.updates = []
  121. # Computation.
  122. inp0_shape = tf.shape(self.input[0])
  123. batch_size = inp0_shape[0]
  124. with tf.device("/cpu:0"):
  125. emb_weights = tf.get_variable(
  126. "embedding", [niclass, vec_size],
  127. initializer=tf.random_uniform_initializer(-1.7, 1.7))
  128. e0 = tf.scatter_update(emb_weights,
  129. tf.constant(0, dtype=tf.int32, shape=[1]),
  130. tf.zeros([1, vec_size]))
  131. adam = tf.train.AdamOptimizer(0.01*self.lr, epsilon=1e-5)
  132. # Main graph creation loop, for every bin in data_utils.
  133. self.steps = []
  134. for length in sorted(list(set(data_utils.bins + [data_utils.forward_max]))):
  135. data_utils.print_out("Creating model for bin of length %d." % length)
  136. start_time = time.time()
  137. if length > data_utils.bins[0]:
  138. tf.get_variable_scope().reuse_variables()
  139. # Embed inputs and calculate mask.
  140. with tf.device("/cpu:0"):
  141. with tf.control_dependencies([e0]):
  142. embedded = [tf.nn.embedding_lookup(emb_weights, self.input[l])
  143. for l in xrange(length)]
  144. # Mask to 0-out padding space in each step.
  145. imask = [check_for_zero(self.input[l]) for l in xrange(length)]
  146. omask = [check_for_zero(self.target[l]) for l in xrange(length)]
  147. mask = [1.0 - (imask[i] * omask[i]) for i in xrange(length)]
  148. mask = [tf.reshape(m, [-1, 1]) for m in mask]
  149. # Use a shifted mask for step scaling and concatenated for weights.
  150. shifted_mask = mask + [tf.zeros_like(mask[0])]
  151. scales = [shifted_mask[i] * (1.0 - shifted_mask[i+1])
  152. for i in xrange(length)]
  153. scales = [tf.reshape(s, [-1, 1, 1, 1]) for s in scales]
  154. mask = tf.concat(1, mask[0:length]) # batch x length
  155. weights = mask
  156. # Add a height dimension to mask to use later for masking.
  157. mask = tf.reshape(mask, [-1, length, 1, 1])
  158. mask = tf.concat(2, [mask for _ in xrange(height)]) + tf.zeros(
  159. tf.pack([batch_size, length, height, nmaps]), dtype=tf.float32)
  160. # Start is a length-list of batch-by-nmaps tensors, reshape and concat.
  161. start = [tf.tanh(embedded[l]) for l in xrange(length)]
  162. start = [tf.reshape(start[l], [-1, 1, nmaps]) for l in xrange(length)]
  163. start = tf.reshape(tf.concat(1, start), [-1, length, 1, nmaps])
  164. # First image comes from start by applying one convolution and adding 0s.
  165. first = conv_linear(start, 1, 1, vec_size, nmaps, True, 0.0, "input")
  166. first = [first] + [tf.zeros(tf.pack([batch_size, length, 1, nmaps]),
  167. dtype=tf.float32) for _ in xrange(height - 1)]
  168. first = tf.concat(2, first)
  169. # Computation steps.
  170. step = [tf.nn.dropout(first, 1.0 - self.do_training * dropout) * mask]
  171. outputs = []
  172. for it in xrange(length):
  173. with tf.variable_scope("RX%d" % (it % rx_step)) as vs:
  174. if it >= rx_step:
  175. vs.reuse_variables()
  176. cur = step[it]
  177. # Do nconvs-many CGRU steps.
  178. for layer in xrange(nconvs):
  179. cur = conv_gru([], cur, kw, kh, nmaps, cutoff, "cgru_%d" % layer)
  180. cur = tf.nn.dropout(cur, 1.0 - self.do_training * dropout)
  181. step.append(cur * mask)
  182. outputs.append(tf.slice(step[-1], [0, 0, 0, 0], [-1, -1, 1, -1]))
  183. self.steps.append([tf.reshape(s, [-1, length, height * nmaps])
  184. for s in step])
  185. # Output is the n-th step output; n = current length, as in scales.
  186. output = tf.add_n([outputs[i] * scales[i] for i in xrange(length)])
  187. # Final convolution to get logits, list outputs.
  188. output = conv_linear(output, 1, 1, nmaps, noclass, True, 0.0, "output")
  189. output = tf.reshape(output, [-1, length, noclass])
  190. self.outputs.append([tf.reshape(o, [-1, noclass])
  191. for o in list(tf.split(1, length, output))])
  192. # Calculate cross-entropy loss and normalize it.
  193. targets = tf.concat(1, [make_dense(self.target[l], noclass)
  194. for l in xrange(length)])
  195. targets = tf.reshape(targets, [-1, noclass])
  196. xent = tf.reshape(tf.nn.softmax_cross_entropy_with_logits(
  197. tf.reshape(output, [-1, noclass]), targets), [-1, length])
  198. perp_loss = tf.reduce_sum(xent * weights)
  199. perp_loss /= tf.cast(batch_size, dtype=tf.float32)
  200. perp_loss /= length
  201. # Final loss: cross-entropy + shared parameter relaxation part.
  202. relax_dist, self.avg_op = relaxed_distance(rx_step)
  203. total_loss = perp_loss + relax_dist * self.pull
  204. self.losses.append(perp_loss)
  205. # Gradients and Adam update operation.
  206. if length == data_utils.bins[0] or (mode == 0 and
  207. length < data_utils.bins[-1] + 1):
  208. data_utils.print_out("Creating backward for bin of length %d." % length)
  209. params = tf.trainable_variables()
  210. grads = tf.gradients(total_loss, params)
  211. grads, norm = tf.clip_by_global_norm(grads, max_grad_norm)
  212. self.grad_norms.append(norm)
  213. for grad in grads:
  214. if isinstance(grad, tf.Tensor):
  215. grad += tf.truncated_normal(tf.shape(grad)) * self.noise_param
  216. update = adam.apply_gradients(zip(grads, params),
  217. global_step=self.global_step)
  218. self.updates.append(update)
  219. data_utils.print_out("Created model for bin of length %d in"
  220. " %.2f s." % (length, time.time() - start_time))
  221. self.saver = tf.train.Saver(tf.all_variables())
  222. def step(self, sess, inp, target, do_backward, noise_param=None):
  223. """Run a step of the network."""
  224. assert len(inp) == len(target)
  225. length = len(target)
  226. feed_in = {}
  227. feed_in[self.noise_param.name] = noise_param if noise_param else 0.0
  228. feed_in[self.do_training.name] = 1.0 if do_backward else 0.0
  229. feed_out = []
  230. index = len(data_utils.bins)
  231. if length < data_utils.bins[-1] + 1:
  232. index = data_utils.bins.index(length)
  233. if do_backward:
  234. feed_out.append(self.updates[index])
  235. feed_out.append(self.grad_norms[index])
  236. feed_out.append(self.losses[index])
  237. for l in xrange(length):
  238. feed_in[self.input[l].name] = inp[l]
  239. for l in xrange(length):
  240. feed_in[self.target[l].name] = target[l]
  241. feed_out.append(self.outputs[index][l])
  242. for l in xrange(length+1):
  243. feed_out.append(self.steps[index][l])
  244. res = sess.run(feed_out, feed_in)
  245. offset = 0
  246. norm = None
  247. if do_backward:
  248. offset = 2
  249. norm = res[1]
  250. outputs = res[offset + 1:offset + 1 + length]
  251. steps = res[offset + 1 + length:]
  252. return res[offset], outputs, norm, steps