train.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # ==============================================================================
  16. r"""Script for training model.
  17. Simple command to get up and running:
  18. python train.py --memory_size=8192 \
  19. --batch_size=16 --validation_length=50 \
  20. --episode_width=5 --episode_length=30
  21. """
  22. import logging
  23. import os
  24. import random
  25. import numpy as np
  26. import tensorflow as tf
  27. import data_utils
  28. import model
  29. FLAGS = tf.flags.FLAGS
  30. tf.flags.DEFINE_integer('rep_dim', 128,
  31. 'dimension of keys to use in memory')
  32. tf.flags.DEFINE_integer('episode_length', 100, 'length of episode')
  33. tf.flags.DEFINE_integer('episode_width', 5,
  34. 'number of distinct labels in a single episode')
  35. tf.flags.DEFINE_integer('memory_size', None, 'number of slots in memory. '
  36. 'Leave as None to default to episode length')
  37. tf.flags.DEFINE_integer('batch_size', 16, 'batch size')
  38. tf.flags.DEFINE_integer('num_episodes', 100000, 'number of training episodes')
  39. tf.flags.DEFINE_integer('validation_frequency', 20,
  40. 'every so many training episodes, '
  41. 'assess validation accuracy')
  42. tf.flags.DEFINE_integer('validation_length', 10,
  43. 'number of episodes to use to compute '
  44. 'validation accuracy')
  45. tf.flags.DEFINE_integer('seed', 888, 'random seed for training sampling')
  46. tf.flags.DEFINE_string('save_dir', '', 'directory to save model to')
  47. tf.flags.DEFINE_bool('use_lsh', False,
  48. 'use locality-sensitive hashing '
  49. '(NOTE: not fully tested)')
  50. class Trainer(object):
  51. """Class that takes care of training, validating, and checkpointing model."""
  52. def __init__(self, train_data, valid_data, input_dim, output_dim=None):
  53. self.train_data = train_data
  54. self.valid_data = valid_data
  55. self.input_dim = input_dim
  56. self.rep_dim = FLAGS.rep_dim
  57. self.episode_length = FLAGS.episode_length
  58. self.episode_width = FLAGS.episode_width
  59. self.batch_size = FLAGS.batch_size
  60. self.memory_size = (self.episode_length * self.batch_size
  61. if FLAGS.memory_size is None else FLAGS.memory_size)
  62. self.use_lsh = FLAGS.use_lsh
  63. self.output_dim = (output_dim if output_dim is not None
  64. else self.episode_width)
  65. def get_model(self):
  66. # vocab size is the number of distinct values that
  67. # could go into the memory key-value storage
  68. vocab_size = self.episode_width * self.batch_size
  69. return model.Model(
  70. self.input_dim, self.output_dim, self.rep_dim, self.memory_size,
  71. vocab_size, use_lsh=self.use_lsh)
  72. def sample_episode_batch(self, data,
  73. episode_length, episode_width, batch_size):
  74. """Generates a random batch for training or validation.
  75. Structures each element of the batch as an 'episode'.
  76. Each episode contains episode_length examples and
  77. episode_width distinct labels.
  78. Args:
  79. data: A dictionary mapping label to list of examples.
  80. episode_length: Number of examples in each episode.
  81. episode_width: Distinct number of labels in each episode.
  82. batch_size: Batch size (number of episodes).
  83. Returns:
  84. A tuple (x, y) where x is a list of batches of examples
  85. with size episode_length and y is a list of batches of labels.
  86. """
  87. episodes_x = [[] for _ in xrange(episode_length)]
  88. episodes_y = [[] for _ in xrange(episode_length)]
  89. assert len(data) >= episode_width
  90. keys = data.keys()
  91. for b in xrange(batch_size):
  92. episode_labels = random.sample(keys, episode_width)
  93. remainder = episode_length % episode_width
  94. remainders = [0] * (episode_width - remainder) + [1] * remainder
  95. episode_x = [
  96. random.sample(data[lab],
  97. r + (episode_length - remainder) / episode_width)
  98. for lab, r in zip(episode_labels, remainders)]
  99. episode = sum([[(x, i, ii) for ii, x in enumerate(xx)]
  100. for i, xx in enumerate(episode_x)], [])
  101. random.shuffle(episode)
  102. # Arrange episode so that each distinct label is seen before moving to
  103. # 2nd showing
  104. episode.sort(key=lambda elem: elem[2])
  105. assert len(episode) == episode_length
  106. for i in xrange(episode_length):
  107. episodes_x[i].append(episode[i][0])
  108. episodes_y[i].append(episode[i][1] + b * episode_width)
  109. return ([np.array(xx).astype('float32') for xx in episodes_x],
  110. [np.array(yy).astype('int32') for yy in episodes_y])
  111. def compute_correct(self, ys, y_preds):
  112. return np.mean(np.equal(y_preds, np.array(ys)))
  113. def individual_compute_correct(self, y, y_pred):
  114. return y_pred == y
  115. def run(self):
  116. """Performs training.
  117. Trains a model using episodic training.
  118. Every so often, runs some evaluations on validation data.
  119. """
  120. train_data, valid_data = self.train_data, self.valid_data
  121. input_dim, output_dim = self.input_dim, self.output_dim
  122. rep_dim, episode_length = self.rep_dim, self.episode_length
  123. episode_width, memory_size = self.episode_width, self.memory_size
  124. batch_size = self.batch_size
  125. train_size = len(train_data)
  126. valid_size = len(valid_data)
  127. logging.info('train_size (number of labels) %d', train_size)
  128. logging.info('valid_size (number of labels) %d', valid_size)
  129. logging.info('input_dim %d', input_dim)
  130. logging.info('output_dim %d', output_dim)
  131. logging.info('rep_dim %d', rep_dim)
  132. logging.info('episode_length %d', episode_length)
  133. logging.info('episode_width %d', episode_width)
  134. logging.info('memory_size %d', memory_size)
  135. logging.info('batch_size %d', batch_size)
  136. assert all(len(v) >= float(episode_length) / episode_width
  137. for v in train_data.itervalues())
  138. assert all(len(v) >= float(episode_length) / episode_width
  139. for v in valid_data.itervalues())
  140. output_dim = episode_width
  141. self.model = self.get_model()
  142. self.model.setup()
  143. sess = tf.Session()
  144. sess.run(tf.global_variables_initializer())
  145. saver = tf.train.Saver(max_to_keep=10)
  146. ckpt = None
  147. if FLAGS.save_dir:
  148. ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir)
  149. if ckpt and ckpt.model_checkpoint_path:
  150. logging.info('restoring from %s', ckpt.model_checkpoint_path)
  151. saver.restore(sess, ckpt.model_checkpoint_path)
  152. logging.info('starting now')
  153. losses = []
  154. random.seed(FLAGS.seed)
  155. np.random.seed(FLAGS.seed)
  156. for i in xrange(FLAGS.num_episodes):
  157. x, y = self.sample_episode_batch(
  158. train_data, episode_length, episode_width, batch_size)
  159. outputs = self.model.episode_step(sess, x, y, clear_memory=True)
  160. loss = outputs
  161. losses.append(loss)
  162. if i % FLAGS.validation_frequency == 0:
  163. logging.info('episode batch %d, avg train loss %f',
  164. i, np.mean(losses))
  165. losses = []
  166. # validation
  167. correct = []
  168. correct_by_shot = dict((k, []) for k in xrange(self.episode_width + 1))
  169. for _ in xrange(FLAGS.validation_length):
  170. x, y = self.sample_episode_batch(
  171. valid_data, episode_length, episode_width, 1)
  172. outputs = self.model.episode_predict(
  173. sess, x, y, clear_memory=True)
  174. y_preds = outputs
  175. correct.append(self.compute_correct(np.array(y), y_preds))
  176. # compute per-shot accuracies
  177. seen_counts = [[0] * episode_width for _ in xrange(batch_size)]
  178. # loop over episode steps
  179. for yy, yy_preds in zip(y, y_preds):
  180. # loop over batch examples
  181. for k, (yyy, yyy_preds) in enumerate(zip(yy, yy_preds)):
  182. yyy, yyy_preds = int(yyy), int(yyy_preds)
  183. count = seen_counts[k][yyy % self.episode_width]
  184. if count in correct_by_shot:
  185. correct_by_shot[count].append(
  186. self.individual_compute_correct(yyy, yyy_preds))
  187. seen_counts[k][yyy % self.episode_width] = count + 1
  188. logging.info('validation overall accuracy %f', np.mean(correct))
  189. logging.info('%d-shot: %.3f, ' * (self.episode_width + 1),
  190. *sum([[k, np.mean(correct_by_shot[k])]
  191. for k in xrange(self.episode_width + 1)], []))
  192. if saver and FLAGS.save_dir:
  193. saved_file = saver.save(sess,
  194. os.path.join(FLAGS.save_dir, 'model.ckpt'),
  195. global_step=self.model.global_step)
  196. logging.info('saved model to %s', saved_file)
  197. def main(unused_argv):
  198. train_data, valid_data = data_utils.get_data()
  199. trainer = Trainer(train_data, valid_data, data_utils.IMAGE_NEW_SIZE ** 2)
  200. trainer.run()
  201. if __name__ == '__main__':
  202. logging.basicConfig(level=logging.INFO)
  203. tf.app.run()