123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531 |
- # Copyright 2016 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Loads the WikiQuestions dataset.
- An example consists of question, table. Additionally, we store the processed
- columns which store the entries after performing number, date and other
- preprocessing as done in the baseline.
- columns, column names and processed columns are split into word and number
- columns.
- lookup answer (or matrix) is also split into number and word lookup matrix
- Author: aneelakantan (Arvind Neelakantan)
- """
- import math
- import os
- import re
- import numpy as np
- import unicodedata as ud
- import tensorflow as tf
- bad_number = -200000.0 #number that is added to a corrupted table entry in a number column
- def is_nan_or_inf(number):
- return math.isnan(number) or math.isinf(number)
- def strip_accents(s):
- u = unicode(s, "utf-8")
- u_new = ''.join(c for c in ud.normalize('NFKD', u) if ud.category(c) != 'Mn')
- return u_new.encode("utf-8")
- def correct_unicode(string):
- string = strip_accents(string)
- string = re.sub("\xc2\xa0", " ", string).strip()
- string = re.sub("\xe2\x80\x93", "-", string).strip()
- #string = re.sub(ur'[\u0300-\u036F]', "", string)
- string = re.sub("‚", ",", string)
- string = re.sub("…", "...", string)
- #string = re.sub("[·・]", ".", string)
- string = re.sub("ˆ", "^", string)
- string = re.sub("˜", "~", string)
- string = re.sub("‹", "<", string)
- string = re.sub("›", ">", string)
- #string = re.sub("[‘’´`]", "'", string)
- #string = re.sub("[“â€Â«Â»]", "\"", string)
- #string = re.sub("[•†‡]", "", string)
- #string = re.sub("[â€â€‘–—]", "-", string)
- string = re.sub(ur'[\u2E00-\uFFFF]', "", string)
- string = re.sub("\\s+", " ", string).strip()
- return string
- def simple_normalize(string):
- string = correct_unicode(string)
- # Citations
- string = re.sub("\[(nb ?)?\d+\]", "", string)
- string = re.sub("\*+$", "", string)
- # Year in parenthesis
- string = re.sub("\(\d* ?-? ?\d*\)", "", string)
- string = re.sub("^\"(.*)\"$", "", string)
- return string
- def full_normalize(string):
- #print "an: ", string
- string = simple_normalize(string)
- # Remove trailing info in brackets
- string = re.sub("\[[^\]]*\]", "", string)
- # Remove most unicode characters in other languages
- string = re.sub(ur'[\u007F-\uFFFF]', "", string.strip())
- # Remove trailing info in parenthesis
- string = re.sub("\([^)]*\)$", "", string.strip())
- string = final_normalize(string)
- # Get rid of question marks
- string = re.sub("\?", "", string).strip()
- # Get rid of trailing colons (usually occur in column titles)
- string = re.sub("\:$", " ", string).strip()
- # Get rid of slashes
- string = re.sub(r"/", " ", string).strip()
- string = re.sub(r"\\", " ", string).strip()
- # Replace colon, slash, and dash with space
- # Note: need better replacement for this when parsing time
- string = re.sub(r"\:", " ", string).strip()
- string = re.sub("/", " ", string).strip()
- string = re.sub("-", " ", string).strip()
- # Convert empty strings to UNK
- # Important to do this last or near last
- if not string:
- string = "UNK"
- return string
- def final_normalize(string):
- # Remove leading and trailing whitespace
- string = re.sub("\\s+", " ", string).strip()
- # Convert entirely to lowercase
- string = string.lower()
- # Get rid of strangely escaped newline characters
- string = re.sub("\\\\n", " ", string).strip()
- # Get rid of quotation marks
- string = re.sub(r"\"", "", string).strip()
- string = re.sub(r"\'", "", string).strip()
- string = re.sub(r"`", "", string).strip()
- # Get rid of *
- string = re.sub("\*", "", string).strip()
- return string
- def is_number(x):
- try:
- f = float(x)
- return not is_nan_or_inf(f)
- except ValueError:
- return False
- except TypeError:
- return False
- class WikiExample(object):
- def __init__(self, id, question, answer, table_key):
- self.question_id = id
- self.question = question
- self.answer = answer
- self.table_key = table_key
- self.lookup_matrix = []
- self.is_bad_example = False
- self.is_word_lookup = False
- self.is_ambiguous_word_lookup = False
- self.is_number_lookup = False
- self.is_number_calc = False
- self.is_unknown_answer = False
- class TableInfo(object):
- def __init__(self, word_columns, word_column_names, word_column_indices,
- number_columns, number_column_names, number_column_indices,
- processed_word_columns, processed_number_columns, orig_columns):
- self.word_columns = word_columns
- self.word_column_names = word_column_names
- self.word_column_indices = word_column_indices
- self.number_columns = number_columns
- self.number_column_names = number_column_names
- self.number_column_indices = number_column_indices
- self.processed_word_columns = processed_word_columns
- self.processed_number_columns = processed_number_columns
- self.orig_columns = orig_columns
- class WikiQuestionLoader(object):
- def __init__(self, data_name, root_folder):
- self.root_folder = root_folder
- self.data_folder = os.path.join(self.root_folder, "data")
- self.examples = []
- self.data_name = data_name
- def num_questions(self):
- return len(self.examples)
- def load_qa(self):
- data_source = os.path.join(self.data_folder, self.data_name)
- f = tf.gfile.GFile(data_source, "r")
- id_regex = re.compile("\(id ([^\)]*)\)")
- for line in f:
- id_match = id_regex.search(line)
- id = id_match.group(1)
- self.examples.append(id)
- def load(self):
- self.load_qa()
- def is_date(word):
- if (not (bool(re.search("[a-z0-9]", word, re.IGNORECASE)))):
- return False
- if (len(word) != 10):
- return False
- if (word[4] != "-"):
- return False
- if (word[7] != "-"):
- return False
- for i in range(len(word)):
- if (not (word[i] == "X" or word[i] == "x" or word[i] == "-" or re.search(
- "[0-9]", word[i]))):
- return False
- return True
- class WikiQuestionGenerator(object):
- def __init__(self, train_name, dev_name, test_name, root_folder):
- self.train_name = train_name
- self.dev_name = dev_name
- self.test_name = test_name
- self.train_loader = WikiQuestionLoader(train_name, root_folder)
- self.dev_loader = WikiQuestionLoader(dev_name, root_folder)
- self.test_loader = WikiQuestionLoader(test_name, root_folder)
- self.bad_examples = 0
- self.root_folder = root_folder
- self.data_folder = os.path.join(self.root_folder, "annotated/data")
- self.annotated_examples = {}
- self.annotated_tables = {}
- self.annotated_word_reject = {}
- self.annotated_word_reject["-lrb-"] = 1
- self.annotated_word_reject["-rrb-"] = 1
- self.annotated_word_reject["UNK"] = 1
- def is_money(self, word):
- if (not (bool(re.search("[a-z0-9]", word, re.IGNORECASE)))):
- return False
- for i in range(len(word)):
- if (not (word[i] == "E" or word[i] == "." or re.search("[0-9]",
- word[i]))):
- return False
- return True
- def remove_consecutive(self, ner_tags, ner_values):
- for i in range(len(ner_tags)):
- if ((ner_tags[i] == "NUMBER" or ner_tags[i] == "MONEY" or
- ner_tags[i] == "PERCENT" or ner_tags[i] == "DATE") and
- i + 1 < len(ner_tags) and ner_tags[i] == ner_tags[i + 1] and
- ner_values[i] == ner_values[i + 1] and ner_values[i] != ""):
- word = ner_values[i]
- word = word.replace(">", "").replace("<", "").replace("=", "").replace(
- "%", "").replace("~", "").replace("$", "").replace("£", "").replace(
- "€", "")
- if (re.search("[A-Z]", word) and not (is_date(word)) and not (
- self.is_money(word))):
- ner_values[i] = "A"
- else:
- ner_values[i] = ","
- return ner_tags, ner_values
- def pre_process_sentence(self, tokens, ner_tags, ner_values):
- sentence = []
- tokens = tokens.split("|")
- ner_tags = ner_tags.split("|")
- ner_values = ner_values.split("|")
- ner_tags, ner_values = self.remove_consecutive(ner_tags, ner_values)
- #print "old: ", tokens
- for i in range(len(tokens)):
- word = tokens[i]
- if (ner_values[i] != "" and
- (ner_tags[i] == "NUMBER" or ner_tags[i] == "MONEY" or
- ner_tags[i] == "PERCENT" or ner_tags[i] == "DATE")):
- word = ner_values[i]
- word = word.replace(">", "").replace("<", "").replace("=", "").replace(
- "%", "").replace("~", "").replace("$", "").replace("£", "").replace(
- "€", "")
- if (re.search("[A-Z]", word) and not (is_date(word)) and not (
- self.is_money(word))):
- word = tokens[i]
- if (is_number(ner_values[i])):
- word = float(ner_values[i])
- elif (is_number(word)):
- word = float(word)
- if (tokens[i] == "score"):
- word = "score"
- if (is_number(word)):
- word = float(word)
- if (not (self.annotated_word_reject.has_key(word))):
- if (is_number(word) or is_date(word) or self.is_money(word)):
- sentence.append(word)
- else:
- word = full_normalize(word)
- if (not (self.annotated_word_reject.has_key(word)) and
- bool(re.search("[a-z0-9]", word, re.IGNORECASE))):
- m = re.search(",", word)
- sentence.append(word.replace(",", ""))
- if (len(sentence) == 0):
- sentence.append("UNK")
- return sentence
- def load_annotated_data(self, in_file):
- self.annotated_examples = {}
- self.annotated_tables = {}
- f = tf.gfile.GFile(in_file, "r")
- counter = 0
- for line in f:
- if (counter > 0):
- line = line.strip()
- (question_id, utterance, context, target_value, tokens, lemma_tokens,
- pos_tags, ner_tags, ner_values, target_canon) = line.split("\t")
- question = self.pre_process_sentence(tokens, ner_tags, ner_values)
- target_canon = target_canon.split("|")
- self.annotated_examples[question_id] = WikiExample(
- question_id, question, target_canon, context)
- self.annotated_tables[context] = []
- counter += 1
- print "Annotated examples loaded ", len(self.annotated_examples)
- f.close()
- def is_number_column(self, a):
- for w in a:
- if (len(w) != 1):
- return False
- if (not (is_number(w[0]))):
- return False
- return True
- def convert_table(self, table):
- answer = []
- for i in range(len(table)):
- temp = []
- for j in range(len(table[i])):
- temp.append(" ".join([str(w) for w in table[i][j]]))
- answer.append(temp)
- return answer
- def load_annotated_tables(self):
- for table in self.annotated_tables.keys():
- annotated_table = table.replace("csv", "annotated")
- orig_columns = []
- processed_columns = []
- f = tf.gfile.GFile(os.path.join(self.root_folder, annotated_table), "r")
- counter = 0
- for line in f:
- if (counter > 0):
- line = line.strip()
- line = line + "\t" * (13 - len(line.split("\t")))
- (row, col, read_id, content, tokens, lemma_tokens, pos_tags, ner_tags,
- ner_values, number, date, num2, read_list) = line.split("\t")
- counter += 1
- f.close()
- max_row = int(row)
- max_col = int(col)
- for i in range(max_col + 1):
- orig_columns.append([])
- processed_columns.append([])
- for j in range(max_row + 1):
- orig_columns[i].append(bad_number)
- processed_columns[i].append(bad_number)
- #print orig_columns
- f = tf.gfile.GFile(os.path.join(self.root_folder, annotated_table), "r")
- counter = 0
- column_names = []
- for line in f:
- if (counter > 0):
- line = line.strip()
- line = line + "\t" * (13 - len(line.split("\t")))
- (row, col, read_id, content, tokens, lemma_tokens, pos_tags, ner_tags,
- ner_values, number, date, num2, read_list) = line.split("\t")
- entry = self.pre_process_sentence(tokens, ner_tags, ner_values)
- if (row == "-1"):
- column_names.append(entry)
- else:
- orig_columns[int(col)][int(row)] = entry
- if (len(entry) == 1 and is_number(entry[0])):
- processed_columns[int(col)][int(row)] = float(entry[0])
- else:
- for single_entry in entry:
- if (is_number(single_entry)):
- processed_columns[int(col)][int(row)] = float(single_entry)
- break
- nt = ner_tags.split("|")
- nv = ner_values.split("|")
- for i_entry in range(len(tokens.split("|"))):
- if (nt[i_entry] == "DATE" and
- is_number(nv[i_entry].replace("-", "").replace("X", ""))):
- processed_columns[int(col)][int(row)] = float(nv[
- i_entry].replace("-", "").replace("X", ""))
- #processed_columns[int(col)][int(row)] = float(nv[i_entry])
- if (len(entry) == 1 and (is_number(entry[0]) or is_date(entry[0]) or
- self.is_money(entry[0]))):
- if (len(entry) == 1 and not (is_number(entry[0])) and
- is_date(entry[0])):
- entry[0] = entry[0].replace("X", "x")
- counter += 1
- word_columns = []
- processed_word_columns = []
- word_column_names = []
- word_column_indices = []
- number_columns = []
- processed_number_columns = []
- number_column_names = []
- number_column_indices = []
- for i in range(max_col + 1):
- if (self.is_number_column(orig_columns[i])):
- number_column_indices.append(i)
- number_column_names.append(column_names[i])
- temp = []
- for w in orig_columns[i]:
- if (is_number(w[0])):
- temp.append(w[0])
- number_columns.append(temp)
- processed_number_columns.append(processed_columns[i])
- else:
- word_column_indices.append(i)
- word_column_names.append(column_names[i])
- word_columns.append(orig_columns[i])
- processed_word_columns.append(processed_columns[i])
- table_info = TableInfo(
- word_columns, word_column_names, word_column_indices, number_columns,
- number_column_names, number_column_indices, processed_word_columns,
- processed_number_columns, orig_columns)
- self.annotated_tables[table] = table_info
- f.close()
- def answer_classification(self):
- lookup_questions = 0
- number_lookup_questions = 0
- word_lookup_questions = 0
- ambiguous_lookup_questions = 0
- number_questions = 0
- bad_questions = 0
- ice_bad_questions = 0
- tot = 0
- got = 0
- ice = {}
- with tf.gfile.GFile(
- self.root_folder + "/arvind-with-norms-2.tsv", mode="r") as f:
- lines = f.readlines()
- for line in lines:
- line = line.strip()
- if (not (self.annotated_examples.has_key(line.split("\t")[0]))):
- continue
- if (len(line.split("\t")) == 4):
- line = line + "\t" * (5 - len(line.split("\t")))
- if (not (is_number(line.split("\t")[2]))):
- ice_bad_questions += 1
- (example_id, ans_index, ans_raw, process_answer,
- matched_cells) = line.split("\t")
- if (ice.has_key(example_id)):
- ice[example_id].append(line.split("\t"))
- else:
- ice[example_id] = [line.split("\t")]
- for q_id in self.annotated_examples.keys():
- tot += 1
- example = self.annotated_examples[q_id]
- table_info = self.annotated_tables[example.table_key]
- # Figure out if the answer is numerical or lookup
- n_cols = len(table_info.orig_columns)
- n_rows = len(table_info.orig_columns[0])
- example.lookup_matrix = np.zeros((n_rows, n_cols))
- exact_matches = {}
- for (example_id, ans_index, ans_raw, process_answer,
- matched_cells) in ice[q_id]:
- for match_cell in matched_cells.split("|"):
- if (len(match_cell.split(",")) == 2):
- (row, col) = match_cell.split(",")
- row = int(row)
- col = int(col)
- if (row >= 0):
- exact_matches[ans_index] = 1
- answer_is_in_table = len(exact_matches) == len(example.answer)
- if (answer_is_in_table):
- for (example_id, ans_index, ans_raw, process_answer,
- matched_cells) in ice[q_id]:
- for match_cell in matched_cells.split("|"):
- if (len(match_cell.split(",")) == 2):
- (row, col) = match_cell.split(",")
- row = int(row)
- col = int(col)
- example.lookup_matrix[row, col] = float(ans_index) + 1.0
- example.lookup_number_answer = 0.0
- if (answer_is_in_table):
- lookup_questions += 1
- if len(example.answer) == 1 and is_number(example.answer[0]):
- example.number_answer = float(example.answer[0])
- number_lookup_questions += 1
- example.is_number_lookup = True
- else:
- #print "word lookup"
- example.calc_answer = example.number_answer = 0.0
- word_lookup_questions += 1
- example.is_word_lookup = True
- else:
- if (len(example.answer) == 1 and is_number(example.answer[0])):
- example.number_answer = example.answer[0]
- example.is_number_calc = True
- else:
- bad_questions += 1
- example.is_bad_example = True
- example.is_unknown_answer = True
- example.is_lookup = example.is_word_lookup or example.is_number_lookup
- if not example.is_word_lookup and not example.is_bad_example:
- number_questions += 1
- example.calc_answer = example.answer[0]
- example.lookup_number_answer = example.calc_answer
- # Split up the lookup matrix into word part and number part
- number_column_indices = table_info.number_column_indices
- word_column_indices = table_info.word_column_indices
- example.word_columns = table_info.word_columns
- example.number_columns = table_info.number_columns
- example.word_column_names = table_info.word_column_names
- example.processed_number_columns = table_info.processed_number_columns
- example.processed_word_columns = table_info.processed_word_columns
- example.number_column_names = table_info.number_column_names
- example.number_lookup_matrix = example.lookup_matrix[:,
- number_column_indices]
- example.word_lookup_matrix = example.lookup_matrix[:, word_column_indices]
- def load(self):
- train_data = []
- dev_data = []
- test_data = []
- self.load_annotated_data(
- os.path.join(self.data_folder, "training.annotated"))
- self.load_annotated_tables()
- self.answer_classification()
- self.train_loader.load()
- self.dev_loader.load()
- for i in range(self.train_loader.num_questions()):
- example = self.train_loader.examples[i]
- example = self.annotated_examples[example]
- train_data.append(example)
- for i in range(self.dev_loader.num_questions()):
- example = self.dev_loader.examples[i]
- dev_data.append(self.annotated_examples[example])
- self.load_annotated_data(
- os.path.join(self.data_folder, "pristine-unseen-tables.annotated"))
- self.load_annotated_tables()
- self.answer_classification()
- self.test_loader.load()
- for i in range(self.test_loader.num_questions()):
- example = self.test_loader.examples[i]
- test_data.append(self.annotated_examples[example])
- return train_data, dev_data, test_data
|