123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665 |
- # Copyright 2016 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Functions for constructing vocabulary, converting the examples to integer format and building the required masks for batch computation Author: aneelakantan (Arvind Neelakantan)
- """
- import copy
- import numbers
- import numpy as np
- import wiki_data
- def return_index(a):
- for i in range(len(a)):
- if (a[i] == 1.0):
- return i
- def construct_vocab(data, utility, add_word=False):
- ans = []
- for example in data:
- sent = ""
- for word in example.question:
- if (not (isinstance(word, numbers.Number))):
- sent += word + " "
- example.original_nc = copy.deepcopy(example.number_columns)
- example.original_wc = copy.deepcopy(example.word_columns)
- example.original_nc_names = copy.deepcopy(example.number_column_names)
- example.original_wc_names = copy.deepcopy(example.word_column_names)
- if (add_word):
- continue
- number_found = 0
- if (not (example.is_bad_example)):
- for word in example.question:
- if (isinstance(word, numbers.Number)):
- number_found += 1
- else:
- if (not (utility.word_ids.has_key(word))):
- utility.words.append(word)
- utility.word_count[word] = 1
- utility.word_ids[word] = len(utility.word_ids)
- utility.reverse_word_ids[utility.word_ids[word]] = word
- else:
- utility.word_count[word] += 1
- for col_name in example.word_column_names:
- for word in col_name:
- if (isinstance(word, numbers.Number)):
- number_found += 1
- else:
- if (not (utility.word_ids.has_key(word))):
- utility.words.append(word)
- utility.word_count[word] = 1
- utility.word_ids[word] = len(utility.word_ids)
- utility.reverse_word_ids[utility.word_ids[word]] = word
- else:
- utility.word_count[word] += 1
- for col_name in example.number_column_names:
- for word in col_name:
- if (isinstance(word, numbers.Number)):
- number_found += 1
- else:
- if (not (utility.word_ids.has_key(word))):
- utility.words.append(word)
- utility.word_count[word] = 1
- utility.word_ids[word] = len(utility.word_ids)
- utility.reverse_word_ids[utility.word_ids[word]] = word
- else:
- utility.word_count[word] += 1
- def word_lookup(word, utility):
- if (utility.word_ids.has_key(word)):
- return word
- else:
- return utility.unk_token
- def convert_to_int_2d_and_pad(a, utility):
- ans = []
- #print a
- for b in a:
- temp = []
- if (len(b) > utility.FLAGS.max_entry_length):
- b = b[0:utility.FLAGS.max_entry_length]
- for remaining in range(len(b), utility.FLAGS.max_entry_length):
- b.append(utility.dummy_token)
- assert len(b) == utility.FLAGS.max_entry_length
- for word in b:
- temp.append(utility.word_ids[word_lookup(word, utility)])
- ans.append(temp)
- #print ans
- return ans
- def convert_to_bool_and_pad(a, utility):
- a = a.tolist()
- for i in range(len(a)):
- for j in range(len(a[i])):
- if (a[i][j] < 1):
- a[i][j] = False
- else:
- a[i][j] = True
- a[i] = a[i] + [False] * (utility.FLAGS.max_elements - len(a[i]))
- return a
- seen_tables = {}
- def partial_match(question, table, number):
- answer = []
- match = {}
- for i in range(len(table)):
- temp = []
- for j in range(len(table[i])):
- temp.append(0)
- answer.append(temp)
- for i in range(len(table)):
- for j in range(len(table[i])):
- for word in question:
- if (number):
- if (word == table[i][j]):
- answer[i][j] = 1.0
- match[i] = 1.0
- else:
- if (word in table[i][j]):
- answer[i][j] = 1.0
- match[i] = 1.0
- return answer, match
- def exact_match(question, table, number):
- #performs exact match operation
- answer = []
- match = {}
- matched_indices = []
- for i in range(len(table)):
- temp = []
- for j in range(len(table[i])):
- temp.append(0)
- answer.append(temp)
- for i in range(len(table)):
- for j in range(len(table[i])):
- if (number):
- for word in question:
- if (word == table[i][j]):
- match[i] = 1.0
- answer[i][j] = 1.0
- else:
- table_entry = table[i][j]
- for k in range(len(question)):
- if (k + len(table_entry) <= len(question)):
- if (table_entry == question[k:(k + len(table_entry))]):
- #if(len(table_entry) == 1):
- #print "match: ", table_entry, question
- match[i] = 1.0
- answer[i][j] = 1.0
- matched_indices.append((k, len(table_entry)))
- return answer, match, matched_indices
- def partial_column_match(question, table, number):
- answer = []
- for i in range(len(table)):
- answer.append(0)
- for i in range(len(table)):
- for word in question:
- if (word in table[i]):
- answer[i] = 1.0
- return answer
- def exact_column_match(question, table, number):
- #performs exact match on column names
- answer = []
- matched_indices = []
- for i in range(len(table)):
- answer.append(0)
- for i in range(len(table)):
- table_entry = table[i]
- for k in range(len(question)):
- if (k + len(table_entry) <= len(question)):
- if (table_entry == question[k:(k + len(table_entry))]):
- answer[i] = 1.0
- matched_indices.append((k, len(table_entry)))
- return answer, matched_indices
- def get_max_entry(a):
- e = {}
- for w in a:
- if (w != "UNK, "):
- if (e.has_key(w)):
- e[w] += 1
- else:
- e[w] = 1
- if (len(e) > 0):
- (key, val) = sorted(e.items(), key=lambda x: -1 * x[1])[0]
- if (val > 1):
- return key
- else:
- return -1.0
- else:
- return -1.0
- def list_join(a):
- ans = ""
- for w in a:
- ans += str(w) + ", "
- return ans
- def group_by_max(table, number):
- #computes the most frequently occuring entry in a column
- answer = []
- for i in range(len(table)):
- temp = []
- for j in range(len(table[i])):
- temp.append(0)
- answer.append(temp)
- for i in range(len(table)):
- if (number):
- curr = table[i]
- else:
- curr = [list_join(w) for w in table[i]]
- max_entry = get_max_entry(curr)
- #print i, max_entry
- for j in range(len(curr)):
- if (max_entry == curr[j]):
- answer[i][j] = 1.0
- else:
- answer[i][j] = 0.0
- return answer
- def pick_one(a):
- for i in range(len(a)):
- if (1.0 in a[i]):
- return True
- return False
- def check_processed_cols(col, utility):
- return True in [
- True for y in col
- if (y != utility.FLAGS.pad_int and y !=
- utility.FLAGS.bad_number_pre_process)
- ]
- def complete_wiki_processing(data, utility, train=True):
- #convert to integers and padding
- processed_data = []
- num_bad_examples = 0
- for example in data:
- number_found = 0
- if (example.is_bad_example):
- num_bad_examples += 1
- if (not (example.is_bad_example)):
- example.string_question = example.question[:]
- #entry match
- example.processed_number_columns = example.processed_number_columns[:]
- example.processed_word_columns = example.processed_word_columns[:]
- example.word_exact_match, word_match, matched_indices = exact_match(
- example.string_question, example.original_wc, number=False)
- example.number_exact_match, number_match, _ = exact_match(
- example.string_question, example.original_nc, number=True)
- if (not (pick_one(example.word_exact_match)) and not (
- pick_one(example.number_exact_match))):
- assert len(word_match) == 0
- assert len(number_match) == 0
- example.word_exact_match, word_match = partial_match(
- example.string_question, example.original_wc, number=False)
- #group by max
- example.word_group_by_max = group_by_max(example.original_wc, False)
- example.number_group_by_max = group_by_max(example.original_nc, True)
- #column name match
- example.word_column_exact_match, wcol_matched_indices = exact_column_match(
- example.string_question, example.original_wc_names, number=False)
- example.number_column_exact_match, ncol_matched_indices = exact_column_match(
- example.string_question, example.original_nc_names, number=False)
- if (not (1.0 in example.word_column_exact_match) and not (
- 1.0 in example.number_column_exact_match)):
- example.word_column_exact_match = partial_column_match(
- example.string_question, example.original_wc_names, number=False)
- example.number_column_exact_match = partial_column_match(
- example.string_question, example.original_nc_names, number=False)
- if (len(word_match) > 0 or len(number_match) > 0):
- example.question.append(utility.entry_match_token)
- if (1.0 in example.word_column_exact_match or
- 1.0 in example.number_column_exact_match):
- example.question.append(utility.column_match_token)
- example.string_question = example.question[:]
- example.number_lookup_matrix = np.transpose(
- example.number_lookup_matrix)[:]
- example.word_lookup_matrix = np.transpose(example.word_lookup_matrix)[:]
- example.columns = example.number_columns[:]
- example.word_columns = example.word_columns[:]
- example.len_total_cols = len(example.word_column_names) + len(
- example.number_column_names)
- example.column_names = example.number_column_names[:]
- example.word_column_names = example.word_column_names[:]
- example.string_column_names = example.number_column_names[:]
- example.string_word_column_names = example.word_column_names[:]
- example.sorted_number_index = []
- example.sorted_word_index = []
- example.column_mask = []
- example.word_column_mask = []
- example.processed_column_mask = []
- example.processed_word_column_mask = []
- example.word_column_entry_mask = []
- example.question_attention_mask = []
- example.question_number = example.question_number_1 = -1
- example.question_attention_mask = []
- example.ordinal_question = []
- example.ordinal_question_one = []
- new_question = []
- if (len(example.number_columns) > 0):
- example.len_col = len(example.number_columns[0])
- else:
- example.len_col = len(example.word_columns[0])
- for (start, length) in matched_indices:
- for j in range(length):
- example.question[start + j] = utility.unk_token
- #print example.question
- for word in example.question:
- if (isinstance(word, numbers.Number) or wiki_data.is_date(word)):
- if (not (isinstance(word, numbers.Number)) and
- wiki_data.is_date(word)):
- word = word.replace("X", "").replace("-", "")
- number_found += 1
- if (number_found == 1):
- example.question_number = word
- if (len(example.ordinal_question) > 0):
- example.ordinal_question[len(example.ordinal_question) - 1] = 1.0
- else:
- example.ordinal_question.append(1.0)
- elif (number_found == 2):
- example.question_number_1 = word
- if (len(example.ordinal_question_one) > 0):
- example.ordinal_question_one[len(example.ordinal_question_one) -
- 1] = 1.0
- else:
- example.ordinal_question_one.append(1.0)
- else:
- new_question.append(word)
- example.ordinal_question.append(0.0)
- example.ordinal_question_one.append(0.0)
- example.question = [
- utility.word_ids[word_lookup(w, utility)] for w in new_question
- ]
- example.question_attention_mask = [0.0] * len(example.question)
- #when the first question number occurs before a word
- example.ordinal_question = example.ordinal_question[0:len(
- example.question)]
- example.ordinal_question_one = example.ordinal_question_one[0:len(
- example.question)]
- #question-padding
- example.question = [utility.word_ids[utility.dummy_token]] * (
- utility.FLAGS.question_length - len(example.question)
- ) + example.question
- example.question_attention_mask = [-10000.0] * (
- utility.FLAGS.question_length - len(example.question_attention_mask)
- ) + example.question_attention_mask
- example.ordinal_question = [0.0] * (utility.FLAGS.question_length -
- len(example.ordinal_question)
- ) + example.ordinal_question
- example.ordinal_question_one = [0.0] * (utility.FLAGS.question_length -
- len(example.ordinal_question_one)
- ) + example.ordinal_question_one
- if (True):
- #number columns and related-padding
- num_cols = len(example.columns)
- start = 0
- for column in example.number_columns:
- if (check_processed_cols(example.processed_number_columns[start],
- utility)):
- example.processed_column_mask.append(0.0)
- sorted_index = sorted(
- range(len(example.processed_number_columns[start])),
- key=lambda k: example.processed_number_columns[start][k],
- reverse=True)
- sorted_index = sorted_index + [utility.FLAGS.pad_int] * (
- utility.FLAGS.max_elements - len(sorted_index))
- example.sorted_number_index.append(sorted_index)
- example.columns[start] = column + [utility.FLAGS.pad_int] * (
- utility.FLAGS.max_elements - len(column))
- example.processed_number_columns[start] += [utility.FLAGS.pad_int] * (
- utility.FLAGS.max_elements -
- len(example.processed_number_columns[start]))
- start += 1
- example.column_mask.append(0.0)
- for remaining in range(num_cols, utility.FLAGS.max_number_cols):
- example.sorted_number_index.append([utility.FLAGS.pad_int] *
- (utility.FLAGS.max_elements))
- example.columns.append([utility.FLAGS.pad_int] *
- (utility.FLAGS.max_elements))
- example.processed_number_columns.append([utility.FLAGS.pad_int] *
- (utility.FLAGS.max_elements))
- example.number_exact_match.append([0.0] *
- (utility.FLAGS.max_elements))
- example.number_group_by_max.append([0.0] *
- (utility.FLAGS.max_elements))
- example.column_mask.append(-100000000.0)
- example.processed_column_mask.append(-100000000.0)
- example.number_column_exact_match.append(0.0)
- example.column_names.append([utility.dummy_token])
- #word column and related-padding
- start = 0
- word_num_cols = len(example.word_columns)
- for column in example.word_columns:
- if (check_processed_cols(example.processed_word_columns[start],
- utility)):
- example.processed_word_column_mask.append(0.0)
- sorted_index = sorted(
- range(len(example.processed_word_columns[start])),
- key=lambda k: example.processed_word_columns[start][k],
- reverse=True)
- sorted_index = sorted_index + [utility.FLAGS.pad_int] * (
- utility.FLAGS.max_elements - len(sorted_index))
- example.sorted_word_index.append(sorted_index)
- column = convert_to_int_2d_and_pad(column, utility)
- example.word_columns[start] = column + [[
- utility.word_ids[utility.dummy_token]
- ] * utility.FLAGS.max_entry_length] * (utility.FLAGS.max_elements -
- len(column))
- example.processed_word_columns[start] += [utility.FLAGS.pad_int] * (
- utility.FLAGS.max_elements -
- len(example.processed_word_columns[start]))
- example.word_column_entry_mask.append([0] * len(column) + [
- utility.word_ids[utility.dummy_token]
- ] * (utility.FLAGS.max_elements - len(column)))
- start += 1
- example.word_column_mask.append(0.0)
- for remaining in range(word_num_cols, utility.FLAGS.max_word_cols):
- example.sorted_word_index.append([utility.FLAGS.pad_int] *
- (utility.FLAGS.max_elements))
- example.word_columns.append([[utility.word_ids[utility.dummy_token]] *
- utility.FLAGS.max_entry_length] *
- (utility.FLAGS.max_elements))
- example.word_column_entry_mask.append(
- [utility.word_ids[utility.dummy_token]] *
- (utility.FLAGS.max_elements))
- example.word_exact_match.append([0.0] * (utility.FLAGS.max_elements))
- example.word_group_by_max.append([0.0] * (utility.FLAGS.max_elements))
- example.processed_word_columns.append([utility.FLAGS.pad_int] *
- (utility.FLAGS.max_elements))
- example.word_column_mask.append(-100000000.0)
- example.processed_word_column_mask.append(-100000000.0)
- example.word_column_exact_match.append(0.0)
- example.word_column_names.append([utility.dummy_token] *
- utility.FLAGS.max_entry_length)
- seen_tables[example.table_key] = 1
- #convert column and word column names to integers
- example.column_ids = convert_to_int_2d_and_pad(example.column_names,
- utility)
- example.word_column_ids = convert_to_int_2d_and_pad(
- example.word_column_names, utility)
- for i_em in range(len(example.number_exact_match)):
- example.number_exact_match[i_em] = example.number_exact_match[
- i_em] + [0.0] * (utility.FLAGS.max_elements -
- len(example.number_exact_match[i_em]))
- example.number_group_by_max[i_em] = example.number_group_by_max[
- i_em] + [0.0] * (utility.FLAGS.max_elements -
- len(example.number_group_by_max[i_em]))
- for i_em in range(len(example.word_exact_match)):
- example.word_exact_match[i_em] = example.word_exact_match[
- i_em] + [0.0] * (utility.FLAGS.max_elements -
- len(example.word_exact_match[i_em]))
- example.word_group_by_max[i_em] = example.word_group_by_max[
- i_em] + [0.0] * (utility.FLAGS.max_elements -
- len(example.word_group_by_max[i_em]))
- example.exact_match = example.number_exact_match + example.word_exact_match
- example.group_by_max = example.number_group_by_max + example.word_group_by_max
- example.exact_column_match = example.number_column_exact_match + example.word_column_exact_match
- #answer and related mask, padding
- if (example.is_lookup):
- example.answer = example.calc_answer
- example.number_print_answer = example.number_lookup_matrix.tolist()
- example.word_print_answer = example.word_lookup_matrix.tolist()
- for i_answer in range(len(example.number_print_answer)):
- example.number_print_answer[i_answer] = example.number_print_answer[
- i_answer] + [0.0] * (utility.FLAGS.max_elements -
- len(example.number_print_answer[i_answer]))
- for i_answer in range(len(example.word_print_answer)):
- example.word_print_answer[i_answer] = example.word_print_answer[
- i_answer] + [0.0] * (utility.FLAGS.max_elements -
- len(example.word_print_answer[i_answer]))
- example.number_lookup_matrix = convert_to_bool_and_pad(
- example.number_lookup_matrix, utility)
- example.word_lookup_matrix = convert_to_bool_and_pad(
- example.word_lookup_matrix, utility)
- for remaining in range(num_cols, utility.FLAGS.max_number_cols):
- example.number_lookup_matrix.append([False] *
- utility.FLAGS.max_elements)
- example.number_print_answer.append([0.0] * utility.FLAGS.max_elements)
- for remaining in range(word_num_cols, utility.FLAGS.max_word_cols):
- example.word_lookup_matrix.append([False] *
- utility.FLAGS.max_elements)
- example.word_print_answer.append([0.0] * utility.FLAGS.max_elements)
- example.print_answer = example.number_print_answer + example.word_print_answer
- else:
- example.answer = example.calc_answer
- example.print_answer = [[0.0] * (utility.FLAGS.max_elements)] * (
- utility.FLAGS.max_number_cols + utility.FLAGS.max_word_cols)
- #question_number masks
- if (example.question_number == -1):
- example.question_number_mask = np.zeros([utility.FLAGS.max_elements])
- else:
- example.question_number_mask = np.ones([utility.FLAGS.max_elements])
- if (example.question_number_1 == -1):
- example.question_number_one_mask = -10000.0
- else:
- example.question_number_one_mask = np.float64(0.0)
- if (example.len_col > utility.FLAGS.max_elements):
- continue
- processed_data.append(example)
- return processed_data
- def add_special_words(utility):
- utility.words.append(utility.entry_match_token)
- utility.word_ids[utility.entry_match_token] = len(utility.word_ids)
- utility.reverse_word_ids[utility.word_ids[
- utility.entry_match_token]] = utility.entry_match_token
- utility.entry_match_token_id = utility.word_ids[utility.entry_match_token]
- print "entry match token: ", utility.word_ids[
- utility.entry_match_token], utility.entry_match_token_id
- utility.words.append(utility.column_match_token)
- utility.word_ids[utility.column_match_token] = len(utility.word_ids)
- utility.reverse_word_ids[utility.word_ids[
- utility.column_match_token]] = utility.column_match_token
- utility.column_match_token_id = utility.word_ids[utility.column_match_token]
- print "entry match token: ", utility.word_ids[
- utility.column_match_token], utility.column_match_token_id
- utility.words.append(utility.dummy_token)
- utility.word_ids[utility.dummy_token] = len(utility.word_ids)
- utility.reverse_word_ids[utility.word_ids[
- utility.dummy_token]] = utility.dummy_token
- utility.dummy_token_id = utility.word_ids[utility.dummy_token]
- utility.words.append(utility.unk_token)
- utility.word_ids[utility.unk_token] = len(utility.word_ids)
- utility.reverse_word_ids[utility.word_ids[
- utility.unk_token]] = utility.unk_token
- def perform_word_cutoff(utility):
- if (utility.FLAGS.word_cutoff > 0):
- for word in utility.word_ids.keys():
- if (utility.word_count.has_key(word) and utility.word_count[word] <
- utility.FLAGS.word_cutoff and word != utility.unk_token and
- word != utility.dummy_token and word != utility.entry_match_token and
- word != utility.column_match_token):
- utility.word_ids.pop(word)
- utility.words.remove(word)
- def word_dropout(question, utility):
- if (utility.FLAGS.word_dropout_prob > 0.0):
- new_question = []
- for i in range(len(question)):
- if (question[i] != utility.dummy_token_id and
- utility.random.random() > utility.FLAGS.word_dropout_prob):
- new_question.append(utility.word_ids[utility.unk_token])
- else:
- new_question.append(question[i])
- return new_question
- else:
- return question
- def generate_feed_dict(data, curr, batch_size, gr, train=False, utility=None):
- #prepare feed dict dictionary
- feed_dict = {}
- feed_examples = []
- for j in range(batch_size):
- feed_examples.append(data[curr + j])
- if (train):
- feed_dict[gr.batch_question] = [
- word_dropout(feed_examples[j].question, utility)
- for j in range(batch_size)
- ]
- else:
- feed_dict[gr.batch_question] = [
- feed_examples[j].question for j in range(batch_size)
- ]
- feed_dict[gr.batch_question_attention_mask] = [
- feed_examples[j].question_attention_mask for j in range(batch_size)
- ]
- feed_dict[
- gr.batch_answer] = [feed_examples[j].answer for j in range(batch_size)]
- feed_dict[gr.batch_number_column] = [
- feed_examples[j].columns for j in range(batch_size)
- ]
- feed_dict[gr.batch_processed_number_column] = [
- feed_examples[j].processed_number_columns for j in range(batch_size)
- ]
- feed_dict[gr.batch_processed_sorted_index_number_column] = [
- feed_examples[j].sorted_number_index for j in range(batch_size)
- ]
- feed_dict[gr.batch_processed_sorted_index_word_column] = [
- feed_examples[j].sorted_word_index for j in range(batch_size)
- ]
- feed_dict[gr.batch_question_number] = np.array(
- [feed_examples[j].question_number for j in range(batch_size)]).reshape(
- (batch_size, 1))
- feed_dict[gr.batch_question_number_one] = np.array(
- [feed_examples[j].question_number_1 for j in range(batch_size)]).reshape(
- (batch_size, 1))
- feed_dict[gr.batch_question_number_mask] = [
- feed_examples[j].question_number_mask for j in range(batch_size)
- ]
- feed_dict[gr.batch_question_number_one_mask] = np.array(
- [feed_examples[j].question_number_one_mask for j in range(batch_size)
- ]).reshape((batch_size, 1))
- feed_dict[gr.batch_print_answer] = [
- feed_examples[j].print_answer for j in range(batch_size)
- ]
- feed_dict[gr.batch_exact_match] = [
- feed_examples[j].exact_match for j in range(batch_size)
- ]
- feed_dict[gr.batch_group_by_max] = [
- feed_examples[j].group_by_max for j in range(batch_size)
- ]
- feed_dict[gr.batch_column_exact_match] = [
- feed_examples[j].exact_column_match for j in range(batch_size)
- ]
- feed_dict[gr.batch_ordinal_question] = [
- feed_examples[j].ordinal_question for j in range(batch_size)
- ]
- feed_dict[gr.batch_ordinal_question_one] = [
- feed_examples[j].ordinal_question_one for j in range(batch_size)
- ]
- feed_dict[gr.batch_number_column_mask] = [
- feed_examples[j].column_mask for j in range(batch_size)
- ]
- feed_dict[gr.batch_number_column_names] = [
- feed_examples[j].column_ids for j in range(batch_size)
- ]
- feed_dict[gr.batch_processed_word_column] = [
- feed_examples[j].processed_word_columns for j in range(batch_size)
- ]
- feed_dict[gr.batch_word_column_mask] = [
- feed_examples[j].word_column_mask for j in range(batch_size)
- ]
- feed_dict[gr.batch_word_column_names] = [
- feed_examples[j].word_column_ids for j in range(batch_size)
- ]
- feed_dict[gr.batch_word_column_entry_mask] = [
- feed_examples[j].word_column_entry_mask for j in range(batch_size)
- ]
- return feed_dict
|