wiki_data.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Loads the WikiQuestions dataset.
  16. An example consists of question, table. Additionally, we store the processed
  17. columns which store the entries after performing number, date and other
  18. preprocessing as done in the baseline.
  19. columns, column names and processed columns are split into word and number
  20. columns.
  21. lookup answer (or matrix) is also split into number and word lookup matrix
  22. Author: aneelakantan (Arvind Neelakantan)
  23. """
  24. import math
  25. import os
  26. import re
  27. import numpy as np
  28. import unicodedata as ud
  29. import tensorflow as tf
  30. bad_number = -200000.0 #number that is added to a corrupted table entry in a number column
  31. def is_nan_or_inf(number):
  32. return math.isnan(number) or math.isinf(number)
  33. def strip_accents(s):
  34. u = unicode(s, "utf-8")
  35. u_new = ''.join(c for c in ud.normalize('NFKD', u) if ud.category(c) != 'Mn')
  36. return u_new.encode("utf-8")
  37. def correct_unicode(string):
  38. string = strip_accents(string)
  39. string = re.sub("\xc2\xa0", " ", string).strip()
  40. string = re.sub("\xe2\x80\x93", "-", string).strip()
  41. #string = re.sub(ur'[\u0300-\u036F]', "", string)
  42. string = re.sub("‚", ",", string)
  43. string = re.sub("…", "...", string)
  44. #string = re.sub("[·・]", ".", string)
  45. string = re.sub("ˆ", "^", string)
  46. string = re.sub("˜", "~", string)
  47. string = re.sub("‹", "<", string)
  48. string = re.sub("›", ">", string)
  49. #string = re.sub("[‘’´`]", "'", string)
  50. #string = re.sub("[“”«»]", "\"", string)
  51. #string = re.sub("[•†‡]", "", string)
  52. #string = re.sub("[‐‑–—]", "-", string)
  53. string = re.sub(ur'[\u2E00-\uFFFF]', "", string)
  54. string = re.sub("\\s+", " ", string).strip()
  55. return string
  56. def simple_normalize(string):
  57. string = correct_unicode(string)
  58. # Citations
  59. string = re.sub("\[(nb ?)?\d+\]", "", string)
  60. string = re.sub("\*+$", "", string)
  61. # Year in parenthesis
  62. string = re.sub("\(\d* ?-? ?\d*\)", "", string)
  63. string = re.sub("^\"(.*)\"$", "", string)
  64. return string
  65. def full_normalize(string):
  66. #print "an: ", string
  67. string = simple_normalize(string)
  68. # Remove trailing info in brackets
  69. string = re.sub("\[[^\]]*\]", "", string)
  70. # Remove most unicode characters in other languages
  71. string = re.sub(ur'[\u007F-\uFFFF]', "", string.strip())
  72. # Remove trailing info in parenthesis
  73. string = re.sub("\([^)]*\)$", "", string.strip())
  74. string = final_normalize(string)
  75. # Get rid of question marks
  76. string = re.sub("\?", "", string).strip()
  77. # Get rid of trailing colons (usually occur in column titles)
  78. string = re.sub("\:$", " ", string).strip()
  79. # Get rid of slashes
  80. string = re.sub(r"/", " ", string).strip()
  81. string = re.sub(r"\\", " ", string).strip()
  82. # Replace colon, slash, and dash with space
  83. # Note: need better replacement for this when parsing time
  84. string = re.sub(r"\:", " ", string).strip()
  85. string = re.sub("/", " ", string).strip()
  86. string = re.sub("-", " ", string).strip()
  87. # Convert empty strings to UNK
  88. # Important to do this last or near last
  89. if not string:
  90. string = "UNK"
  91. return string
  92. def final_normalize(string):
  93. # Remove leading and trailing whitespace
  94. string = re.sub("\\s+", " ", string).strip()
  95. # Convert entirely to lowercase
  96. string = string.lower()
  97. # Get rid of strangely escaped newline characters
  98. string = re.sub("\\\\n", " ", string).strip()
  99. # Get rid of quotation marks
  100. string = re.sub(r"\"", "", string).strip()
  101. string = re.sub(r"\'", "", string).strip()
  102. string = re.sub(r"`", "", string).strip()
  103. # Get rid of *
  104. string = re.sub("\*", "", string).strip()
  105. return string
  106. def is_number(x):
  107. try:
  108. f = float(x)
  109. return not is_nan_or_inf(f)
  110. except ValueError:
  111. return False
  112. except TypeError:
  113. return False
  114. class WikiExample(object):
  115. def __init__(self, id, question, answer, table_key):
  116. self.question_id = id
  117. self.question = question
  118. self.answer = answer
  119. self.table_key = table_key
  120. self.lookup_matrix = []
  121. self.is_bad_example = False
  122. self.is_word_lookup = False
  123. self.is_ambiguous_word_lookup = False
  124. self.is_number_lookup = False
  125. self.is_number_calc = False
  126. self.is_unknown_answer = False
  127. class TableInfo(object):
  128. def __init__(self, word_columns, word_column_names, word_column_indices,
  129. number_columns, number_column_names, number_column_indices,
  130. processed_word_columns, processed_number_columns, orig_columns):
  131. self.word_columns = word_columns
  132. self.word_column_names = word_column_names
  133. self.word_column_indices = word_column_indices
  134. self.number_columns = number_columns
  135. self.number_column_names = number_column_names
  136. self.number_column_indices = number_column_indices
  137. self.processed_word_columns = processed_word_columns
  138. self.processed_number_columns = processed_number_columns
  139. self.orig_columns = orig_columns
  140. class WikiQuestionLoader(object):
  141. def __init__(self, data_name, root_folder):
  142. self.root_folder = root_folder
  143. self.data_folder = os.path.join(self.root_folder, "data")
  144. self.examples = []
  145. self.data_name = data_name
  146. def num_questions(self):
  147. return len(self.examples)
  148. def load_qa(self):
  149. data_source = os.path.join(self.data_folder, self.data_name)
  150. f = tf.gfile.GFile(data_source, "r")
  151. id_regex = re.compile("\(id ([^\)]*)\)")
  152. for line in f:
  153. id_match = id_regex.search(line)
  154. id = id_match.group(1)
  155. self.examples.append(id)
  156. def load(self):
  157. self.load_qa()
  158. def is_date(word):
  159. if (not (bool(re.search("[a-z0-9]", word, re.IGNORECASE)))):
  160. return False
  161. if (len(word) != 10):
  162. return False
  163. if (word[4] != "-"):
  164. return False
  165. if (word[7] != "-"):
  166. return False
  167. for i in range(len(word)):
  168. if (not (word[i] == "X" or word[i] == "x" or word[i] == "-" or re.search(
  169. "[0-9]", word[i]))):
  170. return False
  171. return True
  172. class WikiQuestionGenerator(object):
  173. def __init__(self, train_name, dev_name, test_name, root_folder):
  174. self.train_name = train_name
  175. self.dev_name = dev_name
  176. self.test_name = test_name
  177. self.train_loader = WikiQuestionLoader(train_name, root_folder)
  178. self.dev_loader = WikiQuestionLoader(dev_name, root_folder)
  179. self.test_loader = WikiQuestionLoader(test_name, root_folder)
  180. self.bad_examples = 0
  181. self.root_folder = root_folder
  182. self.data_folder = os.path.join(self.root_folder, "annotated/data")
  183. self.annotated_examples = {}
  184. self.annotated_tables = {}
  185. self.annotated_word_reject = {}
  186. self.annotated_word_reject["-lrb-"] = 1
  187. self.annotated_word_reject["-rrb-"] = 1
  188. self.annotated_word_reject["UNK"] = 1
  189. def is_money(self, word):
  190. if (not (bool(re.search("[a-z0-9]", word, re.IGNORECASE)))):
  191. return False
  192. for i in range(len(word)):
  193. if (not (word[i] == "E" or word[i] == "." or re.search("[0-9]",
  194. word[i]))):
  195. return False
  196. return True
  197. def remove_consecutive(self, ner_tags, ner_values):
  198. for i in range(len(ner_tags)):
  199. if ((ner_tags[i] == "NUMBER" or ner_tags[i] == "MONEY" or
  200. ner_tags[i] == "PERCENT" or ner_tags[i] == "DATE") and
  201. i + 1 < len(ner_tags) and ner_tags[i] == ner_tags[i + 1] and
  202. ner_values[i] == ner_values[i + 1] and ner_values[i] != ""):
  203. word = ner_values[i]
  204. word = word.replace(">", "").replace("<", "").replace("=", "").replace(
  205. "%", "").replace("~", "").replace("$", "").replace("£", "").replace(
  206. "€", "")
  207. if (re.search("[A-Z]", word) and not (is_date(word)) and not (
  208. self.is_money(word))):
  209. ner_values[i] = "A"
  210. else:
  211. ner_values[i] = ","
  212. return ner_tags, ner_values
  213. def pre_process_sentence(self, tokens, ner_tags, ner_values):
  214. sentence = []
  215. tokens = tokens.split("|")
  216. ner_tags = ner_tags.split("|")
  217. ner_values = ner_values.split("|")
  218. ner_tags, ner_values = self.remove_consecutive(ner_tags, ner_values)
  219. #print "old: ", tokens
  220. for i in range(len(tokens)):
  221. word = tokens[i]
  222. if (ner_values[i] != "" and
  223. (ner_tags[i] == "NUMBER" or ner_tags[i] == "MONEY" or
  224. ner_tags[i] == "PERCENT" or ner_tags[i] == "DATE")):
  225. word = ner_values[i]
  226. word = word.replace(">", "").replace("<", "").replace("=", "").replace(
  227. "%", "").replace("~", "").replace("$", "").replace("£", "").replace(
  228. "€", "")
  229. if (re.search("[A-Z]", word) and not (is_date(word)) and not (
  230. self.is_money(word))):
  231. word = tokens[i]
  232. if (is_number(ner_values[i])):
  233. word = float(ner_values[i])
  234. elif (is_number(word)):
  235. word = float(word)
  236. if (tokens[i] == "score"):
  237. word = "score"
  238. if (is_number(word)):
  239. word = float(word)
  240. if (not (self.annotated_word_reject.has_key(word))):
  241. if (is_number(word) or is_date(word) or self.is_money(word)):
  242. sentence.append(word)
  243. else:
  244. word = full_normalize(word)
  245. if (not (self.annotated_word_reject.has_key(word)) and
  246. bool(re.search("[a-z0-9]", word, re.IGNORECASE))):
  247. m = re.search(",", word)
  248. sentence.append(word.replace(",", ""))
  249. if (len(sentence) == 0):
  250. sentence.append("UNK")
  251. return sentence
  252. def load_annotated_data(self, in_file):
  253. self.annotated_examples = {}
  254. self.annotated_tables = {}
  255. f = tf.gfile.GFile(in_file, "r")
  256. counter = 0
  257. for line in f:
  258. if (counter > 0):
  259. line = line.strip()
  260. (question_id, utterance, context, target_value, tokens, lemma_tokens,
  261. pos_tags, ner_tags, ner_values, target_canon) = line.split("\t")
  262. question = self.pre_process_sentence(tokens, ner_tags, ner_values)
  263. target_canon = target_canon.split("|")
  264. self.annotated_examples[question_id] = WikiExample(
  265. question_id, question, target_canon, context)
  266. self.annotated_tables[context] = []
  267. counter += 1
  268. print "Annotated examples loaded ", len(self.annotated_examples)
  269. f.close()
  270. def is_number_column(self, a):
  271. for w in a:
  272. if (len(w) != 1):
  273. return False
  274. if (not (is_number(w[0]))):
  275. return False
  276. return True
  277. def convert_table(self, table):
  278. answer = []
  279. for i in range(len(table)):
  280. temp = []
  281. for j in range(len(table[i])):
  282. temp.append(" ".join([str(w) for w in table[i][j]]))
  283. answer.append(temp)
  284. return answer
  285. def load_annotated_tables(self):
  286. for table in self.annotated_tables.keys():
  287. annotated_table = table.replace("csv", "annotated")
  288. orig_columns = []
  289. processed_columns = []
  290. f = tf.gfile.GFile(os.path.join(self.root_folder, annotated_table), "r")
  291. counter = 0
  292. for line in f:
  293. if (counter > 0):
  294. line = line.strip()
  295. line = line + "\t" * (13 - len(line.split("\t")))
  296. (row, col, read_id, content, tokens, lemma_tokens, pos_tags, ner_tags,
  297. ner_values, number, date, num2, read_list) = line.split("\t")
  298. counter += 1
  299. f.close()
  300. max_row = int(row)
  301. max_col = int(col)
  302. for i in range(max_col + 1):
  303. orig_columns.append([])
  304. processed_columns.append([])
  305. for j in range(max_row + 1):
  306. orig_columns[i].append(bad_number)
  307. processed_columns[i].append(bad_number)
  308. #print orig_columns
  309. f = tf.gfile.GFile(os.path.join(self.root_folder, annotated_table), "r")
  310. counter = 0
  311. column_names = []
  312. for line in f:
  313. if (counter > 0):
  314. line = line.strip()
  315. line = line + "\t" * (13 - len(line.split("\t")))
  316. (row, col, read_id, content, tokens, lemma_tokens, pos_tags, ner_tags,
  317. ner_values, number, date, num2, read_list) = line.split("\t")
  318. entry = self.pre_process_sentence(tokens, ner_tags, ner_values)
  319. if (row == "-1"):
  320. column_names.append(entry)
  321. else:
  322. orig_columns[int(col)][int(row)] = entry
  323. if (len(entry) == 1 and is_number(entry[0])):
  324. processed_columns[int(col)][int(row)] = float(entry[0])
  325. else:
  326. for single_entry in entry:
  327. if (is_number(single_entry)):
  328. processed_columns[int(col)][int(row)] = float(single_entry)
  329. break
  330. nt = ner_tags.split("|")
  331. nv = ner_values.split("|")
  332. for i_entry in range(len(tokens.split("|"))):
  333. if (nt[i_entry] == "DATE" and
  334. is_number(nv[i_entry].replace("-", "").replace("X", ""))):
  335. processed_columns[int(col)][int(row)] = float(nv[
  336. i_entry].replace("-", "").replace("X", ""))
  337. #processed_columns[int(col)][int(row)] = float(nv[i_entry])
  338. if (len(entry) == 1 and (is_number(entry[0]) or is_date(entry[0]) or
  339. self.is_money(entry[0]))):
  340. if (len(entry) == 1 and not (is_number(entry[0])) and
  341. is_date(entry[0])):
  342. entry[0] = entry[0].replace("X", "x")
  343. counter += 1
  344. word_columns = []
  345. processed_word_columns = []
  346. word_column_names = []
  347. word_column_indices = []
  348. number_columns = []
  349. processed_number_columns = []
  350. number_column_names = []
  351. number_column_indices = []
  352. for i in range(max_col + 1):
  353. if (self.is_number_column(orig_columns[i])):
  354. number_column_indices.append(i)
  355. number_column_names.append(column_names[i])
  356. temp = []
  357. for w in orig_columns[i]:
  358. if (is_number(w[0])):
  359. temp.append(w[0])
  360. number_columns.append(temp)
  361. processed_number_columns.append(processed_columns[i])
  362. else:
  363. word_column_indices.append(i)
  364. word_column_names.append(column_names[i])
  365. word_columns.append(orig_columns[i])
  366. processed_word_columns.append(processed_columns[i])
  367. table_info = TableInfo(
  368. word_columns, word_column_names, word_column_indices, number_columns,
  369. number_column_names, number_column_indices, processed_word_columns,
  370. processed_number_columns, orig_columns)
  371. self.annotated_tables[table] = table_info
  372. f.close()
  373. def answer_classification(self):
  374. lookup_questions = 0
  375. number_lookup_questions = 0
  376. word_lookup_questions = 0
  377. ambiguous_lookup_questions = 0
  378. number_questions = 0
  379. bad_questions = 0
  380. ice_bad_questions = 0
  381. tot = 0
  382. got = 0
  383. ice = {}
  384. with tf.gfile.GFile(
  385. self.root_folder + "/arvind-with-norms-2.tsv", mode="r") as f:
  386. lines = f.readlines()
  387. for line in lines:
  388. line = line.strip()
  389. if (not (self.annotated_examples.has_key(line.split("\t")[0]))):
  390. continue
  391. if (len(line.split("\t")) == 4):
  392. line = line + "\t" * (5 - len(line.split("\t")))
  393. if (not (is_number(line.split("\t")[2]))):
  394. ice_bad_questions += 1
  395. (example_id, ans_index, ans_raw, process_answer,
  396. matched_cells) = line.split("\t")
  397. if (ice.has_key(example_id)):
  398. ice[example_id].append(line.split("\t"))
  399. else:
  400. ice[example_id] = [line.split("\t")]
  401. for q_id in self.annotated_examples.keys():
  402. tot += 1
  403. example = self.annotated_examples[q_id]
  404. table_info = self.annotated_tables[example.table_key]
  405. # Figure out if the answer is numerical or lookup
  406. n_cols = len(table_info.orig_columns)
  407. n_rows = len(table_info.orig_columns[0])
  408. example.lookup_matrix = np.zeros((n_rows, n_cols))
  409. exact_matches = {}
  410. for (example_id, ans_index, ans_raw, process_answer,
  411. matched_cells) in ice[q_id]:
  412. for match_cell in matched_cells.split("|"):
  413. if (len(match_cell.split(",")) == 2):
  414. (row, col) = match_cell.split(",")
  415. row = int(row)
  416. col = int(col)
  417. if (row >= 0):
  418. exact_matches[ans_index] = 1
  419. answer_is_in_table = len(exact_matches) == len(example.answer)
  420. if (answer_is_in_table):
  421. for (example_id, ans_index, ans_raw, process_answer,
  422. matched_cells) in ice[q_id]:
  423. for match_cell in matched_cells.split("|"):
  424. if (len(match_cell.split(",")) == 2):
  425. (row, col) = match_cell.split(",")
  426. row = int(row)
  427. col = int(col)
  428. example.lookup_matrix[row, col] = float(ans_index) + 1.0
  429. example.lookup_number_answer = 0.0
  430. if (answer_is_in_table):
  431. lookup_questions += 1
  432. if len(example.answer) == 1 and is_number(example.answer[0]):
  433. example.number_answer = float(example.answer[0])
  434. number_lookup_questions += 1
  435. example.is_number_lookup = True
  436. else:
  437. #print "word lookup"
  438. example.calc_answer = example.number_answer = 0.0
  439. word_lookup_questions += 1
  440. example.is_word_lookup = True
  441. else:
  442. if (len(example.answer) == 1 and is_number(example.answer[0])):
  443. example.number_answer = example.answer[0]
  444. example.is_number_calc = True
  445. else:
  446. bad_questions += 1
  447. example.is_bad_example = True
  448. example.is_unknown_answer = True
  449. example.is_lookup = example.is_word_lookup or example.is_number_lookup
  450. if not example.is_word_lookup and not example.is_bad_example:
  451. number_questions += 1
  452. example.calc_answer = example.answer[0]
  453. example.lookup_number_answer = example.calc_answer
  454. # Split up the lookup matrix into word part and number part
  455. number_column_indices = table_info.number_column_indices
  456. word_column_indices = table_info.word_column_indices
  457. example.word_columns = table_info.word_columns
  458. example.number_columns = table_info.number_columns
  459. example.word_column_names = table_info.word_column_names
  460. example.processed_number_columns = table_info.processed_number_columns
  461. example.processed_word_columns = table_info.processed_word_columns
  462. example.number_column_names = table_info.number_column_names
  463. example.number_lookup_matrix = example.lookup_matrix[:,
  464. number_column_indices]
  465. example.word_lookup_matrix = example.lookup_matrix[:, word_column_indices]
  466. def load(self):
  467. train_data = []
  468. dev_data = []
  469. test_data = []
  470. self.load_annotated_data(
  471. os.path.join(self.data_folder, "training.annotated"))
  472. self.load_annotated_tables()
  473. self.answer_classification()
  474. self.train_loader.load()
  475. self.dev_loader.load()
  476. for i in range(self.train_loader.num_questions()):
  477. example = self.train_loader.examples[i]
  478. example = self.annotated_examples[example]
  479. train_data.append(example)
  480. for i in range(self.dev_loader.num_questions()):
  481. example = self.dev_loader.examples[i]
  482. dev_data.append(self.annotated_examples[example])
  483. self.load_annotated_data(
  484. os.path.join(self.data_folder, "pristine-unseen-tables.annotated"))
  485. self.load_annotated_tables()
  486. self.answer_classification()
  487. self.test_loader.load()
  488. for i in range(self.test_loader.num_questions()):
  489. example = self.test_loader.examples[i]
  490. test_data.append(self.annotated_examples[example])
  491. return train_data, dev_data, test_data