neural_gpu.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Copyright 2015 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """The Neural GPU Model."""
  16. import time
  17. import tensorflow as tf
  18. import data_utils
  19. def conv_linear(args, kw, kh, nin, nout, do_bias, bias_start, prefix):
  20. """Convolutional linear map."""
  21. assert args is not None
  22. if not isinstance(args, (list, tuple)):
  23. args = [args]
  24. with tf.variable_scope(prefix):
  25. k = tf.get_variable("CvK", [kw, kh, nin, nout])
  26. if len(args) == 1:
  27. res = tf.nn.conv2d(args[0], k, [1, 1, 1, 1], "SAME")
  28. else:
  29. res = tf.nn.conv2d(tf.concat(3, args), k, [1, 1, 1, 1], "SAME")
  30. if not do_bias: return res
  31. bias_term = tf.get_variable("CvB", [nout],
  32. initializer=tf.constant_initializer(0.0))
  33. return res + bias_term + bias_start
  34. def sigmoid_cutoff(x, cutoff):
  35. """Sigmoid with cutoff, e.g., 1.2sigmoid(x) - 0.1."""
  36. y = tf.sigmoid(x)
  37. if cutoff < 1.01: return y
  38. d = (cutoff - 1.0) / 2.0
  39. return tf.minimum(1.0, tf.maximum(0.0, cutoff * y - d))
  40. def tanh_cutoff(x, cutoff):
  41. """Tanh with cutoff, e.g., 1.1tanh(x) cut to [-1. 1]."""
  42. y = tf.tanh(x)
  43. if cutoff < 1.01: return y
  44. d = (cutoff - 1.0) / 2.0
  45. return tf.minimum(1.0, tf.maximum(-1.0, (1.0 + d) * y))
  46. def conv_gru(inpts, mem, kw, kh, nmaps, cutoff, prefix):
  47. """Convolutional GRU."""
  48. def conv_lin(args, suffix, bias_start):
  49. return conv_linear(args, kw, kh, len(args) * nmaps, nmaps, True, bias_start,
  50. prefix + "/" + suffix)
  51. reset = sigmoid_cutoff(conv_lin(inpts + [mem], "r", 1.0), cutoff)
  52. # candidate = tanh_cutoff(conv_lin(inpts + [reset * mem], "c", 0.0), cutoff)
  53. candidate = tf.tanh(conv_lin(inpts + [reset * mem], "c", 0.0))
  54. gate = sigmoid_cutoff(conv_lin(inpts + [mem], "g", 1.0), cutoff)
  55. return gate * mem + (1 - gate) * candidate
  56. @tf.RegisterGradient("CustomIdG")
  57. def _custom_id_grad(_, grads):
  58. return grads
  59. def quantize(t, quant_scale, max_value=1.0):
  60. """Quantize a tensor t with each element in [-max_value, max_value]."""
  61. t = tf.minimum(max_value, tf.maximum(t, -max_value))
  62. big = quant_scale * (t + max_value) + 0.5
  63. with tf.get_default_graph().gradient_override_map({"Floor": "CustomIdG"}):
  64. res = (tf.floor(big) / quant_scale) - max_value
  65. return res
  66. def quantize_weights_op(quant_scale, max_value):
  67. ops = [v.assign(quantize(v, quant_scale, float(max_value)))
  68. for v in tf.trainable_variables()]
  69. return tf.group(*ops)
  70. def relaxed_average(var_name_suffix, rx_step):
  71. """Calculate the average of relaxed variables having var_name_suffix."""
  72. relaxed_vars = []
  73. for l in xrange(rx_step):
  74. with tf.variable_scope("RX%d" % l, reuse=True):
  75. try:
  76. relaxed_vars.append(tf.get_variable(var_name_suffix))
  77. except ValueError:
  78. pass
  79. dsum = tf.add_n(relaxed_vars)
  80. avg = dsum / len(relaxed_vars)
  81. diff = [v - avg for v in relaxed_vars]
  82. davg = tf.add_n([d*d for d in diff])
  83. return avg, tf.reduce_sum(davg)
  84. def relaxed_distance(rx_step):
  85. """Distance between relaxed variables and their average."""
  86. res, ops, rx_done = [], [], {}
  87. for v in tf.trainable_variables():
  88. if v.name[0:2] == "RX":
  89. rx_name = v.op.name[v.name.find("/") + 1:]
  90. if rx_name not in rx_done:
  91. avg, dist_loss = relaxed_average(rx_name, rx_step)
  92. res.append(dist_loss)
  93. rx_done[rx_name] = avg
  94. ops.append(v.assign(rx_done[rx_name]))
  95. return tf.add_n(res), tf.group(*ops)
  96. def make_dense(targets, noclass):
  97. """Move a batch of targets to a dense 1-hot representation."""
  98. with tf.device("/cpu:0"):
  99. shape = tf.shape(targets)
  100. batch_size = shape[0]
  101. indices = targets + noclass * tf.range(0, batch_size)
  102. length = tf.expand_dims(batch_size * noclass, 0)
  103. dense = tf.sparse_to_dense(indices, length, 1.0, 0.0)
  104. return tf.reshape(dense, [-1, noclass])
  105. def check_for_zero(sparse):
  106. """In a sparse batch of ints, make 1.0 if it's 0 and 0.0 else."""
  107. with tf.device("/cpu:0"):
  108. shape = tf.shape(sparse)
  109. batch_size = shape[0]
  110. sparse = tf.minimum(sparse, 1)
  111. indices = sparse + 2 * tf.range(0, batch_size)
  112. dense = tf.sparse_to_dense(indices, tf.expand_dims(2 * batch_size, 0),
  113. 1.0, 0.0)
  114. reshaped = tf.reshape(dense, [-1, 2])
  115. return tf.reshape(tf.slice(reshaped, [0, 0], [-1, 1]), [-1])
  116. class NeuralGPU(object):
  117. """Neural GPU Model."""
  118. def __init__(self, nmaps, vec_size, niclass, noclass, dropout, rx_step,
  119. max_grad_norm, cutoff, nconvs, kw, kh, height, mode,
  120. learning_rate, pull, pull_incr, min_length, act_noise=0.0):
  121. # Feeds for parameters and ops to update them.
  122. self.global_step = tf.Variable(0, trainable=False)
  123. self.cur_length = tf.Variable(min_length, trainable=False)
  124. self.cur_length_incr_op = self.cur_length.assign_add(1)
  125. self.lr = tf.Variable(float(learning_rate), trainable=False)
  126. self.lr_decay_op = self.lr.assign(self.lr * 0.98)
  127. self.pull = tf.Variable(float(pull), trainable=False)
  128. self.pull_incr_op = self.pull.assign(self.pull * pull_incr)
  129. self.do_training = tf.placeholder(tf.float32, name="do_training")
  130. self.noise_param = tf.placeholder(tf.float32, name="noise_param")
  131. # Feeds for inputs, targets, outputs, losses, etc.
  132. self.input = []
  133. self.target = []
  134. for l in xrange(data_utils.forward_max + 1):
  135. self.input.append(tf.placeholder(tf.int32, name="inp{0}".format(l)))
  136. self.target.append(tf.placeholder(tf.int32, name="tgt{0}".format(l)))
  137. self.outputs = []
  138. self.losses = []
  139. self.grad_norms = []
  140. self.updates = []
  141. # Computation.
  142. inp0_shape = tf.shape(self.input[0])
  143. batch_size = inp0_shape[0]
  144. with tf.device("/cpu:0"):
  145. emb_weights = tf.get_variable(
  146. "embedding", [niclass, vec_size],
  147. initializer=tf.random_uniform_initializer(-1.7, 1.7))
  148. e0 = tf.scatter_update(emb_weights,
  149. tf.constant(0, dtype=tf.int32, shape=[1]),
  150. tf.zeros([1, vec_size]))
  151. adam = tf.train.AdamOptimizer(self.lr, epsilon=1e-4)
  152. # Main graph creation loop, for every bin in data_utils.
  153. self.steps = []
  154. for length in sorted(list(set(data_utils.bins + [data_utils.forward_max]))):
  155. data_utils.print_out("Creating model for bin of length %d." % length)
  156. start_time = time.time()
  157. if length > data_utils.bins[0]:
  158. tf.get_variable_scope().reuse_variables()
  159. # Embed inputs and calculate mask.
  160. with tf.device("/cpu:0"):
  161. with tf.control_dependencies([e0]):
  162. embedded = [tf.nn.embedding_lookup(emb_weights, self.input[l])
  163. for l in xrange(length)]
  164. # Mask to 0-out padding space in each step.
  165. imask = [check_for_zero(self.input[l]) for l in xrange(length)]
  166. omask = [check_for_zero(self.target[l]) for l in xrange(length)]
  167. mask = [1.0 - (imask[i] * omask[i]) for i in xrange(length)]
  168. mask = [tf.reshape(m, [-1, 1]) for m in mask]
  169. # Use a shifted mask for step scaling and concatenated for weights.
  170. shifted_mask = mask + [tf.zeros_like(mask[0])]
  171. scales = [shifted_mask[i] * (1.0 - shifted_mask[i+1])
  172. for i in xrange(length)]
  173. scales = [tf.reshape(s, [-1, 1, 1, 1]) for s in scales]
  174. mask = tf.concat(1, mask[0:length]) # batch x length
  175. weights = mask
  176. # Add a height dimension to mask to use later for masking.
  177. mask = tf.reshape(mask, [-1, length, 1, 1])
  178. mask = tf.concat(2, [mask for _ in xrange(height)]) + tf.zeros(
  179. tf.pack([batch_size, length, height, nmaps]), dtype=tf.float32)
  180. # Start is a length-list of batch-by-nmaps tensors, reshape and concat.
  181. start = [tf.tanh(embedded[l]) for l in xrange(length)]
  182. start = [tf.reshape(start[l], [-1, 1, nmaps]) for l in xrange(length)]
  183. start = tf.reshape(tf.concat(1, start), [-1, length, 1, nmaps])
  184. # First image comes from start by applying one convolution and adding 0s.
  185. first = conv_linear(start, 1, 1, vec_size, nmaps, True, 0.0, "input")
  186. first = [first] + [tf.zeros(tf.pack([batch_size, length, 1, nmaps]),
  187. dtype=tf.float32) for _ in xrange(height - 1)]
  188. first = tf.concat(2, first)
  189. # Computation steps.
  190. keep_prob = 1.0 - self.do_training * (dropout * 8.0 / float(length))
  191. step = [tf.nn.dropout(first, keep_prob) * mask]
  192. act_noise_scale = act_noise * self.do_training * self.pull
  193. outputs = []
  194. for it in xrange(length):
  195. with tf.variable_scope("RX%d" % (it % rx_step)) as vs:
  196. if it >= rx_step:
  197. vs.reuse_variables()
  198. cur = step[it]
  199. # Do nconvs-many CGRU steps.
  200. for layer in xrange(nconvs):
  201. cur = conv_gru([], cur, kw, kh, nmaps, cutoff, "cgru_%d" % layer)
  202. cur *= mask
  203. outputs.append(tf.slice(cur, [0, 0, 0, 0], [-1, -1, 1, -1]))
  204. cur = tf.nn.dropout(cur, keep_prob)
  205. if act_noise > 0.00001:
  206. cur += tf.truncated_normal(tf.shape(cur)) * act_noise_scale
  207. step.append(cur * mask)
  208. self.steps.append([tf.reshape(s, [-1, length, height * nmaps])
  209. for s in step])
  210. # Output is the n-th step output; n = current length, as in scales.
  211. output = tf.add_n([outputs[i] * scales[i] for i in xrange(length)])
  212. # Final convolution to get logits, list outputs.
  213. output = conv_linear(output, 1, 1, nmaps, noclass, True, 0.0, "output")
  214. output = tf.reshape(output, [-1, length, noclass])
  215. external_output = [tf.reshape(o, [-1, noclass])
  216. for o in list(tf.split(1, length, output))]
  217. external_output = [tf.nn.softmax(o) for o in external_output]
  218. self.outputs.append(external_output)
  219. # Calculate cross-entropy loss and normalize it.
  220. targets = tf.concat(1, [make_dense(self.target[l], noclass)
  221. for l in xrange(length)])
  222. targets = tf.reshape(targets, [-1, noclass])
  223. xent = tf.reshape(tf.nn.softmax_cross_entropy_with_logits(
  224. tf.reshape(output, [-1, noclass]), targets), [-1, length])
  225. perp_loss = tf.reduce_sum(xent * weights)
  226. perp_loss /= tf.cast(batch_size, dtype=tf.float32)
  227. perp_loss /= length
  228. # Final loss: cross-entropy + shared parameter relaxation part.
  229. relax_dist, self.avg_op = relaxed_distance(rx_step)
  230. total_loss = perp_loss + relax_dist * self.pull
  231. self.losses.append(perp_loss)
  232. # Gradients and Adam update operation.
  233. if length == data_utils.bins[0] or (mode == 0 and
  234. length < data_utils.bins[-1] + 1):
  235. data_utils.print_out("Creating backward for bin of length %d." % length)
  236. params = tf.trainable_variables()
  237. grads = tf.gradients(total_loss, params)
  238. grads, norm = tf.clip_by_global_norm(grads, max_grad_norm)
  239. self.grad_norms.append(norm)
  240. for grad in grads:
  241. if isinstance(grad, tf.Tensor):
  242. grad += tf.truncated_normal(tf.shape(grad)) * self.noise_param
  243. update = adam.apply_gradients(zip(grads, params),
  244. global_step=self.global_step)
  245. self.updates.append(update)
  246. data_utils.print_out("Created model for bin of length %d in"
  247. " %.2f s." % (length, time.time() - start_time))
  248. self.saver = tf.train.Saver(tf.all_variables())
  249. def step(self, sess, inp, target, do_backward, noise_param=None,
  250. get_steps=False):
  251. """Run a step of the network."""
  252. assert len(inp) == len(target)
  253. length = len(target)
  254. feed_in = {}
  255. feed_in[self.noise_param.name] = noise_param if noise_param else 0.0
  256. feed_in[self.do_training.name] = 1.0 if do_backward else 0.0
  257. feed_out = []
  258. index = len(data_utils.bins)
  259. if length < data_utils.bins[-1] + 1:
  260. index = data_utils.bins.index(length)
  261. if do_backward:
  262. feed_out.append(self.updates[index])
  263. feed_out.append(self.grad_norms[index])
  264. feed_out.append(self.losses[index])
  265. for l in xrange(length):
  266. feed_in[self.input[l].name] = inp[l]
  267. for l in xrange(length):
  268. feed_in[self.target[l].name] = target[l]
  269. feed_out.append(self.outputs[index][l])
  270. if get_steps:
  271. for l in xrange(length+1):
  272. feed_out.append(self.steps[index][l])
  273. res = sess.run(feed_out, feed_in)
  274. offset = 0
  275. norm = None
  276. if do_backward:
  277. offset = 2
  278. norm = res[1]
  279. outputs = res[offset + 1:offset + 1 + length]
  280. steps = res[offset + 1 + length:] if get_steps else None
  281. return res[offset], outputs, norm, steps