data_utils.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Functions for constructing vocabulary, converting the examples to integer format and building the required masks for batch computation Author: aneelakantan (Arvind Neelakantan)
  16. """
  17. import copy
  18. import numbers
  19. import numpy as np
  20. import wiki_data
  21. def return_index(a):
  22. for i in range(len(a)):
  23. if (a[i] == 1.0):
  24. return i
  25. def construct_vocab(data, utility, add_word=False):
  26. ans = []
  27. for example in data:
  28. sent = ""
  29. for word in example.question:
  30. if (not (isinstance(word, numbers.Number))):
  31. sent += word + " "
  32. example.original_nc = copy.deepcopy(example.number_columns)
  33. example.original_wc = copy.deepcopy(example.word_columns)
  34. example.original_nc_names = copy.deepcopy(example.number_column_names)
  35. example.original_wc_names = copy.deepcopy(example.word_column_names)
  36. if (add_word):
  37. continue
  38. number_found = 0
  39. if (not (example.is_bad_example)):
  40. for word in example.question:
  41. if (isinstance(word, numbers.Number)):
  42. number_found += 1
  43. else:
  44. if (not (utility.word_ids.has_key(word))):
  45. utility.words.append(word)
  46. utility.word_count[word] = 1
  47. utility.word_ids[word] = len(utility.word_ids)
  48. utility.reverse_word_ids[utility.word_ids[word]] = word
  49. else:
  50. utility.word_count[word] += 1
  51. for col_name in example.word_column_names:
  52. for word in col_name:
  53. if (isinstance(word, numbers.Number)):
  54. number_found += 1
  55. else:
  56. if (not (utility.word_ids.has_key(word))):
  57. utility.words.append(word)
  58. utility.word_count[word] = 1
  59. utility.word_ids[word] = len(utility.word_ids)
  60. utility.reverse_word_ids[utility.word_ids[word]] = word
  61. else:
  62. utility.word_count[word] += 1
  63. for col_name in example.number_column_names:
  64. for word in col_name:
  65. if (isinstance(word, numbers.Number)):
  66. number_found += 1
  67. else:
  68. if (not (utility.word_ids.has_key(word))):
  69. utility.words.append(word)
  70. utility.word_count[word] = 1
  71. utility.word_ids[word] = len(utility.word_ids)
  72. utility.reverse_word_ids[utility.word_ids[word]] = word
  73. else:
  74. utility.word_count[word] += 1
  75. def word_lookup(word, utility):
  76. if (utility.word_ids.has_key(word)):
  77. return word
  78. else:
  79. return utility.unk_token
  80. def convert_to_int_2d_and_pad(a, utility):
  81. ans = []
  82. #print a
  83. for b in a:
  84. temp = []
  85. if (len(b) > utility.FLAGS.max_entry_length):
  86. b = b[0:utility.FLAGS.max_entry_length]
  87. for remaining in range(len(b), utility.FLAGS.max_entry_length):
  88. b.append(utility.dummy_token)
  89. assert len(b) == utility.FLAGS.max_entry_length
  90. for word in b:
  91. temp.append(utility.word_ids[word_lookup(word, utility)])
  92. ans.append(temp)
  93. #print ans
  94. return ans
  95. def convert_to_bool_and_pad(a, utility):
  96. a = a.tolist()
  97. for i in range(len(a)):
  98. for j in range(len(a[i])):
  99. if (a[i][j] < 1):
  100. a[i][j] = False
  101. else:
  102. a[i][j] = True
  103. a[i] = a[i] + [False] * (utility.FLAGS.max_elements - len(a[i]))
  104. return a
  105. seen_tables = {}
  106. def partial_match(question, table, number):
  107. answer = []
  108. match = {}
  109. for i in range(len(table)):
  110. temp = []
  111. for j in range(len(table[i])):
  112. temp.append(0)
  113. answer.append(temp)
  114. for i in range(len(table)):
  115. for j in range(len(table[i])):
  116. for word in question:
  117. if (number):
  118. if (word == table[i][j]):
  119. answer[i][j] = 1.0
  120. match[i] = 1.0
  121. else:
  122. if (word in table[i][j]):
  123. answer[i][j] = 1.0
  124. match[i] = 1.0
  125. return answer, match
  126. def exact_match(question, table, number):
  127. #performs exact match operation
  128. answer = []
  129. match = {}
  130. matched_indices = []
  131. for i in range(len(table)):
  132. temp = []
  133. for j in range(len(table[i])):
  134. temp.append(0)
  135. answer.append(temp)
  136. for i in range(len(table)):
  137. for j in range(len(table[i])):
  138. if (number):
  139. for word in question:
  140. if (word == table[i][j]):
  141. match[i] = 1.0
  142. answer[i][j] = 1.0
  143. else:
  144. table_entry = table[i][j]
  145. for k in range(len(question)):
  146. if (k + len(table_entry) <= len(question)):
  147. if (table_entry == question[k:(k + len(table_entry))]):
  148. #if(len(table_entry) == 1):
  149. #print "match: ", table_entry, question
  150. match[i] = 1.0
  151. answer[i][j] = 1.0
  152. matched_indices.append((k, len(table_entry)))
  153. return answer, match, matched_indices
  154. def partial_column_match(question, table, number):
  155. answer = []
  156. for i in range(len(table)):
  157. answer.append(0)
  158. for i in range(len(table)):
  159. for word in question:
  160. if (word in table[i]):
  161. answer[i] = 1.0
  162. return answer
  163. def exact_column_match(question, table, number):
  164. #performs exact match on column names
  165. answer = []
  166. matched_indices = []
  167. for i in range(len(table)):
  168. answer.append(0)
  169. for i in range(len(table)):
  170. table_entry = table[i]
  171. for k in range(len(question)):
  172. if (k + len(table_entry) <= len(question)):
  173. if (table_entry == question[k:(k + len(table_entry))]):
  174. answer[i] = 1.0
  175. matched_indices.append((k, len(table_entry)))
  176. return answer, matched_indices
  177. def get_max_entry(a):
  178. e = {}
  179. for w in a:
  180. if (w != "UNK, "):
  181. if (e.has_key(w)):
  182. e[w] += 1
  183. else:
  184. e[w] = 1
  185. if (len(e) > 0):
  186. (key, val) = sorted(e.items(), key=lambda x: -1 * x[1])[0]
  187. if (val > 1):
  188. return key
  189. else:
  190. return -1.0
  191. else:
  192. return -1.0
  193. def list_join(a):
  194. ans = ""
  195. for w in a:
  196. ans += str(w) + ", "
  197. return ans
  198. def group_by_max(table, number):
  199. #computes the most frequently occuring entry in a column
  200. answer = []
  201. for i in range(len(table)):
  202. temp = []
  203. for j in range(len(table[i])):
  204. temp.append(0)
  205. answer.append(temp)
  206. for i in range(len(table)):
  207. if (number):
  208. curr = table[i]
  209. else:
  210. curr = [list_join(w) for w in table[i]]
  211. max_entry = get_max_entry(curr)
  212. #print i, max_entry
  213. for j in range(len(curr)):
  214. if (max_entry == curr[j]):
  215. answer[i][j] = 1.0
  216. else:
  217. answer[i][j] = 0.0
  218. return answer
  219. def pick_one(a):
  220. for i in range(len(a)):
  221. if (1.0 in a[i]):
  222. return True
  223. return False
  224. def check_processed_cols(col, utility):
  225. return True in [
  226. True for y in col
  227. if (y != utility.FLAGS.pad_int and y !=
  228. utility.FLAGS.bad_number_pre_process)
  229. ]
  230. def complete_wiki_processing(data, utility, train=True):
  231. #convert to integers and padding
  232. processed_data = []
  233. num_bad_examples = 0
  234. for example in data:
  235. number_found = 0
  236. if (example.is_bad_example):
  237. num_bad_examples += 1
  238. if (not (example.is_bad_example)):
  239. example.string_question = example.question[:]
  240. #entry match
  241. example.processed_number_columns = example.processed_number_columns[:]
  242. example.processed_word_columns = example.processed_word_columns[:]
  243. example.word_exact_match, word_match, matched_indices = exact_match(
  244. example.string_question, example.original_wc, number=False)
  245. example.number_exact_match, number_match, _ = exact_match(
  246. example.string_question, example.original_nc, number=True)
  247. if (not (pick_one(example.word_exact_match)) and not (
  248. pick_one(example.number_exact_match))):
  249. assert len(word_match) == 0
  250. assert len(number_match) == 0
  251. example.word_exact_match, word_match = partial_match(
  252. example.string_question, example.original_wc, number=False)
  253. #group by max
  254. example.word_group_by_max = group_by_max(example.original_wc, False)
  255. example.number_group_by_max = group_by_max(example.original_nc, True)
  256. #column name match
  257. example.word_column_exact_match, wcol_matched_indices = exact_column_match(
  258. example.string_question, example.original_wc_names, number=False)
  259. example.number_column_exact_match, ncol_matched_indices = exact_column_match(
  260. example.string_question, example.original_nc_names, number=False)
  261. if (not (1.0 in example.word_column_exact_match) and not (
  262. 1.0 in example.number_column_exact_match)):
  263. example.word_column_exact_match = partial_column_match(
  264. example.string_question, example.original_wc_names, number=False)
  265. example.number_column_exact_match = partial_column_match(
  266. example.string_question, example.original_nc_names, number=False)
  267. if (len(word_match) > 0 or len(number_match) > 0):
  268. example.question.append(utility.entry_match_token)
  269. if (1.0 in example.word_column_exact_match or
  270. 1.0 in example.number_column_exact_match):
  271. example.question.append(utility.column_match_token)
  272. example.string_question = example.question[:]
  273. example.number_lookup_matrix = np.transpose(
  274. example.number_lookup_matrix)[:]
  275. example.word_lookup_matrix = np.transpose(example.word_lookup_matrix)[:]
  276. example.columns = example.number_columns[:]
  277. example.word_columns = example.word_columns[:]
  278. example.len_total_cols = len(example.word_column_names) + len(
  279. example.number_column_names)
  280. example.column_names = example.number_column_names[:]
  281. example.word_column_names = example.word_column_names[:]
  282. example.string_column_names = example.number_column_names[:]
  283. example.string_word_column_names = example.word_column_names[:]
  284. example.sorted_number_index = []
  285. example.sorted_word_index = []
  286. example.column_mask = []
  287. example.word_column_mask = []
  288. example.processed_column_mask = []
  289. example.processed_word_column_mask = []
  290. example.word_column_entry_mask = []
  291. example.question_attention_mask = []
  292. example.question_number = example.question_number_1 = -1
  293. example.question_attention_mask = []
  294. example.ordinal_question = []
  295. example.ordinal_question_one = []
  296. new_question = []
  297. if (len(example.number_columns) > 0):
  298. example.len_col = len(example.number_columns[0])
  299. else:
  300. example.len_col = len(example.word_columns[0])
  301. for (start, length) in matched_indices:
  302. for j in range(length):
  303. example.question[start + j] = utility.unk_token
  304. #print example.question
  305. for word in example.question:
  306. if (isinstance(word, numbers.Number) or wiki_data.is_date(word)):
  307. if (not (isinstance(word, numbers.Number)) and
  308. wiki_data.is_date(word)):
  309. word = word.replace("X", "").replace("-", "")
  310. number_found += 1
  311. if (number_found == 1):
  312. example.question_number = word
  313. if (len(example.ordinal_question) > 0):
  314. example.ordinal_question[len(example.ordinal_question) - 1] = 1.0
  315. else:
  316. example.ordinal_question.append(1.0)
  317. elif (number_found == 2):
  318. example.question_number_1 = word
  319. if (len(example.ordinal_question_one) > 0):
  320. example.ordinal_question_one[len(example.ordinal_question_one) -
  321. 1] = 1.0
  322. else:
  323. example.ordinal_question_one.append(1.0)
  324. else:
  325. new_question.append(word)
  326. example.ordinal_question.append(0.0)
  327. example.ordinal_question_one.append(0.0)
  328. example.question = [
  329. utility.word_ids[word_lookup(w, utility)] for w in new_question
  330. ]
  331. example.question_attention_mask = [0.0] * len(example.question)
  332. #when the first question number occurs before a word
  333. example.ordinal_question = example.ordinal_question[0:len(
  334. example.question)]
  335. example.ordinal_question_one = example.ordinal_question_one[0:len(
  336. example.question)]
  337. #question-padding
  338. example.question = [utility.word_ids[utility.dummy_token]] * (
  339. utility.FLAGS.question_length - len(example.question)
  340. ) + example.question
  341. example.question_attention_mask = [-10000.0] * (
  342. utility.FLAGS.question_length - len(example.question_attention_mask)
  343. ) + example.question_attention_mask
  344. example.ordinal_question = [0.0] * (utility.FLAGS.question_length -
  345. len(example.ordinal_question)
  346. ) + example.ordinal_question
  347. example.ordinal_question_one = [0.0] * (utility.FLAGS.question_length -
  348. len(example.ordinal_question_one)
  349. ) + example.ordinal_question_one
  350. if (True):
  351. #number columns and related-padding
  352. num_cols = len(example.columns)
  353. start = 0
  354. for column in example.number_columns:
  355. if (check_processed_cols(example.processed_number_columns[start],
  356. utility)):
  357. example.processed_column_mask.append(0.0)
  358. sorted_index = sorted(
  359. range(len(example.processed_number_columns[start])),
  360. key=lambda k: example.processed_number_columns[start][k],
  361. reverse=True)
  362. sorted_index = sorted_index + [utility.FLAGS.pad_int] * (
  363. utility.FLAGS.max_elements - len(sorted_index))
  364. example.sorted_number_index.append(sorted_index)
  365. example.columns[start] = column + [utility.FLAGS.pad_int] * (
  366. utility.FLAGS.max_elements - len(column))
  367. example.processed_number_columns[start] += [utility.FLAGS.pad_int] * (
  368. utility.FLAGS.max_elements -
  369. len(example.processed_number_columns[start]))
  370. start += 1
  371. example.column_mask.append(0.0)
  372. for remaining in range(num_cols, utility.FLAGS.max_number_cols):
  373. example.sorted_number_index.append([utility.FLAGS.pad_int] *
  374. (utility.FLAGS.max_elements))
  375. example.columns.append([utility.FLAGS.pad_int] *
  376. (utility.FLAGS.max_elements))
  377. example.processed_number_columns.append([utility.FLAGS.pad_int] *
  378. (utility.FLAGS.max_elements))
  379. example.number_exact_match.append([0.0] *
  380. (utility.FLAGS.max_elements))
  381. example.number_group_by_max.append([0.0] *
  382. (utility.FLAGS.max_elements))
  383. example.column_mask.append(-100000000.0)
  384. example.processed_column_mask.append(-100000000.0)
  385. example.number_column_exact_match.append(0.0)
  386. example.column_names.append([utility.dummy_token])
  387. #word column and related-padding
  388. start = 0
  389. word_num_cols = len(example.word_columns)
  390. for column in example.word_columns:
  391. if (check_processed_cols(example.processed_word_columns[start],
  392. utility)):
  393. example.processed_word_column_mask.append(0.0)
  394. sorted_index = sorted(
  395. range(len(example.processed_word_columns[start])),
  396. key=lambda k: example.processed_word_columns[start][k],
  397. reverse=True)
  398. sorted_index = sorted_index + [utility.FLAGS.pad_int] * (
  399. utility.FLAGS.max_elements - len(sorted_index))
  400. example.sorted_word_index.append(sorted_index)
  401. column = convert_to_int_2d_and_pad(column, utility)
  402. example.word_columns[start] = column + [[
  403. utility.word_ids[utility.dummy_token]
  404. ] * utility.FLAGS.max_entry_length] * (utility.FLAGS.max_elements -
  405. len(column))
  406. example.processed_word_columns[start] += [utility.FLAGS.pad_int] * (
  407. utility.FLAGS.max_elements -
  408. len(example.processed_word_columns[start]))
  409. example.word_column_entry_mask.append([0] * len(column) + [
  410. utility.word_ids[utility.dummy_token]
  411. ] * (utility.FLAGS.max_elements - len(column)))
  412. start += 1
  413. example.word_column_mask.append(0.0)
  414. for remaining in range(word_num_cols, utility.FLAGS.max_word_cols):
  415. example.sorted_word_index.append([utility.FLAGS.pad_int] *
  416. (utility.FLAGS.max_elements))
  417. example.word_columns.append([[utility.word_ids[utility.dummy_token]] *
  418. utility.FLAGS.max_entry_length] *
  419. (utility.FLAGS.max_elements))
  420. example.word_column_entry_mask.append(
  421. [utility.word_ids[utility.dummy_token]] *
  422. (utility.FLAGS.max_elements))
  423. example.word_exact_match.append([0.0] * (utility.FLAGS.max_elements))
  424. example.word_group_by_max.append([0.0] * (utility.FLAGS.max_elements))
  425. example.processed_word_columns.append([utility.FLAGS.pad_int] *
  426. (utility.FLAGS.max_elements))
  427. example.word_column_mask.append(-100000000.0)
  428. example.processed_word_column_mask.append(-100000000.0)
  429. example.word_column_exact_match.append(0.0)
  430. example.word_column_names.append([utility.dummy_token] *
  431. utility.FLAGS.max_entry_length)
  432. seen_tables[example.table_key] = 1
  433. #convert column and word column names to integers
  434. example.column_ids = convert_to_int_2d_and_pad(example.column_names,
  435. utility)
  436. example.word_column_ids = convert_to_int_2d_and_pad(
  437. example.word_column_names, utility)
  438. for i_em in range(len(example.number_exact_match)):
  439. example.number_exact_match[i_em] = example.number_exact_match[
  440. i_em] + [0.0] * (utility.FLAGS.max_elements -
  441. len(example.number_exact_match[i_em]))
  442. example.number_group_by_max[i_em] = example.number_group_by_max[
  443. i_em] + [0.0] * (utility.FLAGS.max_elements -
  444. len(example.number_group_by_max[i_em]))
  445. for i_em in range(len(example.word_exact_match)):
  446. example.word_exact_match[i_em] = example.word_exact_match[
  447. i_em] + [0.0] * (utility.FLAGS.max_elements -
  448. len(example.word_exact_match[i_em]))
  449. example.word_group_by_max[i_em] = example.word_group_by_max[
  450. i_em] + [0.0] * (utility.FLAGS.max_elements -
  451. len(example.word_group_by_max[i_em]))
  452. example.exact_match = example.number_exact_match + example.word_exact_match
  453. example.group_by_max = example.number_group_by_max + example.word_group_by_max
  454. example.exact_column_match = example.number_column_exact_match + example.word_column_exact_match
  455. #answer and related mask, padding
  456. if (example.is_lookup):
  457. example.answer = example.calc_answer
  458. example.number_print_answer = example.number_lookup_matrix.tolist()
  459. example.word_print_answer = example.word_lookup_matrix.tolist()
  460. for i_answer in range(len(example.number_print_answer)):
  461. example.number_print_answer[i_answer] = example.number_print_answer[
  462. i_answer] + [0.0] * (utility.FLAGS.max_elements -
  463. len(example.number_print_answer[i_answer]))
  464. for i_answer in range(len(example.word_print_answer)):
  465. example.word_print_answer[i_answer] = example.word_print_answer[
  466. i_answer] + [0.0] * (utility.FLAGS.max_elements -
  467. len(example.word_print_answer[i_answer]))
  468. example.number_lookup_matrix = convert_to_bool_and_pad(
  469. example.number_lookup_matrix, utility)
  470. example.word_lookup_matrix = convert_to_bool_and_pad(
  471. example.word_lookup_matrix, utility)
  472. for remaining in range(num_cols, utility.FLAGS.max_number_cols):
  473. example.number_lookup_matrix.append([False] *
  474. utility.FLAGS.max_elements)
  475. example.number_print_answer.append([0.0] * utility.FLAGS.max_elements)
  476. for remaining in range(word_num_cols, utility.FLAGS.max_word_cols):
  477. example.word_lookup_matrix.append([False] *
  478. utility.FLAGS.max_elements)
  479. example.word_print_answer.append([0.0] * utility.FLAGS.max_elements)
  480. example.print_answer = example.number_print_answer + example.word_print_answer
  481. else:
  482. example.answer = example.calc_answer
  483. example.print_answer = [[0.0] * (utility.FLAGS.max_elements)] * (
  484. utility.FLAGS.max_number_cols + utility.FLAGS.max_word_cols)
  485. #question_number masks
  486. if (example.question_number == -1):
  487. example.question_number_mask = np.zeros([utility.FLAGS.max_elements])
  488. else:
  489. example.question_number_mask = np.ones([utility.FLAGS.max_elements])
  490. if (example.question_number_1 == -1):
  491. example.question_number_one_mask = -10000.0
  492. else:
  493. example.question_number_one_mask = np.float64(0.0)
  494. if (example.len_col > utility.FLAGS.max_elements):
  495. continue
  496. processed_data.append(example)
  497. return processed_data
  498. def add_special_words(utility):
  499. utility.words.append(utility.entry_match_token)
  500. utility.word_ids[utility.entry_match_token] = len(utility.word_ids)
  501. utility.reverse_word_ids[utility.word_ids[
  502. utility.entry_match_token]] = utility.entry_match_token
  503. utility.entry_match_token_id = utility.word_ids[utility.entry_match_token]
  504. print "entry match token: ", utility.word_ids[
  505. utility.entry_match_token], utility.entry_match_token_id
  506. utility.words.append(utility.column_match_token)
  507. utility.word_ids[utility.column_match_token] = len(utility.word_ids)
  508. utility.reverse_word_ids[utility.word_ids[
  509. utility.column_match_token]] = utility.column_match_token
  510. utility.column_match_token_id = utility.word_ids[utility.column_match_token]
  511. print "entry match token: ", utility.word_ids[
  512. utility.column_match_token], utility.column_match_token_id
  513. utility.words.append(utility.dummy_token)
  514. utility.word_ids[utility.dummy_token] = len(utility.word_ids)
  515. utility.reverse_word_ids[utility.word_ids[
  516. utility.dummy_token]] = utility.dummy_token
  517. utility.dummy_token_id = utility.word_ids[utility.dummy_token]
  518. utility.words.append(utility.unk_token)
  519. utility.word_ids[utility.unk_token] = len(utility.word_ids)
  520. utility.reverse_word_ids[utility.word_ids[
  521. utility.unk_token]] = utility.unk_token
  522. def perform_word_cutoff(utility):
  523. if (utility.FLAGS.word_cutoff > 0):
  524. for word in utility.word_ids.keys():
  525. if (utility.word_count.has_key(word) and utility.word_count[word] <
  526. utility.FLAGS.word_cutoff and word != utility.unk_token and
  527. word != utility.dummy_token and word != utility.entry_match_token and
  528. word != utility.column_match_token):
  529. utility.word_ids.pop(word)
  530. utility.words.remove(word)
  531. def word_dropout(question, utility):
  532. if (utility.FLAGS.word_dropout_prob > 0.0):
  533. new_question = []
  534. for i in range(len(question)):
  535. if (question[i] != utility.dummy_token_id and
  536. utility.random.random() > utility.FLAGS.word_dropout_prob):
  537. new_question.append(utility.word_ids[utility.unk_token])
  538. else:
  539. new_question.append(question[i])
  540. return new_question
  541. else:
  542. return question
  543. def generate_feed_dict(data, curr, batch_size, gr, train=False, utility=None):
  544. #prepare feed dict dictionary
  545. feed_dict = {}
  546. feed_examples = []
  547. for j in range(batch_size):
  548. feed_examples.append(data[curr + j])
  549. if (train):
  550. feed_dict[gr.batch_question] = [
  551. word_dropout(feed_examples[j].question, utility)
  552. for j in range(batch_size)
  553. ]
  554. else:
  555. feed_dict[gr.batch_question] = [
  556. feed_examples[j].question for j in range(batch_size)
  557. ]
  558. feed_dict[gr.batch_question_attention_mask] = [
  559. feed_examples[j].question_attention_mask for j in range(batch_size)
  560. ]
  561. feed_dict[
  562. gr.batch_answer] = [feed_examples[j].answer for j in range(batch_size)]
  563. feed_dict[gr.batch_number_column] = [
  564. feed_examples[j].columns for j in range(batch_size)
  565. ]
  566. feed_dict[gr.batch_processed_number_column] = [
  567. feed_examples[j].processed_number_columns for j in range(batch_size)
  568. ]
  569. feed_dict[gr.batch_processed_sorted_index_number_column] = [
  570. feed_examples[j].sorted_number_index for j in range(batch_size)
  571. ]
  572. feed_dict[gr.batch_processed_sorted_index_word_column] = [
  573. feed_examples[j].sorted_word_index for j in range(batch_size)
  574. ]
  575. feed_dict[gr.batch_question_number] = np.array(
  576. [feed_examples[j].question_number for j in range(batch_size)]).reshape(
  577. (batch_size, 1))
  578. feed_dict[gr.batch_question_number_one] = np.array(
  579. [feed_examples[j].question_number_1 for j in range(batch_size)]).reshape(
  580. (batch_size, 1))
  581. feed_dict[gr.batch_question_number_mask] = [
  582. feed_examples[j].question_number_mask for j in range(batch_size)
  583. ]
  584. feed_dict[gr.batch_question_number_one_mask] = np.array(
  585. [feed_examples[j].question_number_one_mask for j in range(batch_size)
  586. ]).reshape((batch_size, 1))
  587. feed_dict[gr.batch_print_answer] = [
  588. feed_examples[j].print_answer for j in range(batch_size)
  589. ]
  590. feed_dict[gr.batch_exact_match] = [
  591. feed_examples[j].exact_match for j in range(batch_size)
  592. ]
  593. feed_dict[gr.batch_group_by_max] = [
  594. feed_examples[j].group_by_max for j in range(batch_size)
  595. ]
  596. feed_dict[gr.batch_column_exact_match] = [
  597. feed_examples[j].exact_column_match for j in range(batch_size)
  598. ]
  599. feed_dict[gr.batch_ordinal_question] = [
  600. feed_examples[j].ordinal_question for j in range(batch_size)
  601. ]
  602. feed_dict[gr.batch_ordinal_question_one] = [
  603. feed_examples[j].ordinal_question_one for j in range(batch_size)
  604. ]
  605. feed_dict[gr.batch_number_column_mask] = [
  606. feed_examples[j].column_mask for j in range(batch_size)
  607. ]
  608. feed_dict[gr.batch_number_column_names] = [
  609. feed_examples[j].column_ids for j in range(batch_size)
  610. ]
  611. feed_dict[gr.batch_processed_word_column] = [
  612. feed_examples[j].processed_word_columns for j in range(batch_size)
  613. ]
  614. feed_dict[gr.batch_word_column_mask] = [
  615. feed_examples[j].word_column_mask for j in range(batch_size)
  616. ]
  617. feed_dict[gr.batch_word_column_names] = [
  618. feed_examples[j].word_column_ids for j in range(batch_size)
  619. ]
  620. feed_dict[gr.batch_word_column_entry_mask] = [
  621. feed_examples[j].word_column_entry_mask for j in range(batch_size)
  622. ]
  623. return feed_dict