ptb_word_lm.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Example / benchmark for building a PTB LSTM model.
  16. Trains the model described in:
  17. (Zaremba, et. al.) Recurrent Neural Network Regularization
  18. http://arxiv.org/abs/1409.2329
  19. There are 3 supported model configurations:
  20. ===========================================
  21. | config | epochs | train | valid | test
  22. ===========================================
  23. | small | 13 | 37.99 | 121.39 | 115.91
  24. | medium | 39 | 48.45 | 86.16 | 82.07
  25. | large | 55 | 37.87 | 82.62 | 78.29
  26. The exact results may vary depending on the random initialization.
  27. The hyperparameters used in the model:
  28. - init_scale - the initial scale of the weights
  29. - learning_rate - the initial value of the learning rate
  30. - max_grad_norm - the maximum permissible norm of the gradient
  31. - num_layers - the number of LSTM layers
  32. - num_steps - the number of unrolled steps of LSTM
  33. - hidden_size - the number of LSTM units
  34. - max_epoch - the number of epochs trained with the initial learning rate
  35. - max_max_epoch - the total number of epochs for training
  36. - keep_prob - the probability of keeping weights in the dropout layer
  37. - lr_decay - the decay of the learning rate for each epoch after "max_epoch"
  38. - batch_size - the batch size
  39. The data required for this example is in the data/ dir of the
  40. PTB dataset from Tomas Mikolov's webpage:
  41. $ wget http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
  42. $ tar xvf simple-examples.tgz
  43. To run:
  44. $ python ptb_word_lm.py --data_path=simple-examples/data/
  45. """
  46. from __future__ import absolute_import
  47. from __future__ import division
  48. from __future__ import print_function
  49. import inspect
  50. import time
  51. import numpy as np
  52. import tensorflow as tf
  53. import reader
  54. flags = tf.flags
  55. logging = tf.logging
  56. flags.DEFINE_string(
  57. "model", "small",
  58. "A type of model. Possible options are: small, medium, large.")
  59. flags.DEFINE_string("data_path", None,
  60. "Where the training/test data is stored.")
  61. flags.DEFINE_string("save_path", None,
  62. "Model output directory.")
  63. flags.DEFINE_bool("use_fp16", False,
  64. "Train using 16-bit floats instead of 32bit floats")
  65. FLAGS = flags.FLAGS
  66. def data_type():
  67. return tf.float16 if FLAGS.use_fp16 else tf.float32
  68. class PTBInput(object):
  69. """The input data."""
  70. def __init__(self, config, data, name=None):
  71. self.batch_size = batch_size = config.batch_size
  72. self.num_steps = num_steps = config.num_steps
  73. self.epoch_size = ((len(data) // batch_size) - 1) // num_steps
  74. self.input_data, self.targets = reader.ptb_producer(
  75. data, batch_size, num_steps, name=name)
  76. class PTBModel(object):
  77. """The PTB model."""
  78. def __init__(self, is_training, config, input_):
  79. self._input = input_
  80. batch_size = input_.batch_size
  81. num_steps = input_.num_steps
  82. size = config.hidden_size
  83. vocab_size = config.vocab_size
  84. # Slightly better results can be obtained with forget gate biases
  85. # initialized to 1 but the hyperparameters of the model would need to be
  86. # different than reported in the paper.
  87. def lstm_cell():
  88. # With the latest TensorFlow source code (as of Mar 27, 2017),
  89. # the BasicLSTMCell will need a reuse parameter which is unfortunately not
  90. # defined in TensorFlow 1.0. To maintain backwards compatibility, we add
  91. # an argument check here:
  92. if 'reuse' in inspect.getargspec(
  93. tf.contrib.rnn.BasicLSTMCell.__init__).args:
  94. return tf.contrib.rnn.BasicLSTMCell(
  95. size, forget_bias=0.0, state_is_tuple=True,
  96. reuse=tf.get_variable_scope().reuse)
  97. else:
  98. return tf.contrib.rnn.BasicLSTMCell(
  99. size, forget_bias=0.0, state_is_tuple=True)
  100. attn_cell = lstm_cell
  101. if is_training and config.keep_prob < 1:
  102. def attn_cell():
  103. return tf.contrib.rnn.DropoutWrapper(
  104. lstm_cell(), output_keep_prob=config.keep_prob)
  105. cell = tf.contrib.rnn.MultiRNNCell(
  106. [attn_cell() for _ in range(config.num_layers)], state_is_tuple=True)
  107. self._initial_state = cell.zero_state(batch_size, data_type())
  108. with tf.device("/cpu:0"):
  109. embedding = tf.get_variable(
  110. "embedding", [vocab_size, size], dtype=data_type())
  111. inputs = tf.nn.embedding_lookup(embedding, input_.input_data)
  112. if is_training and config.keep_prob < 1:
  113. inputs = tf.nn.dropout(inputs, config.keep_prob)
  114. # Simplified version of models/tutorials/rnn/rnn.py's rnn().
  115. # This builds an unrolled LSTM for tutorial purposes only.
  116. # In general, use the rnn() or state_saving_rnn() from rnn.py.
  117. #
  118. # The alternative version of the code below is:
  119. #
  120. # inputs = tf.unstack(inputs, num=num_steps, axis=1)
  121. # outputs, state = tf.contrib.rnn.static_rnn(
  122. # cell, inputs, initial_state=self._initial_state)
  123. outputs = []
  124. state = self._initial_state
  125. with tf.variable_scope("RNN"):
  126. for time_step in range(num_steps):
  127. if time_step > 0: tf.get_variable_scope().reuse_variables()
  128. (cell_output, state) = cell(inputs[:, time_step, :], state)
  129. outputs.append(cell_output)
  130. output = tf.reshape(tf.concat(axis=1, values=outputs), [-1, size])
  131. softmax_w = tf.get_variable(
  132. "softmax_w", [size, vocab_size], dtype=data_type())
  133. softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type())
  134. logits = tf.matmul(output, softmax_w) + softmax_b
  135. loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
  136. [logits],
  137. [tf.reshape(input_.targets, [-1])],
  138. [tf.ones([batch_size * num_steps], dtype=data_type())])
  139. self._cost = cost = tf.reduce_sum(loss) / batch_size
  140. self._final_state = state
  141. if not is_training:
  142. return
  143. self._lr = tf.Variable(0.0, trainable=False)
  144. tvars = tf.trainable_variables()
  145. grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
  146. config.max_grad_norm)
  147. optimizer = tf.train.GradientDescentOptimizer(self._lr)
  148. self._train_op = optimizer.apply_gradients(
  149. zip(grads, tvars),
  150. global_step=tf.contrib.framework.get_or_create_global_step())
  151. self._new_lr = tf.placeholder(
  152. tf.float32, shape=[], name="new_learning_rate")
  153. self._lr_update = tf.assign(self._lr, self._new_lr)
  154. def assign_lr(self, session, lr_value):
  155. session.run(self._lr_update, feed_dict={self._new_lr: lr_value})
  156. @property
  157. def input(self):
  158. return self._input
  159. @property
  160. def initial_state(self):
  161. return self._initial_state
  162. @property
  163. def cost(self):
  164. return self._cost
  165. @property
  166. def final_state(self):
  167. return self._final_state
  168. @property
  169. def lr(self):
  170. return self._lr
  171. @property
  172. def train_op(self):
  173. return self._train_op
  174. class SmallConfig(object):
  175. """Small config."""
  176. init_scale = 0.1
  177. learning_rate = 1.0
  178. max_grad_norm = 5
  179. num_layers = 2
  180. num_steps = 20
  181. hidden_size = 200
  182. max_epoch = 4
  183. max_max_epoch = 13
  184. keep_prob = 1.0
  185. lr_decay = 0.5
  186. batch_size = 20
  187. vocab_size = 10000
  188. class MediumConfig(object):
  189. """Medium config."""
  190. init_scale = 0.05
  191. learning_rate = 1.0
  192. max_grad_norm = 5
  193. num_layers = 2
  194. num_steps = 35
  195. hidden_size = 650
  196. max_epoch = 6
  197. max_max_epoch = 39
  198. keep_prob = 0.5
  199. lr_decay = 0.8
  200. batch_size = 20
  201. vocab_size = 10000
  202. class LargeConfig(object):
  203. """Large config."""
  204. init_scale = 0.04
  205. learning_rate = 1.0
  206. max_grad_norm = 10
  207. num_layers = 2
  208. num_steps = 35
  209. hidden_size = 1500
  210. max_epoch = 14
  211. max_max_epoch = 55
  212. keep_prob = 0.35
  213. lr_decay = 1 / 1.15
  214. batch_size = 20
  215. vocab_size = 10000
  216. class TestConfig(object):
  217. """Tiny config, for testing."""
  218. init_scale = 0.1
  219. learning_rate = 1.0
  220. max_grad_norm = 1
  221. num_layers = 1
  222. num_steps = 2
  223. hidden_size = 2
  224. max_epoch = 1
  225. max_max_epoch = 1
  226. keep_prob = 1.0
  227. lr_decay = 0.5
  228. batch_size = 20
  229. vocab_size = 10000
  230. def run_epoch(session, model, eval_op=None, verbose=False):
  231. """Runs the model on the given data."""
  232. start_time = time.time()
  233. costs = 0.0
  234. iters = 0
  235. state = session.run(model.initial_state)
  236. fetches = {
  237. "cost": model.cost,
  238. "final_state": model.final_state,
  239. }
  240. if eval_op is not None:
  241. fetches["eval_op"] = eval_op
  242. for step in range(model.input.epoch_size):
  243. feed_dict = {}
  244. for i, (c, h) in enumerate(model.initial_state):
  245. feed_dict[c] = state[i].c
  246. feed_dict[h] = state[i].h
  247. vals = session.run(fetches, feed_dict)
  248. cost = vals["cost"]
  249. state = vals["final_state"]
  250. costs += cost
  251. iters += model.input.num_steps
  252. if verbose and step % (model.input.epoch_size // 10) == 10:
  253. print("%.3f perplexity: %.3f speed: %.0f wps" %
  254. (step * 1.0 / model.input.epoch_size, np.exp(costs / iters),
  255. iters * model.input.batch_size / (time.time() - start_time)))
  256. return np.exp(costs / iters)
  257. def get_config():
  258. if FLAGS.model == "small":
  259. return SmallConfig()
  260. elif FLAGS.model == "medium":
  261. return MediumConfig()
  262. elif FLAGS.model == "large":
  263. return LargeConfig()
  264. elif FLAGS.model == "test":
  265. return TestConfig()
  266. else:
  267. raise ValueError("Invalid model: %s", FLAGS.model)
  268. def main(_):
  269. if not FLAGS.data_path:
  270. raise ValueError("Must set --data_path to PTB data directory")
  271. raw_data = reader.ptb_raw_data(FLAGS.data_path)
  272. train_data, valid_data, test_data, _ = raw_data
  273. config = get_config()
  274. eval_config = get_config()
  275. eval_config.batch_size = 1
  276. eval_config.num_steps = 1
  277. with tf.Graph().as_default():
  278. initializer = tf.random_uniform_initializer(-config.init_scale,
  279. config.init_scale)
  280. with tf.name_scope("Train"):
  281. train_input = PTBInput(config=config, data=train_data, name="TrainInput")
  282. with tf.variable_scope("Model", reuse=None, initializer=initializer):
  283. m = PTBModel(is_training=True, config=config, input_=train_input)
  284. tf.summary.scalar("Training Loss", m.cost)
  285. tf.summary.scalar("Learning Rate", m.lr)
  286. with tf.name_scope("Valid"):
  287. valid_input = PTBInput(config=config, data=valid_data, name="ValidInput")
  288. with tf.variable_scope("Model", reuse=True, initializer=initializer):
  289. mvalid = PTBModel(is_training=False, config=config, input_=valid_input)
  290. tf.summary.scalar("Validation Loss", mvalid.cost)
  291. with tf.name_scope("Test"):
  292. test_input = PTBInput(config=eval_config, data=test_data, name="TestInput")
  293. with tf.variable_scope("Model", reuse=True, initializer=initializer):
  294. mtest = PTBModel(is_training=False, config=eval_config,
  295. input_=test_input)
  296. sv = tf.train.Supervisor(logdir=FLAGS.save_path)
  297. with sv.managed_session() as session:
  298. for i in range(config.max_max_epoch):
  299. lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0)
  300. m.assign_lr(session, config.learning_rate * lr_decay)
  301. print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
  302. train_perplexity = run_epoch(session, m, eval_op=m.train_op,
  303. verbose=True)
  304. print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
  305. valid_perplexity = run_epoch(session, mvalid)
  306. print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity))
  307. test_perplexity = run_epoch(session, mtest)
  308. print("Test Perplexity: %.3f" % test_perplexity)
  309. if FLAGS.save_path:
  310. print("Saving model to %s." % FLAGS.save_path)
  311. sv.saver.save(session, FLAGS.save_path, global_step=sv.global_step)
  312. if __name__ == "__main__":
  313. tf.app.run()