model.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Author: aneelakantan (Arvind Neelakantan)
  16. """
  17. import numpy as np
  18. import tensorflow as tf
  19. import nn_utils
  20. class Graph():
  21. def __init__(self, utility, batch_size, max_passes, mode="train"):
  22. self.utility = utility
  23. self.data_type = self.utility.tf_data_type[self.utility.FLAGS.data_type]
  24. self.max_elements = self.utility.FLAGS.max_elements
  25. max_elements = self.utility.FLAGS.max_elements
  26. self.num_cols = self.utility.FLAGS.max_number_cols
  27. self.num_word_cols = self.utility.FLAGS.max_word_cols
  28. self.question_length = self.utility.FLAGS.question_length
  29. self.batch_size = batch_size
  30. self.max_passes = max_passes
  31. self.mode = mode
  32. self.embedding_dims = self.utility.FLAGS.embedding_dims
  33. #input question and a mask
  34. self.batch_question = tf.placeholder(tf.int32,
  35. [batch_size, self.question_length])
  36. self.batch_question_attention_mask = tf.placeholder(
  37. self.data_type, [batch_size, self.question_length])
  38. #ground truth scalar answer and lookup answer
  39. self.batch_answer = tf.placeholder(self.data_type, [batch_size])
  40. self.batch_print_answer = tf.placeholder(
  41. self.data_type,
  42. [batch_size, self.num_cols + self.num_word_cols, max_elements])
  43. #number columns and its processed version
  44. self.batch_number_column = tf.placeholder(
  45. self.data_type, [batch_size, self.num_cols, max_elements
  46. ]) #columns with numeric entries
  47. self.batch_processed_number_column = tf.placeholder(
  48. self.data_type, [batch_size, self.num_cols, max_elements])
  49. self.batch_processed_sorted_index_number_column = tf.placeholder(
  50. tf.int32, [batch_size, self.num_cols, max_elements])
  51. #word columns and its processed version
  52. self.batch_processed_word_column = tf.placeholder(
  53. self.data_type, [batch_size, self.num_word_cols, max_elements])
  54. self.batch_processed_sorted_index_word_column = tf.placeholder(
  55. tf.int32, [batch_size, self.num_word_cols, max_elements])
  56. self.batch_word_column_entry_mask = tf.placeholder(
  57. tf.int32, [batch_size, self.num_word_cols, max_elements])
  58. #names of word and number columns along with their mask
  59. self.batch_word_column_names = tf.placeholder(
  60. tf.int32,
  61. [batch_size, self.num_word_cols, self.utility.FLAGS.max_entry_length])
  62. self.batch_word_column_mask = tf.placeholder(
  63. self.data_type, [batch_size, self.num_word_cols])
  64. self.batch_number_column_names = tf.placeholder(
  65. tf.int32,
  66. [batch_size, self.num_cols, self.utility.FLAGS.max_entry_length])
  67. self.batch_number_column_mask = tf.placeholder(self.data_type,
  68. [batch_size, self.num_cols])
  69. #exact match and group by max operation
  70. self.batch_exact_match = tf.placeholder(
  71. self.data_type,
  72. [batch_size, self.num_cols + self.num_word_cols, max_elements])
  73. self.batch_column_exact_match = tf.placeholder(
  74. self.data_type, [batch_size, self.num_cols + self.num_word_cols])
  75. self.batch_group_by_max = tf.placeholder(
  76. self.data_type,
  77. [batch_size, self.num_cols + self.num_word_cols, max_elements])
  78. #numbers in the question along with their position. This is used to compute arguments to the comparison operations
  79. self.batch_question_number = tf.placeholder(self.data_type, [batch_size, 1])
  80. self.batch_question_number_one = tf.placeholder(self.data_type,
  81. [batch_size, 1])
  82. self.batch_question_number_mask = tf.placeholder(
  83. self.data_type, [batch_size, max_elements])
  84. self.batch_question_number_one_mask = tf.placeholder(self.data_type,
  85. [batch_size, 1])
  86. self.batch_ordinal_question = tf.placeholder(
  87. self.data_type, [batch_size, self.question_length])
  88. self.batch_ordinal_question_one = tf.placeholder(
  89. self.data_type, [batch_size, self.question_length])
  90. def LSTM_question_embedding(self, sentence, sentence_length):
  91. #LSTM processes the input question
  92. lstm_params = "question_lstm"
  93. hidden_vectors = []
  94. sentence = self.batch_question
  95. question_hidden = tf.zeros(
  96. [self.batch_size, self.utility.FLAGS.embedding_dims], self.data_type)
  97. question_c_hidden = tf.zeros(
  98. [self.batch_size, self.utility.FLAGS.embedding_dims], self.data_type)
  99. if (self.utility.FLAGS.rnn_dropout > 0.0):
  100. if (self.mode == "train"):
  101. rnn_dropout_mask = tf.cast(
  102. tf.random_uniform(
  103. tf.shape(question_hidden), minval=0.0, maxval=1.0) <
  104. self.utility.FLAGS.rnn_dropout,
  105. self.data_type) / self.utility.FLAGS.rnn_dropout
  106. else:
  107. rnn_dropout_mask = tf.ones_like(question_hidden)
  108. for question_iterator in range(self.question_length):
  109. curr_word = sentence[:, question_iterator]
  110. question_vector = nn_utils.apply_dropout(
  111. nn_utils.get_embedding(curr_word, self.utility, self.params),
  112. self.utility.FLAGS.dropout, self.mode)
  113. question_hidden, question_c_hidden = nn_utils.LSTMCell(
  114. question_vector, question_hidden, question_c_hidden, lstm_params,
  115. self.params)
  116. if (self.utility.FLAGS.rnn_dropout > 0.0):
  117. question_hidden = question_hidden * rnn_dropout_mask
  118. hidden_vectors.append(tf.expand_dims(question_hidden, 0))
  119. hidden_vectors = tf.concat(axis=0, values=hidden_vectors)
  120. return question_hidden, hidden_vectors
  121. def history_recurrent_step(self, curr_hprev, hprev):
  122. #A single RNN step for controller or history RNN
  123. return tf.tanh(
  124. tf.matmul(
  125. tf.concat(axis=1, values=[hprev, curr_hprev]), self.params[
  126. "history_recurrent"])) + self.params["history_recurrent_bias"]
  127. def question_number_softmax(self, hidden_vectors):
  128. #Attention on quetsion to decide the question number to passed to comparison ops
  129. def compute_ans(op_embedding, comparison):
  130. op_embedding = tf.expand_dims(op_embedding, 0)
  131. #dot product of operation embedding with hidden state to the left of the number occurrence
  132. first = tf.transpose(
  133. tf.matmul(op_embedding,
  134. tf.transpose(
  135. tf.reduce_sum(hidden_vectors * tf.tile(
  136. tf.expand_dims(
  137. tf.transpose(self.batch_ordinal_question), 2),
  138. [1, 1, self.utility.FLAGS.embedding_dims]), 0))))
  139. second = self.batch_question_number_one_mask + tf.transpose(
  140. tf.matmul(op_embedding,
  141. tf.transpose(
  142. tf.reduce_sum(hidden_vectors * tf.tile(
  143. tf.expand_dims(
  144. tf.transpose(self.batch_ordinal_question_one), 2
  145. ), [1, 1, self.utility.FLAGS.embedding_dims]), 0))))
  146. question_number_softmax = tf.nn.softmax(tf.concat(axis=1, values=[first, second]))
  147. if (self.mode == "test"):
  148. cond = tf.equal(question_number_softmax,
  149. tf.reshape(
  150. tf.reduce_max(question_number_softmax, 1),
  151. [self.batch_size, 1]))
  152. question_number_softmax = tf.where(
  153. cond,
  154. tf.fill(tf.shape(question_number_softmax), 1.0),
  155. tf.fill(tf.shape(question_number_softmax), 0.0))
  156. question_number_softmax = tf.cast(question_number_softmax,
  157. self.data_type)
  158. ans = tf.reshape(
  159. tf.reduce_sum(question_number_softmax * tf.concat(
  160. axis=1, values=[self.batch_question_number, self.batch_question_number_one]),
  161. 1), [self.batch_size, 1])
  162. return ans
  163. def compute_op_position(op_name):
  164. for i in range(len(self.utility.operations_set)):
  165. if (op_name == self.utility.operations_set[i]):
  166. return i
  167. def compute_question_number(op_name):
  168. op_embedding = tf.nn.embedding_lookup(self.params_unit,
  169. compute_op_position(op_name))
  170. return compute_ans(op_embedding, op_name)
  171. curr_greater_question_number = compute_question_number("greater")
  172. curr_lesser_question_number = compute_question_number("lesser")
  173. curr_geq_question_number = compute_question_number("geq")
  174. curr_leq_question_number = compute_question_number("leq")
  175. return curr_greater_question_number, curr_lesser_question_number, curr_geq_question_number, curr_leq_question_number
  176. def perform_attention(self, context_vector, hidden_vectors, length, mask):
  177. #Performs attention on hiddent_vectors using context vector
  178. context_vector = tf.tile(
  179. tf.expand_dims(context_vector, 0), [length, 1, 1]) #time * bs * d
  180. attention_softmax = tf.nn.softmax(
  181. tf.transpose(tf.reduce_sum(context_vector * hidden_vectors, 2)) +
  182. mask) #batch_size * time
  183. attention_softmax = tf.tile(
  184. tf.expand_dims(tf.transpose(attention_softmax), 2),
  185. [1, 1, self.embedding_dims])
  186. ans_vector = tf.reduce_sum(attention_softmax * hidden_vectors, 0)
  187. return ans_vector
  188. #computes embeddings for column names using parameters of question module
  189. def get_column_hidden_vectors(self):
  190. #vector representations for the column names
  191. self.column_hidden_vectors = tf.reduce_sum(
  192. nn_utils.get_embedding(self.batch_number_column_names, self.utility,
  193. self.params), 2)
  194. self.word_column_hidden_vectors = tf.reduce_sum(
  195. nn_utils.get_embedding(self.batch_word_column_names, self.utility,
  196. self.params), 2)
  197. def create_summary_embeddings(self):
  198. #embeddings for each text entry in the table using parameters of the question module
  199. self.summary_text_entry_embeddings = tf.reduce_sum(
  200. tf.expand_dims(self.batch_exact_match, 3) * tf.expand_dims(
  201. tf.expand_dims(
  202. tf.expand_dims(
  203. nn_utils.get_embedding(self.utility.entry_match_token_id,
  204. self.utility, self.params), 0), 1),
  205. 2), 2)
  206. def compute_column_softmax(self, column_controller_vector, time_step):
  207. #compute softmax over all the columns using column controller vector
  208. column_controller_vector = tf.tile(
  209. tf.expand_dims(column_controller_vector, 1),
  210. [1, self.num_cols + self.num_word_cols, 1]) #max_cols * bs * d
  211. column_controller_vector = nn_utils.apply_dropout(
  212. column_controller_vector, self.utility.FLAGS.dropout, self.mode)
  213. self.full_column_hidden_vectors = tf.concat(
  214. axis=1, values=[self.column_hidden_vectors, self.word_column_hidden_vectors])
  215. self.full_column_hidden_vectors += self.summary_text_entry_embeddings
  216. self.full_column_hidden_vectors = nn_utils.apply_dropout(
  217. self.full_column_hidden_vectors, self.utility.FLAGS.dropout, self.mode)
  218. column_logits = tf.reduce_sum(
  219. column_controller_vector * self.full_column_hidden_vectors, 2) + (
  220. self.params["word_match_feature_column_name"] *
  221. self.batch_column_exact_match) + self.full_column_mask
  222. column_softmax = tf.nn.softmax(column_logits) #batch_size * max_cols
  223. return column_softmax
  224. def compute_first_or_last(self, select, first=True):
  225. #perform first ot last operation on row select with probabilistic row selection
  226. answer = tf.zeros_like(select)
  227. running_sum = tf.zeros([self.batch_size, 1], self.data_type)
  228. for i in range(self.max_elements):
  229. if (first):
  230. current = tf.slice(select, [0, i], [self.batch_size, 1])
  231. else:
  232. current = tf.slice(select, [0, self.max_elements - 1 - i],
  233. [self.batch_size, 1])
  234. curr_prob = current * (1 - running_sum)
  235. curr_prob = curr_prob * tf.cast(curr_prob >= 0.0, self.data_type)
  236. running_sum += curr_prob
  237. temp_ans = []
  238. curr_prob = tf.expand_dims(tf.reshape(curr_prob, [self.batch_size]), 0)
  239. for i_ans in range(self.max_elements):
  240. if (not (first) and i_ans == self.max_elements - 1 - i):
  241. temp_ans.append(curr_prob)
  242. elif (first and i_ans == i):
  243. temp_ans.append(curr_prob)
  244. else:
  245. temp_ans.append(tf.zeros_like(curr_prob))
  246. temp_ans = tf.transpose(tf.concat(axis=0, values=temp_ans))
  247. answer += temp_ans
  248. return answer
  249. def make_hard_softmax(self, softmax):
  250. #converts soft selection to hard selection. used at test time
  251. cond = tf.equal(
  252. softmax, tf.reshape(tf.reduce_max(softmax, 1), [self.batch_size, 1]))
  253. softmax = tf.where(
  254. cond, tf.fill(tf.shape(softmax), 1.0), tf.fill(tf.shape(softmax), 0.0))
  255. softmax = tf.cast(softmax, self.data_type)
  256. return softmax
  257. def compute_max_or_min(self, select, maxi=True):
  258. #computes the argmax and argmin of a column with probabilistic row selection
  259. answer = tf.zeros([
  260. self.batch_size, self.num_cols + self.num_word_cols, self.max_elements
  261. ], self.data_type)
  262. sum_prob = tf.zeros([self.batch_size, self.num_cols + self.num_word_cols],
  263. self.data_type)
  264. for j in range(self.max_elements):
  265. if (maxi):
  266. curr_pos = j
  267. else:
  268. curr_pos = self.max_elements - 1 - j
  269. select_index = tf.slice(self.full_processed_sorted_index_column,
  270. [0, 0, curr_pos], [self.batch_size, -1, 1])
  271. select_mask = tf.equal(
  272. tf.tile(
  273. tf.expand_dims(
  274. tf.tile(
  275. tf.expand_dims(tf.range(self.max_elements), 0),
  276. [self.batch_size, 1]), 1),
  277. [1, self.num_cols + self.num_word_cols, 1]), select_index)
  278. curr_prob = tf.expand_dims(select, 1) * tf.cast(
  279. select_mask, self.data_type) * self.select_bad_number_mask
  280. curr_prob = curr_prob * tf.expand_dims((1 - sum_prob), 2)
  281. curr_prob = curr_prob * tf.expand_dims(
  282. tf.cast((1 - sum_prob) > 0.0, self.data_type), 2)
  283. answer = tf.where(select_mask, curr_prob, answer)
  284. sum_prob += tf.reduce_sum(curr_prob, 2)
  285. return answer
  286. def perform_operations(self, softmax, full_column_softmax, select,
  287. prev_select_1, curr_pass):
  288. #performs all the 15 operations. computes scalar output, lookup answer and row selector
  289. column_softmax = tf.slice(full_column_softmax, [0, 0],
  290. [self.batch_size, self.num_cols])
  291. word_column_softmax = tf.slice(full_column_softmax, [0, self.num_cols],
  292. [self.batch_size, self.num_word_cols])
  293. init_max = self.compute_max_or_min(select, maxi=True)
  294. init_min = self.compute_max_or_min(select, maxi=False)
  295. #operations that are column independent
  296. count = tf.reshape(tf.reduce_sum(select, 1), [self.batch_size, 1])
  297. select_full_column_softmax = tf.tile(
  298. tf.expand_dims(full_column_softmax, 2),
  299. [1, 1, self.max_elements
  300. ]) #BS * (max_cols + max_word_cols) * max_elements
  301. select_word_column_softmax = tf.tile(
  302. tf.expand_dims(word_column_softmax, 2),
  303. [1, 1, self.max_elements]) #BS * max_word_cols * max_elements
  304. select_greater = tf.reduce_sum(
  305. self.init_select_greater * select_full_column_softmax,
  306. 1) * self.batch_question_number_mask #BS * max_elements
  307. select_lesser = tf.reduce_sum(
  308. self.init_select_lesser * select_full_column_softmax,
  309. 1) * self.batch_question_number_mask #BS * max_elements
  310. select_geq = tf.reduce_sum(
  311. self.init_select_geq * select_full_column_softmax,
  312. 1) * self.batch_question_number_mask #BS * max_elements
  313. select_leq = tf.reduce_sum(
  314. self.init_select_leq * select_full_column_softmax,
  315. 1) * self.batch_question_number_mask #BS * max_elements
  316. select_max = tf.reduce_sum(init_max * select_full_column_softmax,
  317. 1) #BS * max_elements
  318. select_min = tf.reduce_sum(init_min * select_full_column_softmax,
  319. 1) #BS * max_elements
  320. select_prev = tf.concat(axis=1, values=[
  321. tf.slice(select, [0, 1], [self.batch_size, self.max_elements - 1]),
  322. tf.cast(tf.zeros([self.batch_size, 1]), self.data_type)
  323. ])
  324. select_next = tf.concat(axis=1, values=[
  325. tf.cast(tf.zeros([self.batch_size, 1]), self.data_type), tf.slice(
  326. select, [0, 0], [self.batch_size, self.max_elements - 1])
  327. ])
  328. select_last_rs = self.compute_first_or_last(select, False)
  329. select_first_rs = self.compute_first_or_last(select, True)
  330. select_word_match = tf.reduce_sum(self.batch_exact_match *
  331. select_full_column_softmax, 1)
  332. select_group_by_max = tf.reduce_sum(self.batch_group_by_max *
  333. select_full_column_softmax, 1)
  334. length_content = 1
  335. length_select = 13
  336. length_print = 1
  337. values = tf.concat(axis=1, values=[count])
  338. softmax_content = tf.slice(softmax, [0, 0],
  339. [self.batch_size, length_content])
  340. #compute scalar output
  341. output = tf.reduce_sum(tf.multiply(softmax_content, values), 1)
  342. #compute lookup answer
  343. softmax_print = tf.slice(softmax, [0, length_content + length_select],
  344. [self.batch_size, length_print])
  345. curr_print = select_full_column_softmax * tf.tile(
  346. tf.expand_dims(select, 1),
  347. [1, self.num_cols + self.num_word_cols, 1
  348. ]) #BS * max_cols * max_elements (conisders only column)
  349. self.batch_lookup_answer = curr_print * tf.tile(
  350. tf.expand_dims(softmax_print, 2),
  351. [1, self.num_cols + self.num_word_cols, self.max_elements
  352. ]) #BS * max_cols * max_elements
  353. self.batch_lookup_answer = self.batch_lookup_answer * self.select_full_mask
  354. #compute row select
  355. softmax_select = tf.slice(softmax, [0, length_content],
  356. [self.batch_size, length_select])
  357. select_lists = [
  358. tf.expand_dims(select_prev, 1), tf.expand_dims(select_next, 1),
  359. tf.expand_dims(select_first_rs, 1), tf.expand_dims(select_last_rs, 1),
  360. tf.expand_dims(select_group_by_max, 1),
  361. tf.expand_dims(select_greater, 1), tf.expand_dims(select_lesser, 1),
  362. tf.expand_dims(select_geq, 1), tf.expand_dims(select_leq, 1),
  363. tf.expand_dims(select_max, 1), tf.expand_dims(select_min, 1),
  364. tf.expand_dims(select_word_match, 1),
  365. tf.expand_dims(self.reset_select, 1)
  366. ]
  367. select = tf.reduce_sum(
  368. tf.tile(tf.expand_dims(softmax_select, 2), [1, 1, self.max_elements]) *
  369. tf.concat(axis=1, values=select_lists), 1)
  370. select = select * self.select_whole_mask
  371. return output, select
  372. def one_pass(self, select, question_embedding, hidden_vectors, hprev,
  373. prev_select_1, curr_pass):
  374. #Performs one timestep which involves selecting an operation and a column
  375. attention_vector = self.perform_attention(
  376. hprev, hidden_vectors, self.question_length,
  377. self.batch_question_attention_mask) #batch_size * embedding_dims
  378. controller_vector = tf.nn.relu(
  379. tf.matmul(hprev, self.params["controller_prev"]) + tf.matmul(
  380. tf.concat(axis=1, values=[question_embedding, attention_vector]), self.params[
  381. "controller"]))
  382. column_controller_vector = tf.nn.relu(
  383. tf.matmul(hprev, self.params["column_controller_prev"]) + tf.matmul(
  384. tf.concat(axis=1, values=[question_embedding, attention_vector]), self.params[
  385. "column_controller"]))
  386. controller_vector = nn_utils.apply_dropout(
  387. controller_vector, self.utility.FLAGS.dropout, self.mode)
  388. self.operation_logits = tf.matmul(controller_vector,
  389. tf.transpose(self.params_unit))
  390. softmax = tf.nn.softmax(self.operation_logits)
  391. soft_softmax = softmax
  392. #compute column softmax: bs * max_columns
  393. weighted_op_representation = tf.transpose(
  394. tf.matmul(tf.transpose(self.params_unit), tf.transpose(softmax)))
  395. column_controller_vector = tf.nn.relu(
  396. tf.matmul(
  397. tf.concat(axis=1, values=[
  398. column_controller_vector, weighted_op_representation
  399. ]), self.params["break_conditional"]))
  400. full_column_softmax = self.compute_column_softmax(column_controller_vector,
  401. curr_pass)
  402. soft_column_softmax = full_column_softmax
  403. if (self.mode == "test"):
  404. full_column_softmax = self.make_hard_softmax(full_column_softmax)
  405. softmax = self.make_hard_softmax(softmax)
  406. output, select = self.perform_operations(softmax, full_column_softmax,
  407. select, prev_select_1, curr_pass)
  408. return output, select, softmax, soft_softmax, full_column_softmax, soft_column_softmax
  409. def compute_lookup_error(self, val):
  410. #computes lookup error.
  411. cond = tf.equal(self.batch_print_answer, val)
  412. inter = tf.where(
  413. cond, self.init_print_error,
  414. tf.tile(
  415. tf.reshape(tf.constant(1e10, self.data_type), [1, 1, 1]), [
  416. self.batch_size, self.utility.FLAGS.max_word_cols +
  417. self.utility.FLAGS.max_number_cols,
  418. self.utility.FLAGS.max_elements
  419. ]))
  420. return tf.reduce_min(tf.reduce_min(inter, 1), 1) * tf.cast(
  421. tf.greater(
  422. tf.reduce_sum(tf.reduce_sum(tf.cast(cond, self.data_type), 1), 1),
  423. 0.0), self.data_type)
  424. def soft_min(self, x, y):
  425. return tf.maximum(-1.0 * (1 / (
  426. self.utility.FLAGS.soft_min_value + 0.0)) * tf.log(
  427. tf.exp(-self.utility.FLAGS.soft_min_value * x) + tf.exp(
  428. -self.utility.FLAGS.soft_min_value * y)), tf.zeros_like(x))
  429. def error_computation(self):
  430. #computes the error of each example in a batch
  431. math_error = 0.5 * tf.square(tf.subtract(self.scalar_output, self.batch_answer))
  432. #scale math error
  433. math_error = math_error / self.rows
  434. math_error = tf.minimum(math_error, self.utility.FLAGS.max_math_error *
  435. tf.ones(tf.shape(math_error), self.data_type))
  436. self.init_print_error = tf.where(
  437. self.batch_gold_select, -1 * tf.log(self.batch_lookup_answer + 1e-300 +
  438. self.invert_select_full_mask), -1 *
  439. tf.log(1 - self.batch_lookup_answer)) * self.select_full_mask
  440. print_error_1 = self.init_print_error * tf.cast(
  441. tf.equal(self.batch_print_answer, 0.0), self.data_type)
  442. print_error = tf.reduce_sum(tf.reduce_sum((print_error_1), 1), 1)
  443. for val in range(1, 58):
  444. print_error += self.compute_lookup_error(val + 0.0)
  445. print_error = print_error * self.utility.FLAGS.print_cost / self.num_entries
  446. if (self.mode == "train"):
  447. error = tf.where(
  448. tf.logical_and(
  449. tf.not_equal(self.batch_answer, 0.0),
  450. tf.not_equal(
  451. tf.reduce_sum(tf.reduce_sum(self.batch_print_answer, 1), 1),
  452. 0.0)),
  453. self.soft_min(math_error, print_error),
  454. tf.where(
  455. tf.not_equal(self.batch_answer, 0.0), math_error, print_error))
  456. else:
  457. error = tf.where(
  458. tf.logical_and(
  459. tf.equal(self.scalar_output, 0.0),
  460. tf.equal(
  461. tf.reduce_sum(tf.reduce_sum(self.batch_lookup_answer, 1), 1),
  462. 0.0)),
  463. tf.ones_like(math_error),
  464. tf.where(
  465. tf.equal(self.scalar_output, 0.0), print_error, math_error))
  466. return error
  467. def batch_process(self):
  468. #Computes loss and fraction of correct examples in a batch.
  469. self.params_unit = nn_utils.apply_dropout(
  470. self.params["unit"], self.utility.FLAGS.dropout, self.mode)
  471. batch_size = self.batch_size
  472. max_passes = self.max_passes
  473. num_timesteps = 1
  474. max_elements = self.max_elements
  475. select = tf.cast(
  476. tf.fill([self.batch_size, max_elements], 1.0), self.data_type)
  477. hprev = tf.cast(
  478. tf.fill([self.batch_size, self.embedding_dims], 0.0),
  479. self.data_type) #running sum of the hidden states of the model
  480. output = tf.cast(tf.fill([self.batch_size, 1], 0.0),
  481. self.data_type) #output of the model
  482. correct = tf.cast(
  483. tf.fill([1], 0.0), self.data_type
  484. ) #to compute accuracy, returns number of correct examples for this batch
  485. total_error = 0.0
  486. prev_select_1 = tf.zeros_like(select)
  487. self.create_summary_embeddings()
  488. self.get_column_hidden_vectors()
  489. #get question embedding
  490. question_embedding, hidden_vectors = self.LSTM_question_embedding(
  491. self.batch_question, self.question_length)
  492. #compute arguments for comparison operation
  493. greater_question_number, lesser_question_number, geq_question_number, leq_question_number = self.question_number_softmax(
  494. hidden_vectors)
  495. self.init_select_greater = tf.cast(
  496. tf.greater(self.full_processed_column,
  497. tf.expand_dims(greater_question_number, 2)), self.
  498. data_type) * self.select_bad_number_mask #bs * max_cols * max_elements
  499. self.init_select_lesser = tf.cast(
  500. tf.less(self.full_processed_column,
  501. tf.expand_dims(lesser_question_number, 2)), self.
  502. data_type) * self.select_bad_number_mask #bs * max_cols * max_elements
  503. self.init_select_geq = tf.cast(
  504. tf.greater_equal(self.full_processed_column,
  505. tf.expand_dims(geq_question_number, 2)), self.
  506. data_type) * self.select_bad_number_mask #bs * max_cols * max_elements
  507. self.init_select_leq = tf.cast(
  508. tf.less_equal(self.full_processed_column,
  509. tf.expand_dims(leq_question_number, 2)), self.
  510. data_type) * self.select_bad_number_mask #bs * max_cols * max_elements
  511. self.init_select_word_match = 0
  512. if (self.utility.FLAGS.rnn_dropout > 0.0):
  513. if (self.mode == "train"):
  514. history_rnn_dropout_mask = tf.cast(
  515. tf.random_uniform(
  516. tf.shape(hprev), minval=0.0, maxval=1.0) <
  517. self.utility.FLAGS.rnn_dropout,
  518. self.data_type) / self.utility.FLAGS.rnn_dropout
  519. else:
  520. history_rnn_dropout_mask = tf.ones_like(hprev)
  521. select = select * self.select_whole_mask
  522. self.batch_log_prob = tf.zeros([self.batch_size], dtype=self.data_type)
  523. #Perform max_passes and at each pass select operation and column
  524. for curr_pass in range(max_passes):
  525. print "step: ", curr_pass
  526. output, select, softmax, soft_softmax, column_softmax, soft_column_softmax = self.one_pass(
  527. select, question_embedding, hidden_vectors, hprev, prev_select_1,
  528. curr_pass)
  529. prev_select_1 = select
  530. #compute input to history RNN
  531. input_op = tf.transpose(
  532. tf.matmul(
  533. tf.transpose(self.params_unit), tf.transpose(
  534. soft_softmax))) #weighted average of emebdding of operations
  535. input_col = tf.reduce_sum(
  536. tf.expand_dims(soft_column_softmax, 2) *
  537. self.full_column_hidden_vectors, 1)
  538. history_input = tf.concat(axis=1, values=[input_op, input_col])
  539. history_input = nn_utils.apply_dropout(
  540. history_input, self.utility.FLAGS.dropout, self.mode)
  541. hprev = self.history_recurrent_step(history_input, hprev)
  542. if (self.utility.FLAGS.rnn_dropout > 0.0):
  543. hprev = hprev * history_rnn_dropout_mask
  544. self.scalar_output = output
  545. error = self.error_computation()
  546. cond = tf.less(error, 0.0001, name="cond")
  547. correct_add = tf.where(
  548. cond, tf.fill(tf.shape(cond), 1.0), tf.fill(tf.shape(cond), 0.0))
  549. correct = tf.reduce_sum(correct_add)
  550. error = error / batch_size
  551. total_error = tf.reduce_sum(error)
  552. total_correct = correct / batch_size
  553. return total_error, total_correct
  554. def compute_error(self):
  555. #Sets mask variables and performs batch processing
  556. self.batch_gold_select = self.batch_print_answer > 0.0
  557. self.full_column_mask = tf.concat(
  558. axis=1, values=[self.batch_number_column_mask, self.batch_word_column_mask])
  559. self.full_processed_column = tf.concat(
  560. axis=1,
  561. values=[self.batch_processed_number_column, self.batch_processed_word_column])
  562. self.full_processed_sorted_index_column = tf.concat(axis=1, values=[
  563. self.batch_processed_sorted_index_number_column,
  564. self.batch_processed_sorted_index_word_column
  565. ])
  566. self.select_bad_number_mask = tf.cast(
  567. tf.logical_and(
  568. tf.not_equal(self.full_processed_column,
  569. self.utility.FLAGS.pad_int),
  570. tf.not_equal(self.full_processed_column,
  571. self.utility.FLAGS.bad_number_pre_process)),
  572. self.data_type)
  573. self.select_mask = tf.cast(
  574. tf.logical_not(
  575. tf.equal(self.batch_number_column, self.utility.FLAGS.pad_int)),
  576. self.data_type)
  577. self.select_word_mask = tf.cast(
  578. tf.logical_not(
  579. tf.equal(self.batch_word_column_entry_mask,
  580. self.utility.dummy_token_id)), self.data_type)
  581. self.select_full_mask = tf.concat(
  582. axis=1, values=[self.select_mask, self.select_word_mask])
  583. self.select_whole_mask = tf.maximum(
  584. tf.reshape(
  585. tf.slice(self.select_mask, [0, 0, 0],
  586. [self.batch_size, 1, self.max_elements]),
  587. [self.batch_size, self.max_elements]),
  588. tf.reshape(
  589. tf.slice(self.select_word_mask, [0, 0, 0],
  590. [self.batch_size, 1, self.max_elements]),
  591. [self.batch_size, self.max_elements]))
  592. self.invert_select_full_mask = tf.cast(
  593. tf.concat(axis=1, values=[
  594. tf.equal(self.batch_number_column, self.utility.FLAGS.pad_int),
  595. tf.equal(self.batch_word_column_entry_mask,
  596. self.utility.dummy_token_id)
  597. ]), self.data_type)
  598. self.batch_lookup_answer = tf.zeros(tf.shape(self.batch_gold_select))
  599. self.reset_select = self.select_whole_mask
  600. self.rows = tf.reduce_sum(self.select_whole_mask, 1)
  601. self.num_entries = tf.reshape(
  602. tf.reduce_sum(tf.reduce_sum(self.select_full_mask, 1), 1),
  603. [self.batch_size])
  604. self.final_error, self.final_correct = self.batch_process()
  605. return self.final_error
  606. def create_graph(self, params, global_step):
  607. #Creates the graph to compute error, gradient computation and updates parameters
  608. self.params = params
  609. batch_size = self.batch_size
  610. learning_rate = tf.cast(self.utility.FLAGS.learning_rate, self.data_type)
  611. self.total_cost = self.compute_error()
  612. optimize_params = self.params.values()
  613. optimize_names = self.params.keys()
  614. print "optimize params ", optimize_names
  615. if (self.utility.FLAGS.l2_regularizer > 0.0):
  616. reg_cost = 0.0
  617. for ind_param in self.params.keys():
  618. reg_cost += tf.nn.l2_loss(self.params[ind_param])
  619. self.total_cost += self.utility.FLAGS.l2_regularizer * reg_cost
  620. grads = tf.gradients(self.total_cost, optimize_params, name="gradients")
  621. grad_norm = 0.0
  622. for p, name in zip(grads, optimize_names):
  623. print "grads: ", p, name
  624. if isinstance(p, tf.IndexedSlices):
  625. grad_norm += tf.reduce_sum(p.values * p.values)
  626. elif not (p == None):
  627. grad_norm += tf.reduce_sum(p * p)
  628. grad_norm = tf.sqrt(grad_norm)
  629. max_grad_norm = np.float32(self.utility.FLAGS.clip_gradients).astype(
  630. self.utility.np_data_type[self.utility.FLAGS.data_type])
  631. grad_scale = tf.minimum(
  632. tf.cast(1.0, self.data_type), max_grad_norm / grad_norm)
  633. clipped_grads = list()
  634. for p in grads:
  635. if isinstance(p, tf.IndexedSlices):
  636. tmp = p.values * grad_scale
  637. clipped_grads.append(tf.IndexedSlices(tmp, p.indices))
  638. elif not (p == None):
  639. clipped_grads.append(p * grad_scale)
  640. else:
  641. clipped_grads.append(p)
  642. grads = clipped_grads
  643. self.global_step = global_step
  644. params_list = self.params.values()
  645. params_list.append(self.global_step)
  646. adam = tf.train.AdamOptimizer(
  647. learning_rate,
  648. epsilon=tf.cast(self.utility.FLAGS.eps, self.data_type),
  649. use_locking=True)
  650. self.step = adam.apply_gradients(zip(grads, optimize_params),
  651. global_step=self.global_step)
  652. self.init_op = tf.global_variables_initializer()