memory.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # ==============================================================================
  16. """Memory module for storing "nearest neighbors".
  17. Implements a key-value memory for generalized one-shot learning
  18. as described in the paper
  19. "Learning to Remember Rare Events"
  20. by Lukasz Kaiser, Ofir Nachum, Aurko Roy, Samy Bengio,
  21. published as a conference paper at ICLR 2017.
  22. """
  23. import numpy as np
  24. import tensorflow as tf
  25. class Memory(object):
  26. """Memory module."""
  27. def __init__(self, key_dim, memory_size, vocab_size,
  28. choose_k=256, alpha=0.1, correct_in_top=1, age_noise=8.0,
  29. var_cache_device='', nn_device=''):
  30. self.key_dim = key_dim
  31. self.memory_size = memory_size
  32. self.vocab_size = vocab_size
  33. self.choose_k = min(choose_k, memory_size)
  34. self.alpha = alpha
  35. self.correct_in_top = correct_in_top
  36. self.age_noise = age_noise
  37. self.var_cache_device = var_cache_device # Variables are cached here.
  38. self.nn_device = nn_device # Device to perform nearest neighbour matmul.
  39. caching_device = var_cache_device if var_cache_device else None
  40. self.update_memory = tf.constant(True) # Can be fed "false" if needed.
  41. self.mem_keys = tf.get_variable(
  42. 'memkeys', [self.memory_size, self.key_dim], trainable=False,
  43. initializer=tf.random_uniform_initializer(-0.0, 0.0),
  44. caching_device=caching_device)
  45. self.mem_vals = tf.get_variable(
  46. 'memvals', [self.memory_size], dtype=tf.int32, trainable=False,
  47. initializer=tf.constant_initializer(0, tf.int32),
  48. caching_device=caching_device)
  49. self.mem_age = tf.get_variable(
  50. 'memage', [self.memory_size], dtype=tf.float32, trainable=False,
  51. initializer=tf.constant_initializer(0.0), caching_device=caching_device)
  52. self.recent_idx = tf.get_variable(
  53. 'recent_idx', [self.vocab_size], dtype=tf.int32, trainable=False,
  54. initializer=tf.constant_initializer(0, tf.int32))
  55. # variable for projecting query vector into memory key
  56. self.query_proj = tf.get_variable(
  57. 'memory_query_proj', [self.key_dim, self.key_dim], dtype=tf.float32,
  58. initializer=tf.truncated_normal_initializer(0, 0.01),
  59. caching_device=caching_device)
  60. def get(self):
  61. return self.mem_keys, self.mem_vals, self.mem_age, self.recent_idx
  62. def set(self, k, v, a, r=None):
  63. return tf.group(
  64. self.mem_keys.assign(k),
  65. self.mem_vals.assign(v),
  66. self.mem_age.assign(a),
  67. (self.recent_idx.assign(r) if r is not None else tf.group()))
  68. def clear(self):
  69. return tf.variables_initializer([self.mem_keys, self.mem_vals, self.mem_age,
  70. self.recent_idx])
  71. def get_hint_pool_idxs(self, normalized_query):
  72. """Get small set of idxs to compute nearest neighbor queries on.
  73. This is an expensive look-up on the whole memory that is used to
  74. avoid more expensive operations later on.
  75. Args:
  76. normalized_query: A Tensor of shape [None, key_dim].
  77. Returns:
  78. A Tensor of shape [None, choose_k] of indices in memory
  79. that are closest to the queries.
  80. """
  81. # look up in large memory, no gradients
  82. with tf.device(self.nn_device):
  83. similarities = tf.matmul(tf.stop_gradient(normalized_query),
  84. self.mem_keys, transpose_b=True, name='nn_mmul')
  85. _, hint_pool_idxs = tf.nn.top_k(
  86. tf.stop_gradient(similarities), k=self.choose_k, name='nn_topk')
  87. return hint_pool_idxs
  88. def make_update_op(self, upd_idxs, upd_keys, upd_vals,
  89. batch_size, use_recent_idx, intended_output):
  90. """Function that creates all the update ops."""
  91. mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
  92. dtype=tf.float32))
  93. with tf.control_dependencies([mem_age_incr]):
  94. mem_age_upd = tf.scatter_update(
  95. self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))
  96. mem_key_upd = tf.scatter_update(
  97. self.mem_keys, upd_idxs, upd_keys)
  98. mem_val_upd = tf.scatter_update(
  99. self.mem_vals, upd_idxs, upd_vals)
  100. if use_recent_idx:
  101. recent_idx_upd = tf.scatter_update(
  102. self.recent_idx, intended_output, upd_idxs)
  103. else:
  104. recent_idx_upd = tf.group()
  105. return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)
  106. def query(self, query_vec, intended_output, use_recent_idx=True):
  107. """Queries memory for nearest neighbor.
  108. Args:
  109. query_vec: A batch of vectors to query (embedding of input to model).
  110. intended_output: The values that would be the correct output of the
  111. memory.
  112. use_recent_idx: Whether to always insert at least one instance of a
  113. correct memory fetch.
  114. Returns:
  115. A tuple (result, mask, teacher_loss).
  116. result: The result of the memory look up.
  117. mask: The affinity of the query to the result.
  118. teacher_loss: The loss for training the memory module.
  119. """
  120. batch_size = tf.shape(query_vec)[0]
  121. output_given = intended_output is not None
  122. # prepare query for memory lookup
  123. query_vec = tf.matmul(query_vec, self.query_proj)
  124. normalized_query = tf.nn.l2_normalize(query_vec, dim=1)
  125. hint_pool_idxs = self.get_hint_pool_idxs(normalized_query)
  126. if output_given and use_recent_idx: # add at least one correct memory
  127. most_recent_hint_idx = tf.gather(self.recent_idx, intended_output)
  128. hint_pool_idxs = tf.concat(
  129. axis=1,
  130. values=[hint_pool_idxs, tf.expand_dims(most_recent_hint_idx, 1)])
  131. choose_k = tf.shape(hint_pool_idxs)[1]
  132. with tf.device(self.var_cache_device):
  133. # create small memory and look up with gradients
  134. my_mem_keys = tf.stop_gradient(tf.gather(self.mem_keys, hint_pool_idxs,
  135. name='my_mem_keys_gather'))
  136. similarities = tf.matmul(tf.expand_dims(normalized_query, 1),
  137. my_mem_keys, adjoint_b=True, name='batch_mmul')
  138. hint_pool_sims = tf.squeeze(similarities, [1], name='hint_pool_sims')
  139. hint_pool_mem_vals = tf.gather(self.mem_vals, hint_pool_idxs,
  140. name='hint_pool_mem_vals')
  141. # Calculate softmax mask on the top-k if requested.
  142. # Softmax temperature. Say we have K elements at dist x and one at (x+a).
  143. # Softmax of the last is e^tm(x+a)/Ke^tm*x + e^tm(x+a) = e^tm*a/K+e^tm*a.
  144. # To make that 20% we'd need to have e^tm*a ~= 0.2K, so tm = log(0.2K)/a.
  145. softmax_temp = max(1.0, np.log(0.2 * self.choose_k) / self.alpha)
  146. mask = tf.nn.softmax(hint_pool_sims[:, :choose_k - 1] * softmax_temp)
  147. # prepare hints from the teacher on hint pool
  148. teacher_hints = tf.to_float(
  149. tf.abs(tf.expand_dims(intended_output, 1) - hint_pool_mem_vals))
  150. teacher_hints = 1.0 - tf.minimum(1.0, teacher_hints)
  151. teacher_vals, teacher_hint_idxs = tf.nn.top_k(
  152. hint_pool_sims * teacher_hints, k=1)
  153. neg_teacher_vals, _ = tf.nn.top_k(
  154. hint_pool_sims * (1 - teacher_hints), k=1)
  155. # bring back idxs to full memory
  156. teacher_idxs = tf.gather(
  157. tf.reshape(hint_pool_idxs, [-1]),
  158. teacher_hint_idxs[:, 0] + choose_k * tf.range(batch_size))
  159. # zero-out teacher_vals if there are no hints
  160. teacher_vals *= (
  161. 1 - tf.to_float(tf.equal(0.0, tf.reduce_sum(teacher_hints, 1))))
  162. # prepare returned values
  163. nearest_neighbor = tf.to_int32(
  164. tf.argmax(hint_pool_sims[:, :choose_k - 1], 1))
  165. no_teacher_idxs = tf.gather(
  166. tf.reshape(hint_pool_idxs, [-1]),
  167. nearest_neighbor + choose_k * tf.range(batch_size))
  168. # we'll determine whether to do an update to memory based on whether
  169. # memory was queried correctly
  170. sliced_hints = tf.slice(teacher_hints, [0, 0], [-1, self.correct_in_top])
  171. incorrect_memory_lookup = tf.equal(0.0, tf.reduce_sum(sliced_hints, 1))
  172. # loss based on triplet loss
  173. teacher_loss = (tf.nn.relu(neg_teacher_vals - teacher_vals + self.alpha)
  174. - self.alpha)
  175. with tf.device(self.var_cache_device):
  176. result = tf.gather(self.mem_vals, tf.reshape(no_teacher_idxs, [-1]))
  177. # prepare memory updates
  178. update_keys = normalized_query
  179. update_vals = intended_output
  180. fetched_idxs = teacher_idxs # correctly fetched from memory
  181. with tf.device(self.var_cache_device):
  182. fetched_keys = tf.gather(self.mem_keys, fetched_idxs, name='fetched_keys')
  183. fetched_vals = tf.gather(self.mem_vals, fetched_idxs, name='fetched_vals')
  184. # do memory updates here
  185. fetched_keys_upd = update_keys + fetched_keys # Momentum-like update
  186. fetched_keys_upd = tf.nn.l2_normalize(fetched_keys_upd, dim=1)
  187. # Randomize age a bit, e.g., to select different ones in parallel workers.
  188. mem_age_with_noise = self.mem_age + tf.random_uniform(
  189. [self.memory_size], - self.age_noise, self.age_noise)
  190. _, oldest_idxs = tf.nn.top_k(mem_age_with_noise, k=batch_size, sorted=False)
  191. with tf.control_dependencies([result]):
  192. upd_idxs = tf.where(incorrect_memory_lookup,
  193. oldest_idxs,
  194. fetched_idxs)
  195. # upd_idxs = tf.Print(upd_idxs, [upd_idxs], "UPD IDX", summarize=8)
  196. upd_keys = tf.where(incorrect_memory_lookup,
  197. update_keys,
  198. fetched_keys_upd)
  199. upd_vals = tf.where(incorrect_memory_lookup,
  200. update_vals,
  201. fetched_vals)
  202. def make_update_op():
  203. return self.make_update_op(upd_idxs, upd_keys, upd_vals,
  204. batch_size, use_recent_idx, intended_output)
  205. update_op = tf.cond(self.update_memory, make_update_op, tf.no_op)
  206. with tf.control_dependencies([update_op]):
  207. result = tf.identity(result)
  208. mask = tf.identity(mask)
  209. teacher_loss = tf.identity(teacher_loss)
  210. return result, mask, tf.reduce_mean(teacher_loss)
  211. class LSHMemory(Memory):
  212. """Memory employing locality sensitive hashing.
  213. Note: Not fully tested.
  214. """
  215. def __init__(self, key_dim, memory_size, vocab_size,
  216. choose_k=256, alpha=0.1, correct_in_top=1, age_noise=8.0,
  217. var_cache_device='', nn_device='',
  218. num_hashes=None, num_libraries=None):
  219. super(LSHMemory, self).__init__(
  220. key_dim, memory_size, vocab_size,
  221. choose_k=choose_k, alpha=alpha, correct_in_top=1, age_noise=age_noise,
  222. var_cache_device=var_cache_device, nn_device=nn_device)
  223. self.num_libraries = num_libraries or int(self.choose_k ** 0.5)
  224. self.num_per_hash_slot = max(1, self.choose_k // self.num_libraries)
  225. self.num_hashes = (num_hashes or
  226. int(np.log2(self.memory_size / self.num_per_hash_slot)))
  227. self.num_hashes = min(max(self.num_hashes, 1), 20)
  228. self.num_hash_slots = 2 ** self.num_hashes
  229. # hashing vectors
  230. self.hash_vecs = [
  231. tf.get_variable(
  232. 'hash_vecs%d' % i, [self.num_hashes, self.key_dim],
  233. dtype=tf.float32, trainable=False,
  234. initializer=tf.truncated_normal_initializer(0, 1))
  235. for i in xrange(self.num_libraries)]
  236. # map representing which hash slots map to which mem keys
  237. self.hash_slots = [
  238. tf.get_variable(
  239. 'hash_slots%d' % i, [self.num_hash_slots, self.num_per_hash_slot],
  240. dtype=tf.int32, trainable=False,
  241. initializer=tf.random_uniform_initializer(maxval=self.memory_size,
  242. dtype=tf.int32))
  243. for i in xrange(self.num_libraries)]
  244. def get(self): # not implemented
  245. return self.mem_keys, self.mem_vals, self.mem_age, self.recent_idx
  246. def set(self, k, v, a, r=None): # not implemented
  247. return tf.group(
  248. self.mem_keys.assign(k),
  249. self.mem_vals.assign(v),
  250. self.mem_age.assign(a),
  251. (self.recent_idx.assign(r) if r is not None else tf.group()))
  252. def clear(self):
  253. return tf.variables_initializer([self.mem_keys, self.mem_vals, self.mem_age,
  254. self.recent_idx] + self.hash_slots)
  255. def get_hash_slots(self, query):
  256. """Gets hashed-to buckets for batch of queries.
  257. Args:
  258. query: 2-d Tensor of query vectors.
  259. Returns:
  260. A list of hashed-to buckets for each hash function.
  261. """
  262. binary_hash = [
  263. tf.less(tf.matmul(query, self.hash_vecs[i], transpose_b=True), 0)
  264. for i in xrange(self.num_libraries)]
  265. hash_slot_idxs = [
  266. tf.reduce_sum(
  267. tf.to_int32(binary_hash[i]) *
  268. tf.constant([[2 ** i for i in xrange(self.num_hashes)]],
  269. dtype=tf.int32), 1)
  270. for i in xrange(self.num_libraries)]
  271. return hash_slot_idxs
  272. def get_hint_pool_idxs(self, normalized_query):
  273. """Get small set of idxs to compute nearest neighbor queries on.
  274. This is an expensive look-up on the whole memory that is used to
  275. avoid more expensive operations later on.
  276. Args:
  277. normalized_query: A Tensor of shape [None, key_dim].
  278. Returns:
  279. A Tensor of shape [None, choose_k] of indices in memory
  280. that are closest to the queries.
  281. """
  282. # get hash of query vecs
  283. hash_slot_idxs = self.get_hash_slots(normalized_query)
  284. # grab mem idxs in the hash slots
  285. hint_pool_idxs = [
  286. tf.maximum(tf.minimum(
  287. tf.gather(self.hash_slots[i], idxs),
  288. self.memory_size - 1), 0)
  289. for i, idxs in enumerate(hash_slot_idxs)]
  290. return tf.concat(axis=1, values=hint_pool_idxs)
  291. def make_update_op(self, upd_idxs, upd_keys, upd_vals,
  292. batch_size, use_recent_idx, intended_output):
  293. """Function that creates all the update ops."""
  294. base_update_op = super(LSHMemory, self).make_update_op(
  295. upd_idxs, upd_keys, upd_vals,
  296. batch_size, use_recent_idx, intended_output)
  297. # compute hash slots to be updated
  298. hash_slot_idxs = self.get_hash_slots(upd_keys)
  299. # make updates
  300. update_ops = []
  301. with tf.control_dependencies([base_update_op]):
  302. for i, slot_idxs in enumerate(hash_slot_idxs):
  303. # for each slot, choose which entry to replace
  304. entry_idx = tf.random_uniform([batch_size],
  305. maxval=self.num_per_hash_slot,
  306. dtype=tf.int32)
  307. entry_mul = 1 - tf.one_hot(entry_idx, self.num_per_hash_slot,
  308. dtype=tf.int32)
  309. entry_add = (tf.expand_dims(upd_idxs, 1) *
  310. tf.one_hot(entry_idx, self.num_per_hash_slot,
  311. dtype=tf.int32))
  312. mul_op = tf.scatter_mul(self.hash_slots[i], slot_idxs, entry_mul)
  313. with tf.control_dependencies([mul_op]):
  314. add_op = tf.scatter_add(self.hash_slots[i], slot_idxs, entry_add)
  315. update_ops.append(add_op)
  316. return tf.group(*update_ops)