Prechádzať zdrojové kódy

add learning to remember rare events

Ofir Nachum 8 rokov pred
rodič
commit
6a9c0da962

+ 55 - 0
learning_to_remember_rare_events/README.md

@@ -0,0 +1,55 @@
+Code for the Memory Module as described
+in "Learning to Remember Rare Events" by
+Lukasz Kaiser, Ofir Nachum, Aurko Roy, and Samy Bengio
+published as a conference paper at ICLR 2017.
+
+Requirements:
+* TensorFlow (see tensorflow.org for how to install)
+* Some basic command-line utilities (git, unzip).
+
+Description:
+
+The general memory module is located in memory.py.
+Some code is provided to see the memory module in
+action on the standard Omniglot dataset.
+Download and setup the dataset using data_utils.py
+and then run the training script train.py
+(see example commands below).
+
+Note that the structure and parameters of the model
+are optimized for the data preparation as provided.
+
+Quick Start:
+
+First download and set-up Omniglot data by running
+
+```
+python data_utils.py
+```
+
+Then run the training script:
+
+```
+python train.py --memory_size=8192 \
+  --batch_size=16 --validation_length=50 \
+  --episode_width=5 --episode_length=30
+```
+
+The first validation batch may look like this (although it is noisy):
+```
+0-shot: 0.040, 1-shot: 0.404, 2-shot: 0.516, 3-shot: 0.604,
+  4-shot: 0.656, 5-shot: 0.684
+```
+At step 500 you may see something like this:
+```
+0-shot: 0.036, 1-shot: 0.836, 2-shot: 0.900, 3-shot: 0.940,
+  4-shot: 0.944, 5-shot: 0.916
+```
+At step 4000 you may see something like this:
+```
+0-shot: 0.044, 1-shot: 0.960, 2-shot: 1.000, 3-shot: 0.988,
+  4-shot: 0.972, 5-shot: 0.992
+```
+
+Maintained by Ofir Nachum (ofirnachum) and
+Lukasz Kaiser (lukaszkaiser).

+ 242 - 0
learning_to_remember_rare_events/data_utils.py

@@ -0,0 +1,242 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+"""Data loading and other utilities.
+
+Use this file to first copy over and pre-process the Omniglot dataset.
+Simply call
+  python data_utils.py
+"""
+
+import cPickle as pickle
+import logging
+import os
+import subprocess
+
+import numpy as np
+from scipy.misc import imresize
+from scipy.misc import imrotate
+from scipy.ndimage import imread
+import tensorflow as tf
+
+
+MAIN_DIR = ''
+REPO_LOCATION = 'https://github.com/brendenlake/omniglot.git'
+REPO_DIR = os.path.join(MAIN_DIR, 'omniglot')
+DATA_DIR = os.path.join(REPO_DIR, 'python')
+TRAIN_DIR = os.path.join(DATA_DIR, 'images_background')
+TEST_DIR = os.path.join(DATA_DIR, 'images_evaluation')
+DATA_FILE_FORMAT = os.path.join(MAIN_DIR, '%s_omni.pkl')
+
+TRAIN_ROTATIONS = True  # augment training data with rotations
+TEST_ROTATIONS = False  # augment testing data with rotations
+IMAGE_ORIGINAL_SIZE = 105
+IMAGE_NEW_SIZE = 28
+
+
+def get_data():
+  """Get data in form suitable for episodic training.
+
+  Returns:
+    Train and test data as dictionaries mapping
+    label to list of examples.
+  """
+  with tf.gfile.GFile(DATA_FILE_FORMAT % 'train') as f:
+    processed_train_data = pickle.load(f)
+  with tf.gfile.GFile(DATA_FILE_FORMAT % 'test') as f:
+    processed_test_data = pickle.load(f)
+
+  train_data = {}
+  test_data = {}
+
+  for data, processed_data in zip([train_data, test_data],
+                                  [processed_train_data, processed_test_data]):
+    for image, label in zip(processed_data['images'],
+                            processed_data['labels']):
+      if label not in data:
+        data[label] = []
+      data[label].append(image.reshape([-1]).astype('float32'))
+
+  intersection = set(train_data.keys()) & set(test_data.keys())
+  assert not intersection, 'Train and test data intersect.'
+  ok_num_examples = [len(ll) == 20 for _, ll in train_data.iteritems()]
+  assert all(ok_num_examples), 'Bad number of examples in train data.'
+  ok_num_examples = [len(ll) == 20 for _, ll in test_data.iteritems()]
+  assert all(ok_num_examples), 'Bad number of examples in test data.'
+
+  logging.info('Number of labels in train data: %d.', len(train_data))
+  logging.info('Number of labels in test data: %d.', len(test_data))
+
+  return train_data, test_data
+
+
+def crawl_directory(directory, augment_with_rotations=False,
+                    first_label=0):
+  """Crawls data directory and returns stuff."""
+  label_idx = first_label
+  images = []
+  labels = []
+  info = []
+
+  # traverse root directory
+  for root, _, files in os.walk(directory):
+    logging.info('Reading files from %s', root)
+    fileflag = 0
+    for file_name in files:
+      full_file_name = os.path.join(root, file_name)
+      img = imread(full_file_name, flatten=True)
+      for i, angle in enumerate([0, 90, 180, 270]):
+        if not augment_with_rotations and i > 0:
+          break
+
+        images.append(imrotate(img, angle))
+        labels.append(label_idx + i)
+        info.append(full_file_name)
+
+      fileflag = 1
+
+    if fileflag:
+      label_idx += 4 if augment_with_rotations else 1
+
+  return images, labels, info
+
+
+def resize_images(images, new_width, new_height):
+  """Resize images to new dimensions."""
+  resized_images = np.zeros([images.shape[0], new_width, new_height],
+                            dtype=np.float32)
+
+  for i in range(images.shape[0]):
+    resized_images[i, :, :] = imresize(images[i, :, :],
+                                       [new_width, new_height],
+                                       interp='bilinear',
+                                       mode=None)
+  return resized_images
+
+
+def write_datafiles(directory, write_file,
+                    resize=True, rotate=False,
+                    new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE,
+                    first_label=0):
+  """Load and preprocess images from a directory and write them to a file.
+
+  Args:
+    directory: Directory of alphabet sub-directories.
+    write_file: Filename to write to.
+    resize: Whether to resize the images.
+    rotate: Whether to augment the dataset with rotations.
+    new_width: New resize width.
+    new_height: New resize height.
+    first_label: Label to start with.
+
+  Returns:
+    Number of new labels created.
+  """
+
+  # these are the default sizes for Omniglot:
+  imgwidth = IMAGE_ORIGINAL_SIZE
+  imgheight = IMAGE_ORIGINAL_SIZE
+
+  logging.info('Reading the data.')
+  images, labels, info = crawl_directory(directory,
+                                         augment_with_rotations=rotate,
+                                         first_label=first_label)
+
+  images_np = np.zeros([len(images), imgwidth, imgheight], dtype=np.bool)
+  labels_np = np.zeros([len(labels)], dtype=np.uint32)
+  for i in xrange(len(images)):
+    images_np[i, :, :] = images[i]
+    labels_np[i] = labels[i]
+
+  if resize:
+    logging.info('Resizing images.')
+    resized_images = resize_images(images_np, new_width, new_height)
+
+    logging.info('Writing resized data in float32 format.')
+    data = {'images': resized_images,
+            'labels': labels_np,
+            'info': info}
+    with tf.gfile.GFile(write_file, 'w') as f:
+      pickle.dump(data, f)
+  else:
+    logging.info('Writing original sized data in boolean format.')
+    data = {'images': images_np,
+            'labels': labels_np,
+            'info': info}
+    with tf.gfile.GFile(write_file, 'w') as f:
+      pickle.dump(data, f)
+
+  return len(np.unique(labels_np))
+
+
+def maybe_download_data():
+  """Download Omniglot repo if it does not exist."""
+  if os.path.exists(REPO_DIR):
+    logging.info('It appears that Git repo already exists.')
+  else:
+    logging.info('It appears that Git repo does not exist.')
+    logging.info('Cloning now.')
+
+    subprocess.check_output('git clone %s' % REPO_LOCATION, shell=True)
+
+  if os.path.exists(TRAIN_DIR):
+    logging.info('It appears that train data has already been unzipped.')
+  else:
+    logging.info('It appears that train data has not been unzipped.')
+    logging.info('Unzipping now.')
+
+    subprocess.check_output('unzip %s.zip -d %s' % (TRAIN_DIR, DATA_DIR),
+                            shell=True)
+
+  if os.path.exists(TEST_DIR):
+    logging.info('It appears that test data has already been unzipped.')
+  else:
+    logging.info('It appears that test data has not been unzipped.')
+    logging.info('Unzipping now.')
+
+    subprocess.check_output('unzip %s.zip -d %s' % (TEST_DIR, DATA_DIR),
+                            shell=True)
+
+
+def preprocess_omniglot():
+  """Download and prepare raw Omniglot data.
+
+  Downloads the data from GitHub if it does not exist.
+  Then load the images, augment with rotations if desired.
+  Resize the images and write them to a pickle file.
+  """
+
+  maybe_download_data()
+
+  directory = TRAIN_DIR
+  write_file = DATA_FILE_FORMAT % 'train'
+  num_labels = write_datafiles(
+      directory, write_file, resize=True, rotate=TRAIN_ROTATIONS,
+      new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE)
+
+  directory = TEST_DIR
+  write_file = DATA_FILE_FORMAT % 'test'
+  write_datafiles(directory, write_file, resize=True, rotate=TEST_ROTATIONS,
+                  new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE,
+                  first_label=num_labels)
+
+
+def main(unused_argv):
+  logging.basicConfig(level=logging.INFO)
+  preprocess_omniglot()
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 385 - 0
learning_to_remember_rare_events/memory.py

@@ -0,0 +1,385 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+"""Memory module for storing "nearest neighbors".
+
+Implements a key-value memory for generalized one-shot learning
+as described in the paper
+"Learning to Remember Rare Events"
+by Lukasz Kaiser, Ofir Nachum, Aurko Roy, Samy Bengio,
+published as a conference paper at ICLR 2017.
+"""
+
+import numpy as np
+import tensorflow as tf
+
+
+class Memory(object):
+  """Memory module."""
+
+  def __init__(self, key_dim, memory_size, vocab_size,
+               choose_k=256, alpha=0.1, correct_in_top=1, age_noise=8.0,
+               var_cache_device='', nn_device=''):
+    self.key_dim = key_dim
+    self.memory_size = memory_size
+    self.vocab_size = vocab_size
+    self.choose_k = min(choose_k, memory_size)
+    self.alpha = alpha
+    self.correct_in_top = correct_in_top
+    self.age_noise = age_noise
+    self.var_cache_device = var_cache_device  # Variables are cached here.
+    self.nn_device = nn_device  # Device to perform nearest neighbour matmul.
+
+    caching_device = var_cache_device if var_cache_device else None
+    self.update_memory = tf.constant(True)  # Can be fed "false" if needed.
+    self.mem_keys = tf.get_variable(
+        'memkeys', [self.memory_size, self.key_dim], trainable=False,
+        initializer=tf.random_uniform_initializer(-0.0, 0.0),
+        caching_device=caching_device)
+    self.mem_vals = tf.get_variable(
+        'memvals', [self.memory_size], dtype=tf.int32, trainable=False,
+        initializer=tf.constant_initializer(0, tf.int32),
+        caching_device=caching_device)
+    self.mem_age = tf.get_variable(
+        'memage', [self.memory_size], dtype=tf.float32, trainable=False,
+        initializer=tf.constant_initializer(0.0), caching_device=caching_device)
+    self.recent_idx = tf.get_variable(
+        'recent_idx', [self.vocab_size], dtype=tf.int32, trainable=False,
+        initializer=tf.constant_initializer(0, tf.int32))
+
+    # variable for projecting query vector into memory key
+    self.query_proj = tf.get_variable(
+        'memory_query_proj', [self.key_dim, self.key_dim], dtype=tf.float32,
+        initializer=tf.truncated_normal_initializer(0, 0.01),
+        caching_device=caching_device)
+
+  def get(self):
+    return self.mem_keys, self.mem_vals, self.mem_age, self.recent_idx
+
+  def set(self, k, v, a, r=None):
+    return tf.group(
+        self.mem_keys.assign(k),
+        self.mem_vals.assign(v),
+        self.mem_age.assign(a),
+        (self.recent_idx.assign(r) if r is not None else tf.group()))
+
+  def clear(self):
+    return tf.variables_initializer([self.mem_keys, self.mem_vals, self.mem_age,
+                                     self.recent_idx])
+
+  def get_hint_pool_idxs(self, normalized_query):
+    """Get small set of idxs to compute nearest neighbor queries on.
+
+    This is an expensive look-up on the whole memory that is used to
+    avoid more expensive operations later on.
+
+    Args:
+      normalized_query: A Tensor of shape [None, key_dim].
+
+    Returns:
+      A Tensor of shape [None, choose_k] of indices in memory
+      that are closest to the queries.
+
+    """
+    # look up in large memory, no gradients
+    with tf.device(self.nn_device):
+      similarities = tf.matmul(tf.stop_gradient(normalized_query),
+                               self.mem_keys, transpose_b=True, name='nn_mmul')
+    _, hint_pool_idxs = tf.nn.top_k(
+        tf.stop_gradient(similarities), k=self.choose_k, name='nn_topk')
+    return hint_pool_idxs
+
+  def make_update_op(self, upd_idxs, upd_keys, upd_vals,
+                     batch_size, use_recent_idx, intended_output):
+    """Function that creates all the update ops."""
+    mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
+                                                   dtype=tf.float32))
+    with tf.control_dependencies([mem_age_incr]):
+      mem_age_upd = tf.scatter_update(
+          self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))
+
+    mem_key_upd = tf.scatter_update(
+        self.mem_keys, upd_idxs, upd_keys)
+    mem_val_upd = tf.scatter_update(
+        self.mem_vals, upd_idxs, upd_vals)
+
+    if use_recent_idx:
+      recent_idx_upd = tf.scatter_update(
+          self.recent_idx, intended_output, upd_idxs)
+    else:
+      recent_idx_upd = tf.group()
+
+    return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)
+
+  def query(self, query_vec, intended_output, use_recent_idx=True):
+    """Queries memory for nearest neighbor.
+
+    Args:
+      query_vec: A batch of vectors to query (embedding of input to model).
+      intended_output: The values that would be the correct output of the
+        memory.
+      use_recent_idx: Whether to always insert at least one instance of a
+        correct memory fetch.
+
+    Returns:
+      A tuple (result, mask, teacher_loss).
+      result: The result of the memory look up.
+      mask: The affinity of the query to the result.
+      teacher_loss: The loss for training the memory module.
+    """
+
+    batch_size = tf.shape(query_vec)[0]
+    output_given = intended_output is not None
+
+    # prepare query for memory lookup
+    query_vec = tf.matmul(query_vec, self.query_proj)
+    normalized_query = tf.nn.l2_normalize(query_vec, dim=1)
+
+    hint_pool_idxs = self.get_hint_pool_idxs(normalized_query)
+
+    if output_given and use_recent_idx:  # add at least one correct memory
+      most_recent_hint_idx = tf.gather(self.recent_idx, intended_output)
+      hint_pool_idxs = tf.concat([hint_pool_idxs,
+                                  tf.expand_dims(most_recent_hint_idx, 1)], 1)
+    choose_k = tf.shape(hint_pool_idxs)[1]
+
+    with tf.device(self.var_cache_device):
+      # create small memory and look up with gradients
+      my_mem_keys = tf.stop_gradient(tf.gather(self.mem_keys, hint_pool_idxs,
+                                               name='my_mem_keys_gather'))
+      similarities = tf.matmul(tf.expand_dims(normalized_query, 1),
+                               my_mem_keys, adjoint_b=True, name='batch_mmul')
+      hint_pool_sims = tf.squeeze(similarities, [1], name='hint_pool_sims')
+      hint_pool_mem_vals = tf.gather(self.mem_vals, hint_pool_idxs,
+                                     name='hint_pool_mem_vals')
+    # Calculate softmax mask on the top-k if requested.
+    # Softmax temperature. Say we have K elements at dist x and one at (x+a).
+    # Softmax of the last is e^tm(x+a)/Ke^tm*x + e^tm(x+a) = e^tm*a/K+e^tm*a.
+    # To make that 20% we'd need to have e^tm*a ~= 0.2K, so tm = log(0.2K)/a.
+    softmax_temp = max(1.0, np.log(0.2 * self.choose_k) / self.alpha)
+    mask = tf.nn.softmax(hint_pool_sims[:, :choose_k - 1] * softmax_temp)
+
+    # prepare hints from the teacher on hint pool
+    teacher_hints = tf.to_float(
+        tf.abs(tf.expand_dims(intended_output, 1) - hint_pool_mem_vals))
+    teacher_hints = 1.0 - tf.minimum(1.0, teacher_hints)
+
+    teacher_vals, teacher_hint_idxs = tf.nn.top_k(
+        hint_pool_sims * teacher_hints, k=1)
+    neg_teacher_vals, _ = tf.nn.top_k(
+        hint_pool_sims * (1 - teacher_hints), k=1)
+
+    # bring back idxs to full memory
+    teacher_idxs = tf.gather(
+        tf.reshape(hint_pool_idxs, [-1]),
+        teacher_hint_idxs[:, 0] + choose_k * tf.range(batch_size))
+
+    # zero-out teacher_vals if there are no hints
+    teacher_vals *= (
+        1 - tf.to_float(tf.equal(0.0, tf.reduce_sum(teacher_hints, 1))))
+
+    # prepare returned values
+    nearest_neighbor = tf.to_int32(
+        tf.argmax(hint_pool_sims[:, :choose_k - 1], 1))
+    no_teacher_idxs = tf.gather(
+        tf.reshape(hint_pool_idxs, [-1]),
+        nearest_neighbor + choose_k * tf.range(batch_size))
+
+    # we'll determine whether to do an update to memory based on whether
+    # memory was queried correctly
+    sliced_hints = tf.slice(teacher_hints, [0, 0], [-1, self.correct_in_top])
+    incorrect_memory_lookup = tf.equal(0.0, tf.reduce_sum(sliced_hints, 1))
+
+    # loss based on triplet loss
+    teacher_loss = (tf.nn.relu(neg_teacher_vals - teacher_vals + self.alpha)
+                    - self.alpha)
+
+    with tf.device(self.var_cache_device):
+      result = tf.gather(self.mem_vals, tf.reshape(no_teacher_idxs, [-1]))
+
+    # prepare memory updates
+    update_keys = normalized_query
+    update_vals = intended_output
+
+    fetched_idxs = teacher_idxs  # correctly fetched from memory
+    with tf.device(self.var_cache_device):
+      fetched_keys = tf.gather(self.mem_keys, fetched_idxs, name='fetched_keys')
+      fetched_vals = tf.gather(self.mem_vals, fetched_idxs, name='fetched_vals')
+
+    # do memory updates here
+    fetched_keys_upd = update_keys + fetched_keys  # Momentum-like update
+    fetched_keys_upd = tf.nn.l2_normalize(fetched_keys_upd, dim=1)
+    # Randomize age a bit, e.g., to select different ones in parallel workers.
+    mem_age_with_noise = self.mem_age + tf.random_uniform(
+        [self.memory_size], - self.age_noise, self.age_noise)
+
+    _, oldest_idxs = tf.nn.top_k(mem_age_with_noise, k=batch_size, sorted=False)
+
+    with tf.control_dependencies([result]):
+      upd_idxs = tf.where(incorrect_memory_lookup,
+                          oldest_idxs,
+                          fetched_idxs)
+      # upd_idxs = tf.Print(upd_idxs, [upd_idxs], "UPD IDX", summarize=8)
+      upd_keys = tf.where(incorrect_memory_lookup,
+                          update_keys,
+                          fetched_keys_upd)
+      upd_vals = tf.where(incorrect_memory_lookup,
+                          update_vals,
+                          fetched_vals)
+
+    def make_update_op():
+      return self.make_update_op(upd_idxs, upd_keys, upd_vals,
+                                 batch_size, use_recent_idx, intended_output)
+
+    update_op = tf.cond(self.update_memory, make_update_op, tf.no_op)
+
+    with tf.control_dependencies([update_op]):
+      result = tf.identity(result)
+      mask = tf.identity(mask)
+      teacher_loss = tf.identity(teacher_loss)
+
+    return result, mask, tf.reduce_mean(teacher_loss)
+
+
+class LSHMemory(Memory):
+  """Memory employing locality sensitive hashing.
+
+  Note: Not fully tested.
+  """
+
+  def __init__(self, key_dim, memory_size, vocab_size,
+               choose_k=256, alpha=0.1, correct_in_top=1, age_noise=8.0,
+               var_cache_device='', nn_device='',
+               num_hashes=None, num_libraries=None):
+    super(LSHMemory, self).__init__(
+        key_dim, memory_size, vocab_size,
+        choose_k=choose_k, alpha=alpha, correct_in_top=1, age_noise=age_noise,
+        var_cache_device=var_cache_device, nn_device=nn_device)
+
+    self.num_libraries = num_libraries or int(self.choose_k ** 0.5)
+    self.num_per_hash_slot = max(1, self.choose_k // self.num_libraries)
+    self.num_hashes = (num_hashes or
+                       int(np.log2(self.memory_size / self.num_per_hash_slot)))
+    self.num_hashes = min(max(self.num_hashes, 1), 20)
+    self.num_hash_slots = 2 ** self.num_hashes
+
+    # hashing vectors
+    self.hash_vecs = [
+        tf.get_variable(
+            'hash_vecs%d' % i, [self.num_hashes, self.key_dim],
+            dtype=tf.float32, trainable=False,
+            initializer=tf.truncated_normal_initializer(0, 1))
+        for i in xrange(self.num_libraries)]
+
+    # map representing which hash slots map to which mem keys
+    self.hash_slots = [
+        tf.get_variable(
+            'hash_slots%d' % i, [self.num_hash_slots, self.num_per_hash_slot],
+            dtype=tf.int32, trainable=False,
+            initializer=tf.random_uniform_initializer(maxval=self.memory_size,
+                                                      dtype=tf.int32))
+        for i in xrange(self.num_libraries)]
+
+  def get(self):  # not implemented
+    return self.mem_keys, self.mem_vals, self.mem_age, self.recent_idx
+
+  def set(self, k, v, a, r=None):  # not implemented
+    return tf.group(
+        self.mem_keys.assign(k),
+        self.mem_vals.assign(v),
+        self.mem_age.assign(a),
+        (self.recent_idx.assign(r) if r is not None else tf.group()))
+
+  def clear(self):
+    return tf.variables_initializer([self.mem_keys, self.mem_vals, self.mem_age,
+                                     self.recent_idx] + self.hash_slots)
+
+  def get_hash_slots(self, query):
+    """Gets hashed-to buckets for batch of queries.
+
+    Args:
+      query: 2-d Tensor of query vectors.
+
+    Returns:
+      A list of hashed-to buckets for each hash function.
+    """
+
+    binary_hash = [
+        tf.less(tf.matmul(query, self.hash_vecs[i], transpose_b=True), 0)
+        for i in xrange(self.num_libraries)]
+    hash_slot_idxs = [
+        tf.reduce_sum(
+            tf.to_int32(binary_hash[i]) *
+            tf.constant([[2 ** i for i in xrange(self.num_hashes)]],
+                        dtype=tf.int32), 1)
+        for i in xrange(self.num_libraries)]
+    return hash_slot_idxs
+
+  def get_hint_pool_idxs(self, normalized_query):
+    """Get small set of idxs to compute nearest neighbor queries on.
+
+    This is an expensive look-up on the whole memory that is used to
+    avoid more expensive operations later on.
+
+    Args:
+      normalized_query: A Tensor of shape [None, key_dim].
+
+    Returns:
+      A Tensor of shape [None, choose_k] of indices in memory
+      that are closest to the queries.
+
+    """
+    # get hash of query vecs
+    hash_slot_idxs = self.get_hash_slots(normalized_query)
+
+    # grab mem idxs in the hash slots
+    hint_pool_idxs = [
+        tf.maximum(tf.minimum(
+            tf.gather(self.hash_slots[i], idxs),
+            self.memory_size - 1), 0)
+        for i, idxs in enumerate(hash_slot_idxs)]
+
+    return tf.concat(hint_pool_idxs, 1)
+
+  def make_update_op(self, upd_idxs, upd_keys, upd_vals,
+                     batch_size, use_recent_idx, intended_output):
+    """Function that creates all the update ops."""
+    base_update_op = super(LSHMemory, self).make_update_op(
+        upd_idxs, upd_keys, upd_vals,
+        batch_size, use_recent_idx, intended_output)
+
+    # compute hash slots to be updated
+    hash_slot_idxs = self.get_hash_slots(upd_keys)
+
+    # make updates
+    update_ops = []
+    with tf.control_dependencies([base_update_op]):
+      for i, slot_idxs in enumerate(hash_slot_idxs):
+        # for each slot, choose which entry to replace
+        entry_idx = tf.random_uniform([batch_size],
+                                      maxval=self.num_per_hash_slot,
+                                      dtype=tf.int32)
+        entry_mul = 1 - tf.one_hot(entry_idx, self.num_per_hash_slot,
+                                   dtype=tf.int32)
+        entry_add = (tf.expand_dims(upd_idxs, 1) *
+                     tf.one_hot(entry_idx, self.num_per_hash_slot,
+                                dtype=tf.int32))
+
+        mul_op = tf.scatter_mul(self.hash_slots[i], slot_idxs, entry_mul)
+        with tf.control_dependencies([mul_op]):
+          add_op = tf.scatter_add(self.hash_slots[i], slot_idxs, entry_add)
+          update_ops.append(add_op)
+
+    return tf.group(*update_ops)

+ 308 - 0
learning_to_remember_rare_events/model.py

@@ -0,0 +1,308 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+"""Model using memory component.
+
+The model embeds images using a standard CNN architecture.
+These embeddings are used as keys to the memory component,
+which returns nearest neighbors.
+"""
+
+import tensorflow as tf
+
+import memory
+
+FLAGS = tf.flags.FLAGS
+
+
+class BasicClassifier(object):
+
+  def __init__(self, output_dim):
+    self.output_dim = output_dim
+
+  def core_builder(self, memory_val, x, y):
+    del x, y
+    y_pred = memory_val
+    loss = 0.0
+
+    return loss, y_pred
+
+
+class LeNet(object):
+  """Standard CNN architecture."""
+
+  def __init__(self, image_size, num_channels, hidden_dim):
+    self.image_size = image_size
+    self.num_channels = num_channels
+    self.hidden_dim = hidden_dim
+    self.matrix_init = tf.truncated_normal_initializer(stddev=0.1)
+    self.vector_init = tf.constant_initializer(0.0)
+
+  def core_builder(self, x):
+    """Embeds x using standard CNN architecture.
+
+    Args:
+      x: Batch of images as a 2-d Tensor [batch_size, -1].
+
+    Returns:
+      A 2-d Tensor [batch_size, hidden_dim] of embedded images.
+    """
+
+    ch1 = 32 * 2  # number of channels in 1st layer
+    ch2 = 64 * 2  # number of channels in 2nd layer
+    conv1_weights = tf.get_variable('conv1_w',
+                                    [3, 3, self.num_channels, ch1],
+                                    initializer=self.matrix_init)
+    conv1_biases = tf.get_variable('conv1_b', [ch1],
+                                   initializer=self.vector_init)
+    conv1a_weights = tf.get_variable('conv1a_w',
+                                     [3, 3, ch1, ch1],
+                                     initializer=self.matrix_init)
+    conv1a_biases = tf.get_variable('conv1a_b', [ch1],
+                                    initializer=self.vector_init)
+
+    conv2_weights = tf.get_variable('conv2_w', [3, 3, ch1, ch2],
+                                    initializer=self.matrix_init)
+    conv2_biases = tf.get_variable('conv2_b', [ch2],
+                                   initializer=self.vector_init)
+    conv2a_weights = tf.get_variable('conv2a_w', [3, 3, ch2, ch2],
+                                     initializer=self.matrix_init)
+    conv2a_biases = tf.get_variable('conv2a_b', [ch2],
+                                    initializer=self.vector_init)
+
+    # fully connected
+    fc1_weights = tf.get_variable(
+        'fc1_w', [self.image_size // 4 * self.image_size // 4 * ch2,
+                  self.hidden_dim], initializer=self.matrix_init)
+    fc1_biases = tf.get_variable('fc1_b', [self.hidden_dim],
+                                 initializer=self.vector_init)
+
+    # define model
+    x = tf.reshape(x,
+                   [-1, self.image_size, self.image_size, self.num_channels])
+    batch_size = tf.shape(x)[0]
+
+    conv1 = tf.nn.conv2d(x, conv1_weights,
+                         strides=[1, 1, 1, 1], padding='SAME')
+    relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases))
+    conv1 = tf.nn.conv2d(relu1, conv1a_weights,
+                         strides=[1, 1, 1, 1], padding='SAME')
+    relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1a_biases))
+
+    pool1 = tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1],
+                           strides=[1, 2, 2, 1], padding='SAME')
+
+    conv2 = tf.nn.conv2d(pool1, conv2_weights,
+                         strides=[1, 1, 1, 1], padding='SAME')
+    relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases))
+    conv2 = tf.nn.conv2d(relu2, conv2a_weights,
+                         strides=[1, 1, 1, 1], padding='SAME')
+    relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2a_biases))
+
+    pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1],
+                           strides=[1, 2, 2, 1], padding='SAME')
+
+    reshape = tf.reshape(pool2, [batch_size, -1])
+    hidden = tf.matmul(reshape, fc1_weights) + fc1_biases
+
+    return hidden
+
+
+class Model(object):
+  """Model for coordinating between CNN embedder and Memory module."""
+
+  def __init__(self, input_dim, output_dim, rep_dim, memory_size, vocab_size,
+               learning_rate=0.0001, use_lsh=False):
+    self.input_dim = input_dim
+    self.output_dim = output_dim
+    self.rep_dim = rep_dim
+    self.memory_size = memory_size
+    self.vocab_size = vocab_size
+    self.learning_rate = learning_rate
+    self.use_lsh = use_lsh
+
+    self.embedder = self.get_embedder()
+    self.memory = self.get_memory()
+    self.classifier = self.get_classifier()
+
+    self.global_step = tf.contrib.framework.get_or_create_global_step()
+
+  def get_embedder(self):
+    return LeNet(int(self.input_dim ** 0.5), 1, self.rep_dim)
+
+  def get_memory(self):
+    cls = memory.LSHMemory if self.use_lsh else memory.Memory
+    return cls(self.rep_dim, self.memory_size, self.vocab_size)
+
+  def get_classifier(self):
+    return BasicClassifier(self.output_dim)
+
+  def core_builder(self, x, y, keep_prob, use_recent_idx=True):
+    embeddings = self.embedder.core_builder(x)
+    if keep_prob < 1.0:
+      embeddings = tf.nn.dropout(embeddings, keep_prob)
+    memory_val, _, teacher_loss = self.memory.query(
+        embeddings, y, use_recent_idx=use_recent_idx)
+    loss, y_pred = self.classifier.core_builder(memory_val, x, y)
+
+    return loss + teacher_loss, y_pred
+
+  def train(self, x, y):
+    loss, _ = self.core_builder(x, y, keep_prob=0.3)
+    gradient_ops = self.training_ops(loss)
+    return loss, gradient_ops
+
+  def eval(self, x, y):
+    _, y_preds = self.core_builder(x, y, keep_prob=1.0,
+                                   use_recent_idx=False)
+    return y_preds
+
+  def get_xy_placeholders(self):
+    return (tf.placeholder(tf.float32, [None, self.input_dim]),
+            tf.placeholder(tf.int32, [None]))
+
+  def setup(self):
+    """Sets up all components of the computation graph."""
+
+    self.x, self.y = self.get_xy_placeholders()
+
+    with tf.variable_scope('core', reuse=None):
+      self.loss, self.gradient_ops = self.train(self.x, self.y)
+    with tf.variable_scope('core', reuse=True):
+      self.y_preds = self.eval(self.x, self.y)
+
+    # setup memory "reset" ops
+    (self.mem_keys, self.mem_vals,
+     self.mem_age, self.recent_idx) = self.memory.get()
+    self.mem_keys_reset = tf.placeholder(self.mem_keys.dtype,
+                                         tf.identity(self.mem_keys).shape)
+    self.mem_vals_reset = tf.placeholder(self.mem_vals.dtype,
+                                         tf.identity(self.mem_vals).shape)
+    self.mem_age_reset = tf.placeholder(self.mem_age.dtype,
+                                        tf.identity(self.mem_age).shape)
+    self.recent_idx_reset = tf.placeholder(self.recent_idx.dtype,
+                                           tf.identity(self.recent_idx).shape)
+    self.mem_reset_op = self.memory.set(self.mem_keys_reset,
+                                        self.mem_vals_reset,
+                                        self.mem_age_reset,
+                                        None)
+
+  def training_ops(self, loss):
+    opt = self.get_optimizer()
+    params = tf.trainable_variables()
+    gradients = tf.gradients(loss, params)
+    clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
+    return opt.apply_gradients(zip(clipped_gradients, params),
+                               global_step=self.global_step)
+
+  def get_optimizer(self):
+    return tf.train.AdamOptimizer(learning_rate=self.learning_rate,
+                                  epsilon=1e-4)
+
+  def one_step(self, sess, x, y):
+    outputs = [self.loss, self.gradient_ops]
+    return sess.run(outputs, feed_dict={self.x: x, self.y: y})
+
+  def episode_step(self, sess, x, y, clear_memory=False):
+    """Performs training steps on episodic input.
+
+    Args:
+      sess: A Tensorflow Session.
+      x: A list of batches of images defining the episode.
+      y: A list of batches of labels corresponding to x.
+      clear_memory: Whether to clear the memory before the episode.
+
+    Returns:
+      List of losses the same length as the episode.
+    """
+
+    outputs = [self.loss, self.gradient_ops]
+
+    if clear_memory:
+      self.clear_memory(sess)
+
+    losses = []
+    for xx, yy in zip(x, y):
+      out = sess.run(outputs, feed_dict={self.x: xx, self.y: yy})
+      loss = out[0]
+      losses.append(loss)
+
+    return losses
+
+  def predict(self, sess, x, y=None):
+    """Predict the labels on a single batch of examples.
+
+    Args:
+      sess: A Tensorflow Session.
+      x: A batch of images.
+      y: The labels for the images in x.
+        This allows for updating the memory.
+
+    Returns:
+      Predicted y.
+    """
+
+    cur_memory = sess.run([self.mem_keys, self.mem_vals,
+                           self.mem_age])
+
+    outputs = [self.y_preds]
+    if y is None:
+      ret = sess.run(outputs, feed_dict={self.x: x})
+    else:
+      ret = sess.run(outputs, feed_dict={self.x: x, self.y: y})
+
+    sess.run([self.mem_reset_op],
+             feed_dict={self.mem_keys_reset: cur_memory[0],
+                        self.mem_vals_reset: cur_memory[1],
+                        self.mem_age_reset: cur_memory[2]})
+
+    return ret
+
+  def episode_predict(self, sess, x, y, clear_memory=False):
+    """Predict the labels on an episode of examples.
+
+    Args:
+      sess: A Tensorflow Session.
+      x: A list of batches of images.
+      y: A list of labels for the images in x.
+        This allows for updating the memory.
+      clear_memory: Whether to clear the memory before the episode.
+
+    Returns:
+      List of predicted y.
+    """
+
+    cur_memory = sess.run([self.mem_keys, self.mem_vals,
+                           self.mem_age])
+
+    if clear_memory:
+      self.clear_memory(sess)
+
+    outputs = [self.y_preds]
+    y_preds = []
+    for xx, yy in zip(x, y):
+      out = sess.run(outputs, feed_dict={self.x: xx, self.y: yy})
+      y_pred = out[0]
+      y_preds.append(y_pred)
+
+    sess.run([self.mem_reset_op],
+             feed_dict={self.mem_keys_reset: cur_memory[0],
+                        self.mem_vals_reset: cur_memory[1],
+                        self.mem_age_reset: cur_memory[2]})
+
+    return y_preds
+
+  def clear_memory(self, sess):
+    sess.run([self.memory.clear()])

+ 241 - 0
learning_to_remember_rare_events/train.py

@@ -0,0 +1,241 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+r"""Script for training model.
+
+Simple command to get up and running:
+  python train.py --memory_size=8192 \
+      --batch_size=16 --validation_length=50 \
+      --episode_width=5 --episode_length=30
+"""
+
+import logging
+import os
+import random
+
+import numpy as np
+import tensorflow as tf
+
+import data_utils
+import model
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_integer('rep_dim', 128,
+                        'dimension of keys to use in memory')
+tf.flags.DEFINE_integer('episode_length', 100, 'length of episode')
+tf.flags.DEFINE_integer('episode_width', 5,
+                        'number of distinct labels in a single episode')
+tf.flags.DEFINE_integer('memory_size', None, 'number of slots in memory. '
+                        'Leave as None to default to episode length')
+tf.flags.DEFINE_integer('batch_size', 16, 'batch size')
+tf.flags.DEFINE_integer('num_episodes', 100000, 'number of training episodes')
+tf.flags.DEFINE_integer('validation_frequency', 20,
+                        'every so many training episodes, '
+                        'assess validation accuracy')
+tf.flags.DEFINE_integer('validation_length', 10,
+                        'number of episodes to use to compute '
+                        'validation accuracy')
+tf.flags.DEFINE_integer('seed', 888, 'random seed for training sampling')
+tf.flags.DEFINE_string('save_dir', '', 'directory to save model to')
+tf.flags.DEFINE_bool('use_lsh', False,
+                     'use locality-sensitive hashing '
+                     '(NOTE: not fully tested)')
+
+
+class Trainer(object):
+  """Class that takes care of training, validating, and checkpointing model."""
+
+  def __init__(self, train_data, valid_data, input_dim, output_dim=None):
+    self.train_data = train_data
+    self.valid_data = valid_data
+    self.input_dim = input_dim
+
+    self.rep_dim = FLAGS.rep_dim
+    self.episode_length = FLAGS.episode_length
+    self.episode_width = FLAGS.episode_width
+    self.batch_size = FLAGS.batch_size
+    self.memory_size = (self.episode_length * self.batch_size
+                        if FLAGS.memory_size is None else FLAGS.memory_size)
+    self.use_lsh = FLAGS.use_lsh
+
+    self.output_dim = (output_dim if output_dim is not None
+                       else self.episode_width)
+
+  def get_model(self):
+    # vocab size is the number of distinct values that
+    # could go into the memory key-value storage
+    vocab_size = self.episode_width * self.batch_size
+    return model.Model(
+        self.input_dim, self.output_dim, self.rep_dim, self.memory_size,
+        vocab_size, use_lsh=self.use_lsh)
+
+  def sample_episode_batch(self, data,
+                           episode_length, episode_width, batch_size):
+    """Generates a random batch for training or validation.
+
+    Structures each element of the batch as an 'episode'.
+    Each episode contains episode_length examples and
+    episode_width distinct labels.
+
+    Args:
+      data: A dictionary mapping label to list of examples.
+      episode_length: Number of examples in each episode.
+      episode_width: Distinct number of labels in each episode.
+      batch_size: Batch size (number of episodes).
+
+    Returns:
+      A tuple (x, y) where x is a list of batches of examples
+      with size episode_length and y is a list of batches of labels.
+    """
+
+    episodes_x = [[] for _ in xrange(episode_length)]
+    episodes_y = [[] for _ in xrange(episode_length)]
+    assert len(data) >= episode_width
+    keys = data.keys()
+    for b in xrange(batch_size):
+      episode_labels = random.sample(keys, episode_width)
+      remainder = episode_length % episode_width
+      remainders = [0] * (episode_width - remainder) + [1] * remainder
+      episode_x = [
+          random.sample(data[lab],
+                        r + (episode_length - remainder) / episode_width)
+          for lab, r in zip(episode_labels, remainders)]
+      episode = sum([[(x, i, ii) for ii, x in enumerate(xx)]
+                     for i, xx in enumerate(episode_x)], [])
+      random.shuffle(episode)
+      # Arrange episode so that each distinct label is seen before moving to
+      # 2nd showing
+      episode.sort(key=lambda elem: elem[2])
+      assert len(episode) == episode_length
+      for i in xrange(episode_length):
+        episodes_x[i].append(episode[i][0])
+        episodes_y[i].append(episode[i][1] + b * episode_width)
+
+    return ([np.array(xx).astype('float32') for xx in episodes_x],
+            [np.array(yy).astype('int32') for yy in episodes_y])
+
+  def compute_correct(self, ys, y_preds):
+    return np.mean(np.equal(y_preds, np.array(ys)))
+
+  def individual_compute_correct(self, y, y_pred):
+    return y_pred == y
+
+  def run(self):
+    """Performs training.
+
+    Trains a model using episodic training.
+    Every so often, runs some evaluations on validation data.
+    """
+
+    train_data, valid_data = self.train_data, self.valid_data
+    input_dim, output_dim = self.input_dim, self.output_dim
+    rep_dim, episode_length = self.rep_dim, self.episode_length
+    episode_width, memory_size = self.episode_width, self.memory_size
+    batch_size = self.batch_size
+
+    train_size = len(train_data)
+    valid_size = len(valid_data)
+    logging.info('train_size (number of labels) %d', train_size)
+    logging.info('valid_size (number of labels) %d', valid_size)
+    logging.info('input_dim %d', input_dim)
+    logging.info('output_dim %d', output_dim)
+    logging.info('rep_dim %d', rep_dim)
+    logging.info('episode_length %d', episode_length)
+    logging.info('episode_width %d', episode_width)
+    logging.info('memory_size %d', memory_size)
+    logging.info('batch_size %d', batch_size)
+
+    assert all(len(v) >= float(episode_length) / episode_width
+               for v in train_data.itervalues())
+    assert all(len(v) >= float(episode_length) / episode_width
+               for v in valid_data.itervalues())
+
+    output_dim = episode_width
+    self.model = self.get_model()
+    self.model.setup()
+
+    sess = tf.Session()
+    sess.run(tf.initialize_all_variables())
+
+    saver = tf.train.Saver(max_to_keep=10)
+    ckpt = None
+    if FLAGS.save_dir:
+      ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir)
+    if ckpt and ckpt.model_checkpoint_path:
+      logging.info('restoring from %s', ckpt.model_checkpoint_path)
+      saver.restore(sess, ckpt.model_checkpoint_path)
+
+    logging.info('starting now')
+    losses = []
+    random.seed(FLAGS.seed)
+    np.random.seed(FLAGS.seed)
+    for i in xrange(FLAGS.num_episodes):
+      x, y = self.sample_episode_batch(
+          train_data, episode_length, episode_width, batch_size)
+      outputs = self.model.episode_step(sess, x, y, clear_memory=True)
+      loss = outputs
+      losses.append(loss)
+
+      if i % FLAGS.validation_frequency == 0:
+        logging.info('episode batch %d, avg train loss %f',
+                     i, np.mean(losses))
+        losses = []
+
+        # validation
+        correct = []
+        correct_by_shot = dict((k, []) for k in xrange(self.episode_width + 1))
+        for _ in xrange(FLAGS.validation_length):
+          x, y = self.sample_episode_batch(
+              valid_data, episode_length, episode_width, 1)
+          outputs = self.model.episode_predict(
+              sess, x, y, clear_memory=True)
+          y_preds = outputs
+          correct.append(self.compute_correct(np.array(y), y_preds))
+
+          # compute per-shot accuracies
+          seen_counts = [[0] * episode_width for _ in xrange(batch_size)]
+          # loop over episode steps
+          for yy, yy_preds in zip(y, y_preds):
+            # loop over batch examples
+            for k, (yyy, yyy_preds) in enumerate(zip(yy, yy_preds)):
+              yyy, yyy_preds = int(yyy), int(yyy_preds)
+              count = seen_counts[k][yyy % self.episode_width]
+              if count in correct_by_shot:
+                correct_by_shot[count].append(
+                    self.individual_compute_correct(yyy, yyy_preds))
+              seen_counts[k][yyy % self.episode_width] = count + 1
+
+        logging.info('validation overall accuracy %f', np.mean(correct))
+        logging.info('%d-shot: %.3f, ' * (self.episode_width + 1),
+                     *sum([[k, np.mean(correct_by_shot[k])]
+                           for k in xrange(self.episode_width + 1)], []))
+
+        if saver and FLAGS.save_dir:
+          saved_file = saver.save(sess,
+                                  os.path.join(FLAGS.save_dir, 'model.ckpt'),
+                                  global_step=self.model.global_step)
+          logging.info('saved model to %s', saved_file)
+
+
+def main(unused_argv):
+  train_data, valid_data = data_utils.get_data()
+  trainer = Trainer(train_data, valid_data, data_utils.IMAGE_NEW_SIZE ** 2)
+  trainer.run()
+
+
+if __name__ == '__main__':
+  logging.basicConfig(level=logging.INFO)
+  tf.app.run()