names.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """A library showing off sequence recognition and generation with the simple
  15. example of names.
  16. We use recurrent neural nets to learn complex functions able to recognize and
  17. generate sequences of a given form. This can be used for natural language
  18. syntax recognition, dynamically generating maps or puzzles and of course
  19. baby name generation.
  20. Before using this module, it is recommended to read the Tensorflow tutorial on
  21. recurrent neural nets, as it explains the basic concepts of this model, and
  22. will show off another module, the PTB module on which this model bases itself.
  23. Here is an overview of the functions available in this module:
  24. * RNN Module for sequence functions based on PTB
  25. * Name recognition specifically for recognizing names, but can be adapted to
  26. recognizing sequence patterns
  27. * Name generations specifically for generating names, but can be adapted to
  28. generating arbitrary sequence patterns
  29. """
  30. from __future__ import absolute_import
  31. from __future__ import division
  32. from __future__ import print_function
  33. import time
  34. import tensorflow as tf
  35. import numpy as np
  36. from model import NamignizerModel
  37. import data_utils
  38. class SmallConfig(object):
  39. """Small config."""
  40. init_scale = 0.1
  41. learning_rate = 1.0
  42. max_grad_norm = 5
  43. num_layers = 2
  44. num_steps = 20
  45. hidden_size = 200
  46. max_epoch = 4
  47. max_max_epoch = 13
  48. keep_prob = 1.0
  49. lr_decay = 0.5
  50. batch_size = 20
  51. vocab_size = 27
  52. epoch_size = 100
  53. class LargeConfig(object):
  54. """Medium config."""
  55. init_scale = 0.05
  56. learning_rate = 1.0
  57. max_grad_norm = 5
  58. num_layers = 2
  59. num_steps = 35
  60. hidden_size = 650
  61. max_epoch = 6
  62. max_max_epoch = 39
  63. keep_prob = 0.5
  64. lr_decay = 0.8
  65. batch_size = 20
  66. vocab_size = 27
  67. epoch_size = 100
  68. class TestConfig(object):
  69. """Tiny config, for testing."""
  70. init_scale = 0.1
  71. learning_rate = 1.0
  72. max_grad_norm = 1
  73. num_layers = 1
  74. num_steps = 2
  75. hidden_size = 2
  76. max_epoch = 1
  77. max_max_epoch = 1
  78. keep_prob = 1.0
  79. lr_decay = 0.5
  80. batch_size = 20
  81. vocab_size = 27
  82. epoch_size = 100
  83. def run_epoch(session, m, names, counts, epoch_size, eval_op, verbose=False):
  84. """Runs the model on the given data for one epoch
  85. Args:
  86. session: the tf session holding the model graph
  87. m: an instance of the NamignizerModel
  88. names: a set of lowercase names of 26 characters
  89. counts: a list of the frequency of the above names
  90. epoch_size: the number of batches to run
  91. eval_op: whether to change the params or not, and how to do it
  92. Kwargs:
  93. verbose: whether to print out state of training during the epoch
  94. Returns:
  95. cost: the average cost during the last stage of the epoch
  96. """
  97. start_time = time.time()
  98. costs = 0.0
  99. iters = 0
  100. for step, (x, y) in enumerate(data_utils.namignizer_iterator(names, counts,
  101. m.batch_size, m.num_steps, epoch_size)):
  102. cost, _ = session.run([m.cost, eval_op],
  103. {m.input_data: x,
  104. m.targets: y,
  105. m.weights: np.ones(m.batch_size * m.num_steps)})
  106. costs += cost
  107. iters += m.num_steps
  108. if verbose and step % (epoch_size // 10) == 9:
  109. print("%.3f perplexity: %.3f speed: %.0f lps" %
  110. (step * 1.0 / epoch_size, np.exp(costs / iters),
  111. iters * m.batch_size / (time.time() - start_time)))
  112. if step >= epoch_size:
  113. break
  114. return np.exp(costs / iters)
  115. def train(data_dir, checkpoint_path, config):
  116. """Trains the model with the given data
  117. Args:
  118. data_dir: path to the data for the model (see data_utils for data
  119. format)
  120. checkpoint_path: the path to save the trained model checkpoints
  121. config: one of the above configs that specify the model and how it
  122. should be run and trained
  123. Returns:
  124. None
  125. """
  126. # Prepare Name data.
  127. print("Reading Name data in %s" % data_dir)
  128. names, counts = data_utils.read_names(data_dir)
  129. with tf.Graph().as_default(), tf.Session() as session:
  130. initializer = tf.random_uniform_initializer(-config.init_scale,
  131. config.init_scale)
  132. with tf.variable_scope("model", reuse=None, initializer=initializer):
  133. m = NamignizerModel(is_training=True, config=config)
  134. tf.global_variables_initializer().run()
  135. for i in range(config.max_max_epoch):
  136. lr_decay = config.lr_decay ** max(i - config.max_epoch, 0.0)
  137. m.assign_lr(session, config.learning_rate * lr_decay)
  138. print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
  139. train_perplexity = run_epoch(session, m, names, counts, config.epoch_size, m.train_op,
  140. verbose=True)
  141. print("Epoch: %d Train Perplexity: %.3f" %
  142. (i + 1, train_perplexity))
  143. m.saver.save(session, checkpoint_path, global_step=i)
  144. def namignize(names, checkpoint_path, config):
  145. """Recognizes names and prints the Perplexity of the model for each names
  146. in the list
  147. Args:
  148. names: a list of names in the model format
  149. checkpoint_path: the path to restore the trained model from, should not
  150. include the model name, just the path to
  151. config: one of the above configs that specify the model and how it
  152. should be run and trained
  153. Returns:
  154. None
  155. """
  156. with tf.Graph().as_default(), tf.Session() as session:
  157. with tf.variable_scope("model"):
  158. m = NamignizerModel(is_training=False, config=config)
  159. m.saver.restore(session, checkpoint_path)
  160. for name in names:
  161. x, y = data_utils.name_to_batch(name, m.batch_size, m.num_steps)
  162. cost, loss, _ = session.run([m.cost, m.loss, tf.no_op()],
  163. {m.input_data: x,
  164. m.targets: y,
  165. m.weights: np.concatenate((
  166. np.ones(len(name)), np.zeros(m.batch_size * m.num_steps - len(name))))})
  167. print("Name {} gives us a perplexity of {}".format(
  168. name, np.exp(cost)))
  169. def namignator(checkpoint_path, config):
  170. """Generates names randomly according to a given model
  171. Args:
  172. checkpoint_path: the path to restore the trained model from, should not
  173. include the model name, just the path to
  174. config: one of the above configs that specify the model and how it
  175. should be run and trained
  176. Returns:
  177. None
  178. """
  179. # mutate the config to become a name generator config
  180. config.num_steps = 1
  181. config.batch_size = 1
  182. with tf.Graph().as_default(), tf.Session() as session:
  183. with tf.variable_scope("model"):
  184. m = NamignizerModel(is_training=False, config=config)
  185. m.saver.restore(session, checkpoint_path)
  186. activations, final_state, _ = session.run([m.activations, m.final_state, tf.no_op()],
  187. {m.input_data: np.zeros((1, 1)),
  188. m.targets: np.zeros((1, 1)),
  189. m.weights: np.ones(1)})
  190. # sample from our softmax activations
  191. next_letter = np.random.choice(27, p=activations[0])
  192. name = [next_letter]
  193. while next_letter != 0:
  194. activations, final_state, _ = session.run([m.activations, m.final_state, tf.no_op()],
  195. {m.input_data: [[next_letter]],
  196. m.targets: np.zeros((1, 1)),
  197. m.initial_state: final_state,
  198. m.weights: np.ones(1)})
  199. next_letter = np.random.choice(27, p=activations[0])
  200. name += [next_letter]
  201. print(map(lambda x: chr(x + 96), name))
  202. if __name__ == "__main__":
  203. train("data/SmallNames.txt", "model/namignizer", SmallConfig)
  204. namignize(["mary", "ida", "gazorbazorb", "mmmhmm", "bob"],
  205. tf.train.latest_checkpoint("model"), SmallConfig)
  206. namignator(tf.train.latest_checkpoint("model"), SmallConfig)