neural_programmer.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Implementation of the Neural Programmer model described in https://openreview.net/pdf?id=ry2YOrcge
  16. This file calls functions to load & pre-process data, construct the TF graph
  17. and performs training or evaluation as specified by the flag evaluator_job
  18. Author: aneelakantan (Arvind Neelakantan)
  19. """
  20. import time
  21. from random import Random
  22. import numpy as np
  23. import tensorflow as tf
  24. import model
  25. import wiki_data
  26. import parameters
  27. import data_utils
  28. tf.flags.DEFINE_integer("train_steps", 100001, "Number of steps to train")
  29. tf.flags.DEFINE_integer("eval_cycle", 500,
  30. "Evaluate model at every eval_cycle steps")
  31. tf.flags.DEFINE_integer("max_elements", 100,
  32. "maximum rows that are considered for processing")
  33. tf.flags.DEFINE_integer(
  34. "max_number_cols", 15,
  35. "maximum number columns that are considered for processing")
  36. tf.flags.DEFINE_integer(
  37. "max_word_cols", 25,
  38. "maximum number columns that are considered for processing")
  39. tf.flags.DEFINE_integer("question_length", 62, "maximum question length")
  40. tf.flags.DEFINE_integer("max_entry_length", 1, "")
  41. tf.flags.DEFINE_integer("max_passes", 4, "number of operation passes")
  42. tf.flags.DEFINE_integer("embedding_dims", 256, "")
  43. tf.flags.DEFINE_integer("batch_size", 20, "")
  44. tf.flags.DEFINE_float("clip_gradients", 1.0, "")
  45. tf.flags.DEFINE_float("eps", 1e-6, "")
  46. tf.flags.DEFINE_float("param_init", 0.1, "")
  47. tf.flags.DEFINE_float("learning_rate", 0.001, "")
  48. tf.flags.DEFINE_float("l2_regularizer", 0.0001, "")
  49. tf.flags.DEFINE_float("print_cost", 50.0,
  50. "weighting factor in the objective function")
  51. tf.flags.DEFINE_string("job_id", "temp", """job id""")
  52. tf.flags.DEFINE_string("output_dir", "../model/",
  53. """output_dir""")
  54. tf.flags.DEFINE_string("data_dir", "../data/",
  55. """data_dir""")
  56. tf.flags.DEFINE_integer("write_every", 500, "wrtie every N")
  57. tf.flags.DEFINE_integer("param_seed", 150, "")
  58. tf.flags.DEFINE_integer("python_seed", 200, "")
  59. tf.flags.DEFINE_float("dropout", 0.8, "dropout keep probability")
  60. tf.flags.DEFINE_float("rnn_dropout", 0.9,
  61. "dropout keep probability for rnn connections")
  62. tf.flags.DEFINE_float("pad_int", -20000.0,
  63. "number columns are padded with pad_int")
  64. tf.flags.DEFINE_string("data_type", "double", "float or double")
  65. tf.flags.DEFINE_float("word_dropout_prob", 0.9, "word dropout keep prob")
  66. tf.flags.DEFINE_integer("word_cutoff", 10, "")
  67. tf.flags.DEFINE_integer("vocab_size", 10800, "")
  68. tf.flags.DEFINE_boolean("evaluator_job", False,
  69. "wehther to run as trainer/evaluator")
  70. tf.flags.DEFINE_float(
  71. "bad_number_pre_process", -200000.0,
  72. "number that is added to a corrupted table entry in a number column")
  73. tf.flags.DEFINE_float("max_math_error", 3.0,
  74. "max square loss error that is considered")
  75. tf.flags.DEFINE_float("soft_min_value", 5.0, "")
  76. FLAGS = tf.flags.FLAGS
  77. class Utility:
  78. #holds FLAGS and other variables that are used in different files
  79. def __init__(self):
  80. global FLAGS
  81. self.FLAGS = FLAGS
  82. self.unk_token = "UNK"
  83. self.entry_match_token = "entry_match"
  84. self.column_match_token = "column_match"
  85. self.dummy_token = "dummy_token"
  86. self.tf_data_type = {}
  87. self.tf_data_type["double"] = tf.float64
  88. self.tf_data_type["float"] = tf.float32
  89. self.np_data_type = {}
  90. self.np_data_type["double"] = np.float64
  91. self.np_data_type["float"] = np.float32
  92. self.operations_set = ["count"] + [
  93. "prev", "next", "first_rs", "last_rs", "group_by_max", "greater",
  94. "lesser", "geq", "leq", "max", "min", "word-match"
  95. ] + ["reset_select"] + ["print"]
  96. self.word_ids = {}
  97. self.reverse_word_ids = {}
  98. self.word_count = {}
  99. self.random = Random(FLAGS.python_seed)
  100. def evaluate(sess, data, batch_size, graph, i):
  101. #computes accuracy
  102. num_examples = 0.0
  103. gc = 0.0
  104. for j in range(0, len(data) - batch_size + 1, batch_size):
  105. [ct] = sess.run([graph.final_correct],
  106. feed_dict=data_utils.generate_feed_dict(data, j, batch_size,
  107. graph))
  108. gc += ct * batch_size
  109. num_examples += batch_size
  110. print "dev set accuracy after ", i, " : ", gc / num_examples
  111. print num_examples, len(data)
  112. print "--------"
  113. def Train(graph, utility, batch_size, train_data, sess, model_dir,
  114. saver):
  115. #performs training
  116. curr = 0
  117. train_set_loss = 0.0
  118. utility.random.shuffle(train_data)
  119. start = time.time()
  120. for i in range(utility.FLAGS.train_steps):
  121. curr_step = i
  122. if (i > 0 and i % FLAGS.write_every == 0):
  123. model_file = model_dir + "/model_" + str(i)
  124. saver.save(sess, model_file)
  125. if curr + batch_size >= len(train_data):
  126. curr = 0
  127. utility.random.shuffle(train_data)
  128. step, cost_value = sess.run(
  129. [graph.step, graph.total_cost],
  130. feed_dict=data_utils.generate_feed_dict(
  131. train_data, curr, batch_size, graph, train=True, utility=utility))
  132. curr = curr + batch_size
  133. train_set_loss += cost_value
  134. if (i > 0 and i % FLAGS.eval_cycle == 0):
  135. end = time.time()
  136. time_taken = end - start
  137. print "step ", i, " ", time_taken, " seconds "
  138. start = end
  139. print " printing train set loss: ", train_set_loss / utility.FLAGS.eval_cycle
  140. train_set_loss = 0.0
  141. def master(train_data, dev_data, utility):
  142. #creates TF graph and calls trainer or evaluator
  143. batch_size = utility.FLAGS.batch_size
  144. model_dir = utility.FLAGS.output_dir + "/model" + utility.FLAGS.job_id + "/"
  145. #create all paramters of the model
  146. param_class = parameters.Parameters(utility)
  147. params, global_step, init = param_class.parameters(utility)
  148. key = "test" if (FLAGS.evaluator_job) else "train"
  149. graph = model.Graph(utility, batch_size, utility.FLAGS.max_passes, mode=key)
  150. graph.create_graph(params, global_step)
  151. prev_dev_error = 0.0
  152. final_loss = 0.0
  153. final_accuracy = 0.0
  154. #start session
  155. with tf.Session() as sess:
  156. sess.run(init.name)
  157. sess.run(graph.init_op.name)
  158. to_save = params.copy()
  159. saver = tf.train.Saver(to_save, max_to_keep=500)
  160. if (FLAGS.evaluator_job):
  161. while True:
  162. selected_models = {}
  163. file_list = tf.gfile.ListDirectory(model_dir)
  164. for model_file in file_list:
  165. if ("checkpoint" in model_file or "index" in model_file or
  166. "meta" in model_file):
  167. continue
  168. if ("data" in model_file):
  169. model_file = model_file.split(".")[0]
  170. model_step = int(
  171. model_file.split("_")[len(model_file.split("_")) - 1])
  172. selected_models[model_step] = model_file
  173. file_list = sorted(selected_models.items(), key=lambda x: x[0])
  174. if (len(file_list) > 0):
  175. file_list = file_list[0:len(file_list) - 1]
  176. print "list of models: ", file_list
  177. for model_file in file_list:
  178. model_file = model_file[1]
  179. print "restoring: ", model_file
  180. saver.restore(sess, model_dir + "/" + model_file)
  181. model_step = int(
  182. model_file.split("_")[len(model_file.split("_")) - 1])
  183. print "evaluating on dev ", model_file, model_step
  184. evaluate(sess, dev_data, batch_size, graph, model_step)
  185. else:
  186. ckpt = tf.train.get_checkpoint_state(model_dir)
  187. print "model dir: ", model_dir
  188. if (not (tf.gfile.IsDirectory(utility.FLAGS.output_dir))):
  189. print "create dir: ", utility.FLAGS.output_dir
  190. tf.gfile.MkDir(utility.FLAGS.output_dir)
  191. if (not (tf.gfile.IsDirectory(model_dir))):
  192. print "create dir: ", model_dir
  193. tf.gfile.MkDir(model_dir)
  194. Train(graph, utility, batch_size, train_data, sess, model_dir,
  195. saver)
  196. def main(args):
  197. utility = Utility()
  198. train_name = "random-split-1-train.examples"
  199. dev_name = "random-split-1-dev.examples"
  200. test_name = "pristine-unseen-tables.examples"
  201. #load data
  202. dat = wiki_data.WikiQuestionGenerator(train_name, dev_name, test_name, FLAGS.data_dir)
  203. train_data, dev_data, test_data = dat.load()
  204. utility.words = []
  205. utility.word_ids = {}
  206. utility.reverse_word_ids = {}
  207. #construct vocabulary
  208. data_utils.construct_vocab(train_data, utility)
  209. data_utils.construct_vocab(dev_data, utility, True)
  210. data_utils.construct_vocab(test_data, utility, True)
  211. data_utils.add_special_words(utility)
  212. data_utils.perform_word_cutoff(utility)
  213. #convert data to int format and pad the inputs
  214. train_data = data_utils.complete_wiki_processing(train_data, utility, True)
  215. dev_data = data_utils.complete_wiki_processing(dev_data, utility, False)
  216. test_data = data_utils.complete_wiki_processing(test_data, utility, False)
  217. print "# train examples ", len(train_data)
  218. print "# dev examples ", len(dev_data)
  219. print "# test examples ", len(test_data)
  220. print "running open source"
  221. #construct TF graph and train or evaluate
  222. master(train_data, dev_data, utility)
  223. if __name__ == "__main__":
  224. tf.app.run()