neural_gpu.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. """The Neural GPU Model."""
  2. import time
  3. import tensorflow as tf
  4. import data_utils
  5. def conv_linear(args, kw, kh, nin, nout, do_bias, bias_start, prefix):
  6. """Convolutional linear map."""
  7. assert args
  8. if not isinstance(args, (list, tuple)):
  9. args = [args]
  10. with tf.variable_scope(prefix):
  11. k = tf.get_variable("CvK", [kw, kh, nin, nout])
  12. if len(args) == 1:
  13. res = tf.nn.conv2d(args[0], k, [1, 1, 1, 1], "SAME")
  14. else:
  15. res = tf.nn.conv2d(tf.concat(3, args), k, [1, 1, 1, 1], "SAME")
  16. if not do_bias: return res
  17. bias_term = tf.get_variable("CvB", [nout],
  18. initializer=tf.constant_initializer(0.0))
  19. return res + bias_term + bias_start
  20. def sigmoid_cutoff(x, cutoff):
  21. """Sigmoid with cutoff, e.g., 1.2sigmoid(x) - 0.1."""
  22. y = tf.sigmoid(x)
  23. if cutoff < 1.01: return y
  24. d = (cutoff - 1.0) / 2.0
  25. return tf.minimum(1.0, tf.maximum(0.0, cutoff * y - d))
  26. def conv_gru(inpts, mem, kw, kh, nmaps, cutoff, prefix):
  27. """Convolutional GRU."""
  28. def conv_lin(args, suffix, bias_start):
  29. return conv_linear(args, kw, kh, len(args) * nmaps, nmaps, True, bias_start,
  30. prefix + "/" + suffix)
  31. reset = sigmoid_cutoff(conv_lin(inpts + [mem], "r", 1.0), cutoff)
  32. candidate = tf.tanh(conv_lin(inpts + [reset * mem], "c", 0.0))
  33. gate = sigmoid_cutoff(conv_lin(inpts + [mem], "g", 1.0), cutoff)
  34. return gate * mem + (1 - gate) * candidate
  35. def relaxed_average(var_name_suffix, rx_step):
  36. """Calculate the average of relaxed variables having var_name_suffix."""
  37. relaxed_vars = []
  38. for l in xrange(rx_step):
  39. with tf.variable_scope("RX%d" % l, reuse=True):
  40. try:
  41. relaxed_vars.append(tf.get_variable(var_name_suffix))
  42. except ValueError:
  43. pass
  44. dsum = tf.add_n(relaxed_vars)
  45. avg = dsum / len(relaxed_vars)
  46. diff = [v - avg for v in relaxed_vars]
  47. davg = tf.add_n([d*d for d in diff])
  48. return avg, tf.reduce_sum(davg)
  49. def relaxed_distance(rx_step):
  50. """Distance between relaxed variables and their average."""
  51. res, ops, rx_done = [], [], {}
  52. for v in tf.trainable_variables():
  53. if v.name[0:2] == "RX":
  54. rx_name = v.op.name[v.name.find("/") + 1:]
  55. if rx_name not in rx_done:
  56. avg, dist_loss = relaxed_average(rx_name, rx_step)
  57. res.append(dist_loss)
  58. rx_done[rx_name] = avg
  59. ops.append(v.assign(rx_done[rx_name]))
  60. return tf.add_n(res), tf.group(*ops)
  61. def make_dense(targets, noclass):
  62. """Move a batch of targets to a dense 1-hot representation."""
  63. with tf.device("/cpu:0"):
  64. shape = tf.shape(targets)
  65. batch_size = shape[0]
  66. indices = targets + noclass * tf.range(0, batch_size)
  67. length = batch_size * noclass
  68. dense = tf.sparse_to_dense(indices, length, 1.0, 0.0)
  69. return tf.reshape(dense, [-1, noclass])
  70. def check_for_zero(sparse):
  71. """In a sparse batch of ints, make 1.0 if it's 0 and 0.0 else."""
  72. with tf.device("/cpu:0"):
  73. shape = tf.shape(sparse)
  74. batch_size = shape[0]
  75. sparse = tf.minimum(sparse, 1)
  76. indices = sparse + 2 * tf.range(0, batch_size)
  77. dense = tf.sparse_to_dense(indices, 2 * batch_size, 1.0, 0.0)
  78. reshaped = tf.reshape(dense, [-1, 2])
  79. return tf.reshape(tf.slice(reshaped, [0, 0], [-1, 1]), [-1])
  80. class NeuralGPU(object):
  81. """Neural GPU Model."""
  82. def __init__(self, nmaps, vec_size, niclass, noclass, dropout, rx_step,
  83. max_grad_norm, cutoff, nconvs, kw, kh, height, mode,
  84. learning_rate, pull, pull_incr, min_length):
  85. # Feeds for parameters and ops to update them.
  86. self.global_step = tf.Variable(0, trainable=False)
  87. self.cur_length = tf.Variable(min_length, trainable=False)
  88. self.cur_length_incr_op = self.cur_length.assign_add(1)
  89. self.lr = tf.Variable(float(learning_rate), trainable=False)
  90. self.lr_decay_op = self.lr.assign(self.lr * 0.98)
  91. self.pull = tf.Variable(float(pull), trainable=False)
  92. self.pull_incr_op = self.pull.assign(self.pull * pull_incr)
  93. self.do_training = tf.placeholder(tf.float32, name="do_training")
  94. self.noise_param = tf.placeholder(tf.float32, name="noise_param")
  95. # Feeds for inputs, targets, outputs, losses, etc.
  96. self.input = []
  97. self.target = []
  98. for l in xrange(data_utils.forward_max + 1):
  99. self.input.append(tf.placeholder(tf.int32, name="inp{0}".format(l)))
  100. self.target.append(tf.placeholder(tf.int32, name="tgt{0}".format(l)))
  101. self.outputs = []
  102. self.losses = []
  103. self.grad_norms = []
  104. self.updates = []
  105. # Computation.
  106. inp0_shape = tf.shape(self.input[0])
  107. batch_size = inp0_shape[0]
  108. with tf.device("/cpu:0"):
  109. emb_weights = tf.get_variable(
  110. "embedding", [niclass, vec_size],
  111. initializer=tf.random_uniform_initializer(-1.7, 1.7))
  112. e0 = tf.scatter_update(emb_weights,
  113. tf.constant(0, dtype=tf.int32, shape=[1]),
  114. tf.zeros([1, vec_size]))
  115. adam = tf.train.AdamOptimizer(0.01*self.lr, epsilon=1e-5)
  116. # Main graph creation loop, for every bin in data_utils.
  117. self.steps = []
  118. for length in sorted(list(set(data_utils.bins + [data_utils.forward_max]))):
  119. data_utils.print_out("Creating model for bin of length %d." % length)
  120. start_time = time.time()
  121. if length > data_utils.bins[0]:
  122. tf.get_variable_scope().reuse_variables()
  123. # Embed inputs and calculate mask.
  124. with tf.device("/cpu:0"):
  125. with tf.control_dependencies([e0]):
  126. embedded = [tf.nn.embedding_lookup(emb_weights, self.input[l])
  127. for l in xrange(length)]
  128. # Mask to 0-out padding space in each step.
  129. imask = [check_for_zero(self.input[l]) for l in xrange(length)]
  130. omask = [check_for_zero(self.target[l]) for l in xrange(length)]
  131. mask = [1.0 - (imask[i] * omask[i]) for i in xrange(length)]
  132. mask = [tf.reshape(m, [-1, 1]) for m in mask]
  133. # Use a shifted mask for step scaling and concatenated for weights.
  134. shifted_mask = mask + [tf.zeros_like(mask[0])]
  135. scales = [shifted_mask[i] * (1.0 - shifted_mask[i+1])
  136. for i in xrange(length)]
  137. scales = [tf.reshape(s, [-1, 1, 1, 1]) for s in scales]
  138. mask = tf.concat(1, mask[0:length]) # batch x length
  139. weights = mask
  140. # Add a height dimension to mask to use later for masking.
  141. mask = tf.reshape(mask, [-1, length, 1, 1])
  142. mask = tf.concat(2, [mask for _ in xrange(height)]) + tf.zeros(
  143. tf.pack([batch_size, length, height, nmaps]), dtype=tf.float32)
  144. # Start is a length-list of batch-by-nmaps tensors, reshape and concat.
  145. start = [tf.tanh(embedded[l]) for l in xrange(length)]
  146. start = [tf.reshape(start[l], [-1, 1, nmaps]) for l in xrange(length)]
  147. start = tf.reshape(tf.concat(1, start), [-1, length, 1, nmaps])
  148. # First image comes from start by applying one convolution and adding 0s.
  149. first = conv_linear(start, 1, 1, vec_size, nmaps, True, 0.0, "input")
  150. first = [first] + [tf.zeros(tf.pack([batch_size, length, 1, nmaps]),
  151. dtype=tf.float32) for _ in xrange(height - 1)]
  152. first = tf.concat(2, first)
  153. # Computation steps.
  154. step = [tf.nn.dropout(first, 1.0 - self.do_training * dropout) * mask]
  155. outputs = []
  156. for it in xrange(length):
  157. with tf.variable_scope("RX%d" % (it % rx_step)) as vs:
  158. if it >= rx_step:
  159. vs.reuse_variables()
  160. cur = step[it]
  161. # Do nconvs-many CGRU steps.
  162. for layer in xrange(nconvs):
  163. cur = conv_gru([], cur, kw, kh, nmaps, cutoff, "cgru_%d" % layer)
  164. cur = tf.nn.dropout(cur, 1.0 - self.do_training * dropout)
  165. step.append(cur * mask)
  166. outputs.append(tf.slice(step[-1], [0, 0, 0, 0], [-1, -1, 1, -1]))
  167. self.steps.append([tf.reshape(s, [-1, length, height * nmaps])
  168. for s in step])
  169. # Output is the n-th step output; n = current length, as in scales.
  170. output = tf.add_n([outputs[i] * scales[i] for i in xrange(length)])
  171. # Final convolution to get logits, list outputs.
  172. output = conv_linear(output, 1, 1, nmaps, noclass, True, 0.0, "output")
  173. output = tf.reshape(output, [-1, length, noclass])
  174. self.outputs.append([tf.reshape(o, [-1, noclass])
  175. for o in list(tf.split(1, length, output))])
  176. # Calculate cross-entropy loss and normalize it.
  177. targets = tf.concat(1, [make_dense(self.target[l], noclass)
  178. for l in xrange(length)])
  179. targets = tf.reshape(targets, [-1, noclass])
  180. xent = tf.reshape(tf.nn.softmax_cross_entropy_with_logits(
  181. tf.reshape(output, [-1, noclass]), targets), [-1, length])
  182. perp_loss = tf.reduce_sum(xent * weights)
  183. perp_loss /= tf.cast(batch_size, dtype=tf.float32)
  184. perp_loss /= length
  185. # Final loss: cross-entropy + shared parameter relaxation part.
  186. relax_dist, self.avg_op = relaxed_distance(rx_step)
  187. total_loss = perp_loss + relax_dist * self.pull
  188. self.losses.append(perp_loss)
  189. # Gradients and Adam update operation.
  190. if length == data_utils.bins[0] or (mode == 0 and
  191. length < data_utils.bins[-1] + 1):
  192. data_utils.print_out("Creating backward for bin of length %d." % length)
  193. params = tf.trainable_variables()
  194. grads = tf.gradients(total_loss, params)
  195. grads, norm = tf.clip_by_global_norm(grads, max_grad_norm)
  196. self.grad_norms.append(norm)
  197. for grad in grads:
  198. if isinstance(grad, tf.Tensor):
  199. grad += tf.truncated_normal(tf.shape(grad)) * self.noise_param
  200. update = adam.apply_gradients(zip(grads, params),
  201. global_step=self.global_step)
  202. self.updates.append(update)
  203. data_utils.print_out("Created model for bin of length %d in"
  204. " %.2f s." % (length, time.time() - start_time))
  205. self.saver = tf.train.Saver(tf.all_variables())
  206. def step(self, sess, inp, target, do_backward, noise_param=None):
  207. """Run a step of the network."""
  208. assert len(inp) == len(target)
  209. length = len(target)
  210. feed_in = {}
  211. feed_in[self.noise_param.name] = noise_param if noise_param else 0.0
  212. feed_in[self.do_training.name] = 1.0 if do_backward else 0.0
  213. feed_out = []
  214. index = len(data_utils.bins)
  215. if length < data_utils.bins[-1] + 1:
  216. index = data_utils.bins.index(length)
  217. if do_backward:
  218. feed_out.append(self.updates[index])
  219. feed_out.append(self.grad_norms[index])
  220. feed_out.append(self.losses[index])
  221. for l in xrange(length):
  222. feed_in[self.input[l].name] = inp[l]
  223. for l in xrange(length):
  224. feed_in[self.target[l].name] = target[l]
  225. feed_out.append(self.outputs[index][l])
  226. for l in xrange(length+1):
  227. feed_out.append(self.steps[index][l])
  228. res = sess.run(feed_out, feed_in)
  229. offset = 0
  230. norm = None
  231. if do_backward:
  232. offset = 2
  233. norm = res[1]
  234. outputs = res[offset + 1:offset + 1 + length]
  235. steps = res[offset + 1 + length:]
  236. return res[offset], outputs, norm, steps