neural_gpu.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. # Copyright 2015 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """The Neural GPU Model."""
  16. import time
  17. import tensorflow as tf
  18. import data_utils
  19. def conv_linear(args, kw, kh, nin, nout, do_bias, bias_start, prefix):
  20. """Convolutional linear map."""
  21. assert args
  22. if not isinstance(args, (list, tuple)):
  23. args = [args]
  24. with tf.variable_scope(prefix):
  25. k = tf.get_variable("CvK", [kw, kh, nin, nout])
  26. if len(args) == 1:
  27. res = tf.nn.conv2d(args[0], k, [1, 1, 1, 1], "SAME")
  28. else:
  29. res = tf.nn.conv2d(tf.concat(3, args), k, [1, 1, 1, 1], "SAME")
  30. if not do_bias: return res
  31. bias_term = tf.get_variable("CvB", [nout],
  32. initializer=tf.constant_initializer(0.0))
  33. return res + bias_term + bias_start
  34. def sigmoid_cutoff(x, cutoff):
  35. """Sigmoid with cutoff, e.g., 1.2sigmoid(x) - 0.1."""
  36. y = tf.sigmoid(x)
  37. if cutoff < 1.01: return y
  38. d = (cutoff - 1.0) / 2.0
  39. return tf.minimum(1.0, tf.maximum(0.0, cutoff * y - d))
  40. def conv_gru(inpts, mem, kw, kh, nmaps, cutoff, prefix):
  41. """Convolutional GRU."""
  42. def conv_lin(args, suffix, bias_start):
  43. return conv_linear(args, kw, kh, len(args) * nmaps, nmaps, True, bias_start,
  44. prefix + "/" + suffix)
  45. reset = sigmoid_cutoff(conv_lin(inpts + [mem], "r", 1.0), cutoff)
  46. candidate = tf.tanh(conv_lin(inpts + [reset * mem], "c", 0.0))
  47. gate = sigmoid_cutoff(conv_lin(inpts + [mem], "g", 1.0), cutoff)
  48. return gate * mem + (1 - gate) * candidate
  49. def relaxed_average(var_name_suffix, rx_step):
  50. """Calculate the average of relaxed variables having var_name_suffix."""
  51. relaxed_vars = []
  52. for l in xrange(rx_step):
  53. with tf.variable_scope("RX%d" % l, reuse=True):
  54. try:
  55. relaxed_vars.append(tf.get_variable(var_name_suffix))
  56. except ValueError:
  57. pass
  58. dsum = tf.add_n(relaxed_vars)
  59. avg = dsum / len(relaxed_vars)
  60. diff = [v - avg for v in relaxed_vars]
  61. davg = tf.add_n([d*d for d in diff])
  62. return avg, tf.reduce_sum(davg)
  63. def relaxed_distance(rx_step):
  64. """Distance between relaxed variables and their average."""
  65. res, ops, rx_done = [], [], {}
  66. for v in tf.trainable_variables():
  67. if v.name[0:2] == "RX":
  68. rx_name = v.op.name[v.name.find("/") + 1:]
  69. if rx_name not in rx_done:
  70. avg, dist_loss = relaxed_average(rx_name, rx_step)
  71. res.append(dist_loss)
  72. rx_done[rx_name] = avg
  73. ops.append(v.assign(rx_done[rx_name]))
  74. return tf.add_n(res), tf.group(*ops)
  75. def make_dense(targets, noclass):
  76. """Move a batch of targets to a dense 1-hot representation."""
  77. with tf.device("/cpu:0"):
  78. shape = tf.shape(targets)
  79. batch_size = shape[0]
  80. indices = targets + noclass * tf.range(0, batch_size)
  81. length = tf.expand_dims(batch_size * noclass, 0)
  82. dense = tf.sparse_to_dense(indices, length, 1.0, 0.0)
  83. return tf.reshape(dense, [-1, noclass])
  84. def check_for_zero(sparse):
  85. """In a sparse batch of ints, make 1.0 if it's 0 and 0.0 else."""
  86. with tf.device("/cpu:0"):
  87. shape = tf.shape(sparse)
  88. batch_size = shape[0]
  89. sparse = tf.minimum(sparse, 1)
  90. indices = sparse + 2 * tf.range(0, batch_size)
  91. dense = tf.sparse_to_dense(indices, tf.expand_dims(2 * batch_size, 0),
  92. 1.0, 0.0)
  93. reshaped = tf.reshape(dense, [-1, 2])
  94. return tf.reshape(tf.slice(reshaped, [0, 0], [-1, 1]), [-1])
  95. class NeuralGPU(object):
  96. """Neural GPU Model."""
  97. def __init__(self, nmaps, vec_size, niclass, noclass, dropout, rx_step,
  98. max_grad_norm, cutoff, nconvs, kw, kh, height, mode,
  99. learning_rate, pull, pull_incr, min_length):
  100. # Feeds for parameters and ops to update them.
  101. self.global_step = tf.Variable(0, trainable=False)
  102. self.cur_length = tf.Variable(min_length, trainable=False)
  103. self.cur_length_incr_op = self.cur_length.assign_add(1)
  104. self.lr = tf.Variable(float(learning_rate), trainable=False)
  105. self.lr_decay_op = self.lr.assign(self.lr * 0.98)
  106. self.pull = tf.Variable(float(pull), trainable=False)
  107. self.pull_incr_op = self.pull.assign(self.pull * pull_incr)
  108. self.do_training = tf.placeholder(tf.float32, name="do_training")
  109. self.noise_param = tf.placeholder(tf.float32, name="noise_param")
  110. # Feeds for inputs, targets, outputs, losses, etc.
  111. self.input = []
  112. self.target = []
  113. for l in xrange(data_utils.forward_max + 1):
  114. self.input.append(tf.placeholder(tf.int32, name="inp{0}".format(l)))
  115. self.target.append(tf.placeholder(tf.int32, name="tgt{0}".format(l)))
  116. self.outputs = []
  117. self.losses = []
  118. self.grad_norms = []
  119. self.updates = []
  120. # Computation.
  121. inp0_shape = tf.shape(self.input[0])
  122. batch_size = inp0_shape[0]
  123. with tf.device("/cpu:0"):
  124. emb_weights = tf.get_variable(
  125. "embedding", [niclass, vec_size],
  126. initializer=tf.random_uniform_initializer(-1.7, 1.7))
  127. e0 = tf.scatter_update(emb_weights,
  128. tf.constant(0, dtype=tf.int32, shape=[1]),
  129. tf.zeros([1, vec_size]))
  130. adam = tf.train.AdamOptimizer(self.lr, epsilon=1e-4)
  131. # Main graph creation loop, for every bin in data_utils.
  132. self.steps = []
  133. for length in sorted(list(set(data_utils.bins + [data_utils.forward_max]))):
  134. data_utils.print_out("Creating model for bin of length %d." % length)
  135. start_time = time.time()
  136. if length > data_utils.bins[0]:
  137. tf.get_variable_scope().reuse_variables()
  138. # Embed inputs and calculate mask.
  139. with tf.device("/cpu:0"):
  140. with tf.control_dependencies([e0]):
  141. embedded = [tf.nn.embedding_lookup(emb_weights, self.input[l])
  142. for l in xrange(length)]
  143. # Mask to 0-out padding space in each step.
  144. imask = [check_for_zero(self.input[l]) for l in xrange(length)]
  145. omask = [check_for_zero(self.target[l]) for l in xrange(length)]
  146. mask = [1.0 - (imask[i] * omask[i]) for i in xrange(length)]
  147. mask = [tf.reshape(m, [-1, 1]) for m in mask]
  148. # Use a shifted mask for step scaling and concatenated for weights.
  149. shifted_mask = mask + [tf.zeros_like(mask[0])]
  150. scales = [shifted_mask[i] * (1.0 - shifted_mask[i+1])
  151. for i in xrange(length)]
  152. scales = [tf.reshape(s, [-1, 1, 1, 1]) for s in scales]
  153. mask = tf.concat(1, mask[0:length]) # batch x length
  154. weights = mask
  155. # Add a height dimension to mask to use later for masking.
  156. mask = tf.reshape(mask, [-1, length, 1, 1])
  157. mask = tf.concat(2, [mask for _ in xrange(height)]) + tf.zeros(
  158. tf.pack([batch_size, length, height, nmaps]), dtype=tf.float32)
  159. # Start is a length-list of batch-by-nmaps tensors, reshape and concat.
  160. start = [tf.tanh(embedded[l]) for l in xrange(length)]
  161. start = [tf.reshape(start[l], [-1, 1, nmaps]) for l in xrange(length)]
  162. start = tf.reshape(tf.concat(1, start), [-1, length, 1, nmaps])
  163. # First image comes from start by applying one convolution and adding 0s.
  164. first = conv_linear(start, 1, 1, vec_size, nmaps, True, 0.0, "input")
  165. first = [first] + [tf.zeros(tf.pack([batch_size, length, 1, nmaps]),
  166. dtype=tf.float32) for _ in xrange(height - 1)]
  167. first = tf.concat(2, first)
  168. # Computation steps.
  169. step = [tf.nn.dropout(first, 1.0 - self.do_training * dropout) * mask]
  170. outputs = []
  171. for it in xrange(length):
  172. with tf.variable_scope("RX%d" % (it % rx_step)) as vs:
  173. if it >= rx_step:
  174. vs.reuse_variables()
  175. cur = step[it]
  176. # Do nconvs-many CGRU steps.
  177. for layer in xrange(nconvs):
  178. cur = conv_gru([], cur, kw, kh, nmaps, cutoff, "cgru_%d" % layer)
  179. cur = tf.nn.dropout(cur, 1.0 - self.do_training * dropout)
  180. step.append(cur * mask)
  181. outputs.append(tf.slice(step[-1], [0, 0, 0, 0], [-1, -1, 1, -1]))
  182. self.steps.append([tf.reshape(s, [-1, length, height * nmaps])
  183. for s in step])
  184. # Output is the n-th step output; n = current length, as in scales.
  185. output = tf.add_n([outputs[i] * scales[i] for i in xrange(length)])
  186. # Final convolution to get logits, list outputs.
  187. output = conv_linear(output, 1, 1, nmaps, noclass, True, 0.0, "output")
  188. output = tf.reshape(output, [-1, length, noclass])
  189. self.outputs.append([tf.reshape(o, [-1, noclass])
  190. for o in list(tf.split(1, length, output))])
  191. # Calculate cross-entropy loss and normalize it.
  192. targets = tf.concat(1, [make_dense(self.target[l], noclass)
  193. for l in xrange(length)])
  194. targets = tf.reshape(targets, [-1, noclass])
  195. xent = tf.reshape(tf.nn.softmax_cross_entropy_with_logits(
  196. tf.reshape(output, [-1, noclass]), targets), [-1, length])
  197. perp_loss = tf.reduce_sum(xent * weights)
  198. perp_loss /= tf.cast(batch_size, dtype=tf.float32)
  199. perp_loss /= length
  200. # Final loss: cross-entropy + shared parameter relaxation part.
  201. relax_dist, self.avg_op = relaxed_distance(rx_step)
  202. total_loss = perp_loss + relax_dist * self.pull
  203. self.losses.append(perp_loss)
  204. # Gradients and Adam update operation.
  205. if length == data_utils.bins[0] or (mode == 0 and
  206. length < data_utils.bins[-1] + 1):
  207. data_utils.print_out("Creating backward for bin of length %d." % length)
  208. params = tf.trainable_variables()
  209. grads = tf.gradients(total_loss, params)
  210. grads, norm = tf.clip_by_global_norm(grads, max_grad_norm)
  211. self.grad_norms.append(norm)
  212. for grad in grads:
  213. if isinstance(grad, tf.Tensor):
  214. grad += tf.truncated_normal(tf.shape(grad)) * self.noise_param
  215. update = adam.apply_gradients(zip(grads, params),
  216. global_step=self.global_step)
  217. self.updates.append(update)
  218. data_utils.print_out("Created model for bin of length %d in"
  219. " %.2f s." % (length, time.time() - start_time))
  220. self.saver = tf.train.Saver(tf.all_variables())
  221. def step(self, sess, inp, target, do_backward, noise_param=None):
  222. """Run a step of the network."""
  223. assert len(inp) == len(target)
  224. length = len(target)
  225. feed_in = {}
  226. feed_in[self.noise_param.name] = noise_param if noise_param else 0.0
  227. feed_in[self.do_training.name] = 1.0 if do_backward else 0.0
  228. feed_out = []
  229. index = len(data_utils.bins)
  230. if length < data_utils.bins[-1] + 1:
  231. index = data_utils.bins.index(length)
  232. if do_backward:
  233. feed_out.append(self.updates[index])
  234. feed_out.append(self.grad_norms[index])
  235. feed_out.append(self.losses[index])
  236. for l in xrange(length):
  237. feed_in[self.input[l].name] = inp[l]
  238. for l in xrange(length):
  239. feed_in[self.target[l].name] = target[l]
  240. feed_out.append(self.outputs[index][l])
  241. for l in xrange(length+1):
  242. feed_out.append(self.steps[index][l])
  243. res = sess.run(feed_out, feed_in)
  244. offset = 0
  245. norm = None
  246. if do_backward:
  247. offset = 2
  248. norm = res[1]
  249. outputs = res[offset + 1:offset + 1 + length]
  250. steps = res[offset + 1 + length:]
  251. return res[offset], outputs, norm, steps