Quoc Le 8 лет назад
Родитель
Сommit
124a501db1
7 измененных файлов с 2282 добавлено и 0 удалено
  1. 19 0
      README.md
  2. 664 0
      data_utils.py
  3. 678 0
      model.py
  4. 234 0
      neural_programmer.py
  5. 68 0
      nn_utils.py
  6. 89 0
      parameters.py
  7. 530 0
      wiki_data.py

+ 19 - 0
README.md

@@ -20,3 +20,22 @@ To propose a model for inclusion please submit a pull request.
 - [textsum](textsum) -- sequence-to-sequence with attention model for text summarization.
 - [transformer](transformer) -- spatial transformer network, which allows the spatial manipulation of data within the network
 - [im2txt](im2txt) -- image-to-text neural network for image captioning.
+=======
+Implementation of the Neural Programmer model described in https://openreview.net/pdf?id=ry2YOrcge
+
+Download the data from http://www-nlp.stanford.edu/software/sempre/wikitable/
+Change the data_dir FLAG to the location of the data
+
+Training:
+python neural_programmer.py
+
+The models are written to FLAGS.output_dir 
+
+
+Testing:
+python neural_programmer.py --evaluator_job=True
+
+The models are loaded from FLAGS.output_dir.
+The evaluation is done on development data.
+
+Maintained by Arvind Neelakantan (arvind2505)

+ 664 - 0
data_utils.py

@@ -0,0 +1,664 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions for constructing vocabulary, converting the examples to integer format and building the required masks for batch computation Author: aneelakantan (Arvind Neelakantan)
+"""
+
+import copy
+import numbers
+import numpy as np
+import wiki_data
+
+
+def return_index(a):
+  for i in range(len(a)):
+    if (a[i] == 1.0):
+      return i
+
+
+def construct_vocab(data, utility, add_word=False):
+  ans = []
+  for example in data:
+    sent = ""
+    for word in example.question:
+      if (not (isinstance(word, numbers.Number))):
+        sent += word + " "
+    example.original_nc = copy.deepcopy(example.number_columns)
+    example.original_wc = copy.deepcopy(example.word_columns)
+    example.original_nc_names = copy.deepcopy(example.number_column_names)
+    example.original_wc_names = copy.deepcopy(example.word_column_names)
+    if (add_word):
+      continue
+    number_found = 0
+    if (not (example.is_bad_example)):
+      for word in example.question:
+        if (isinstance(word, numbers.Number)):
+          number_found += 1
+        else:
+          if (not (utility.word_ids.has_key(word))):
+            utility.words.append(word)
+            utility.word_count[word] = 1
+            utility.word_ids[word] = len(utility.word_ids)
+            utility.reverse_word_ids[utility.word_ids[word]] = word
+          else:
+            utility.word_count[word] += 1
+      for col_name in example.word_column_names:
+        for word in col_name:
+          if (isinstance(word, numbers.Number)):
+            number_found += 1
+          else:
+            if (not (utility.word_ids.has_key(word))):
+              utility.words.append(word)
+              utility.word_count[word] = 1
+              utility.word_ids[word] = len(utility.word_ids)
+              utility.reverse_word_ids[utility.word_ids[word]] = word
+            else:
+              utility.word_count[word] += 1
+      for col_name in example.number_column_names:
+        for word in col_name:
+          if (isinstance(word, numbers.Number)):
+            number_found += 1
+          else:
+            if (not (utility.word_ids.has_key(word))):
+              utility.words.append(word)
+              utility.word_count[word] = 1
+              utility.word_ids[word] = len(utility.word_ids)
+              utility.reverse_word_ids[utility.word_ids[word]] = word
+            else:
+              utility.word_count[word] += 1
+
+
+def word_lookup(word, utility):
+  if (utility.word_ids.has_key(word)):
+    return word
+  else:
+    return utility.unk_token
+
+
+def convert_to_int_2d_and_pad(a, utility):
+  ans = []
+  #print a
+  for b in a:
+    temp = []
+    if (len(b) > utility.FLAGS.max_entry_length):
+      b = b[0:utility.FLAGS.max_entry_length]
+    for remaining in range(len(b), utility.FLAGS.max_entry_length):
+      b.append(utility.dummy_token)
+    assert len(b) == utility.FLAGS.max_entry_length
+    for word in b:
+      temp.append(utility.word_ids[word_lookup(word, utility)])
+    ans.append(temp)
+  #print ans
+  return ans
+
+
+def convert_to_bool_and_pad(a, utility):
+  a = a.tolist()
+  for i in range(len(a)):
+    for j in range(len(a[i])):
+      if (a[i][j] < 1):
+        a[i][j] = False
+      else:
+        a[i][j] = True
+    a[i] = a[i] + [False] * (utility.FLAGS.max_elements - len(a[i]))
+  return a
+
+
+seen_tables = {}
+
+
+def partial_match(question, table, number):
+  answer = []
+  match = {}
+  for i in range(len(table)):
+    temp = []
+    for j in range(len(table[i])):
+      temp.append(0)
+    answer.append(temp)
+  for i in range(len(table)):
+    for j in range(len(table[i])):
+      for word in question:
+        if (number):
+          if (word == table[i][j]):
+            answer[i][j] = 1.0
+            match[i] = 1.0
+        else:
+          if (word in table[i][j]):
+            answer[i][j] = 1.0
+            match[i] = 1.0
+  return answer, match
+
+
+def exact_match(question, table, number):
+  #performs exact match operation
+  answer = []
+  match = {}
+  matched_indices = []
+  for i in range(len(table)):
+    temp = []
+    for j in range(len(table[i])):
+      temp.append(0)
+    answer.append(temp)
+  for i in range(len(table)):
+    for j in range(len(table[i])):
+      if (number):
+        for word in question:
+          if (word == table[i][j]):
+            match[i] = 1.0
+            answer[i][j] = 1.0
+      else:
+        table_entry = table[i][j]
+        for k in range(len(question)):
+          if (k + len(table_entry) <= len(question)):
+            if (table_entry == question[k:(k + len(table_entry))]):
+              #if(len(table_entry) == 1):
+              #print "match: ", table_entry, question
+              match[i] = 1.0
+              answer[i][j] = 1.0
+              matched_indices.append((k, len(table_entry)))
+  return answer, match, matched_indices
+
+
+def partial_column_match(question, table, number):
+  answer = []
+  for i in range(len(table)):
+    answer.append(0)
+  for i in range(len(table)):
+    for word in question:
+      if (word in table[i]):
+        answer[i] = 1.0
+  return answer
+
+
+def exact_column_match(question, table, number):
+  #performs exact match on column names
+  answer = []
+  matched_indices = []
+  for i in range(len(table)):
+    answer.append(0)
+  for i in range(len(table)):
+    table_entry = table[i]
+    for k in range(len(question)):
+      if (k + len(table_entry) <= len(question)):
+        if (table_entry == question[k:(k + len(table_entry))]):
+          answer[i] = 1.0
+          matched_indices.append((k, len(table_entry)))
+  return answer, matched_indices
+
+
+def get_max_entry(a):
+  e = {}
+  for w in a:
+    if (w != "UNK, "):
+      if (e.has_key(w)):
+        e[w] += 1
+      else:
+        e[w] = 1
+  if (len(e) > 0):
+    (key, val) = sorted(e.items(), key=lambda x: -1 * x[1])[0]
+    if (val > 1):
+      return key
+    else:
+      return -1.0
+  else:
+    return -1.0
+
+
+def list_join(a):
+  ans = ""
+  for w in a:
+    ans += str(w) + ", "
+  return ans
+
+
+def group_by_max(table, number):
+  #computes the most frequently occuring entry in a column
+  answer = []
+  for i in range(len(table)):
+    temp = []
+    for j in range(len(table[i])):
+      temp.append(0)
+    answer.append(temp)
+  for i in range(len(table)):
+    if (number):
+      curr = table[i]
+    else:
+      curr = [list_join(w) for w in table[i]]
+    max_entry = get_max_entry(curr)
+    #print i, max_entry
+    for j in range(len(curr)):
+      if (max_entry == curr[j]):
+        answer[i][j] = 1.0
+      else:
+        answer[i][j] = 0.0
+  return answer
+
+
+def pick_one(a):
+  for i in range(len(a)):
+    if (1.0 in a[i]):
+      return True
+  return False
+
+
+def check_processed_cols(col, utility):
+  return True in [
+      True for y in col
+      if (y != utility.FLAGS.pad_int and y !=
+          utility.FLAGS.bad_number_pre_process)
+  ]
+
+
+def complete_wiki_processing(data, utility, train=True):
+  #convert to integers and padding
+  processed_data = []
+  num_bad_examples = 0
+  for example in data:
+    number_found = 0
+    if (example.is_bad_example):
+      num_bad_examples += 1
+    if (not (example.is_bad_example)):
+      example.string_question = example.question[:]
+      #entry match
+      example.processed_number_columns = example.processed_number_columns[:]
+      example.processed_word_columns = example.processed_word_columns[:]
+      example.word_exact_match, word_match, matched_indices = exact_match(
+          example.string_question, example.original_wc, number=False)
+      example.number_exact_match, number_match, _ = exact_match(
+          example.string_question, example.original_nc, number=True)
+      if (not (pick_one(example.word_exact_match)) and not (
+          pick_one(example.number_exact_match))):
+        assert len(word_match) == 0
+        assert len(number_match) == 0
+        example.word_exact_match, word_match = partial_match(
+            example.string_question, example.original_wc, number=False)
+      #group by max
+      example.word_group_by_max = group_by_max(example.original_wc, False)
+      example.number_group_by_max = group_by_max(example.original_nc, True)
+      #column name match
+      example.word_column_exact_match, wcol_matched_indices = exact_column_match(
+          example.string_question, example.original_wc_names, number=False)
+      example.number_column_exact_match, ncol_matched_indices = exact_column_match(
+          example.string_question, example.original_nc_names, number=False)
+      if (not (1.0 in example.word_column_exact_match) and not (
+          1.0 in example.number_column_exact_match)):
+        example.word_column_exact_match = partial_column_match(
+            example.string_question, example.original_wc_names, number=False)
+        example.number_column_exact_match = partial_column_match(
+            example.string_question, example.original_nc_names, number=False)
+      if (len(word_match) > 0 or len(number_match) > 0):
+        example.question.append(utility.entry_match_token)
+      if (1.0 in example.word_column_exact_match or
+          1.0 in example.number_column_exact_match):
+        example.question.append(utility.column_match_token)
+      example.string_question = example.question[:]
+      example.number_lookup_matrix = np.transpose(
+          example.number_lookup_matrix)[:]
+      example.word_lookup_matrix = np.transpose(example.word_lookup_matrix)[:]
+      example.columns = example.number_columns[:]
+      example.word_columns = example.word_columns[:]
+      example.len_total_cols = len(example.word_column_names) + len(
+          example.number_column_names)
+      example.column_names = example.number_column_names[:]
+      example.word_column_names = example.word_column_names[:]
+      example.string_column_names = example.number_column_names[:]
+      example.string_word_column_names = example.word_column_names[:]
+      example.sorted_number_index = []
+      example.sorted_word_index = []
+      example.column_mask = []
+      example.word_column_mask = []
+      example.processed_column_mask = []
+      example.processed_word_column_mask = []
+      example.word_column_entry_mask = []
+      example.question_attention_mask = []
+      example.question_number = example.question_number_1 = -1
+      example.question_attention_mask = []
+      example.ordinal_question = []
+      example.ordinal_question_one = []
+      new_question = []
+      if (len(example.number_columns) > 0):
+        example.len_col = len(example.number_columns[0])
+      else:
+        example.len_col = len(example.word_columns[0])
+      for (start, length) in matched_indices:
+        for j in range(length):
+          example.question[start + j] = utility.unk_token
+      #print example.question
+      for word in example.question:
+        if (isinstance(word, numbers.Number) or wiki_data.is_date(word)):
+          if (not (isinstance(word, numbers.Number)) and
+              wiki_data.is_date(word)):
+            word = word.replace("X", "").replace("-", "")
+          number_found += 1
+          if (number_found == 1):
+            example.question_number = word
+            if (len(example.ordinal_question) > 0):
+              example.ordinal_question[len(example.ordinal_question) - 1] = 1.0
+            else:
+              example.ordinal_question.append(1.0)
+          elif (number_found == 2):
+            example.question_number_1 = word
+            if (len(example.ordinal_question_one) > 0):
+              example.ordinal_question_one[len(example.ordinal_question_one) -
+                                           1] = 1.0
+            else:
+              example.ordinal_question_one.append(1.0)
+        else:
+          new_question.append(word)
+          example.ordinal_question.append(0.0)
+          example.ordinal_question_one.append(0.0)
+      example.question = [
+          utility.word_ids[word_lookup(w, utility)] for w in new_question
+      ]
+      example.question_attention_mask = [0.0] * len(example.question)
+      #when the first question number occurs before a word
+      example.ordinal_question = example.ordinal_question[0:len(
+          example.question)]
+      example.ordinal_question_one = example.ordinal_question_one[0:len(
+          example.question)]
+      #question-padding
+      example.question = [utility.word_ids[utility.dummy_token]] * (
+          utility.FLAGS.question_length - len(example.question)
+      ) + example.question
+      example.question_attention_mask = [-10000.0] * (
+          utility.FLAGS.question_length - len(example.question_attention_mask)
+      ) + example.question_attention_mask
+      example.ordinal_question = [0.0] * (utility.FLAGS.question_length -
+                                          len(example.ordinal_question)
+                                         ) + example.ordinal_question
+      example.ordinal_question_one = [0.0] * (utility.FLAGS.question_length -
+                                              len(example.ordinal_question_one)
+                                             ) + example.ordinal_question_one
+      if (True):
+        #number columns and related-padding
+        num_cols = len(example.columns)
+        start = 0
+        for column in example.number_columns:
+          if (check_processed_cols(example.processed_number_columns[start],
+                                   utility)):
+            example.processed_column_mask.append(0.0)
+          sorted_index = sorted(
+              range(len(example.processed_number_columns[start])),
+              key=lambda k: example.processed_number_columns[start][k],
+              reverse=True)
+          sorted_index = sorted_index + [utility.FLAGS.pad_int] * (
+              utility.FLAGS.max_elements - len(sorted_index))
+          example.sorted_number_index.append(sorted_index)
+          example.columns[start] = column + [utility.FLAGS.pad_int] * (
+              utility.FLAGS.max_elements - len(column))
+          example.processed_number_columns[start] += [utility.FLAGS.pad_int] * (
+              utility.FLAGS.max_elements -
+              len(example.processed_number_columns[start]))
+          start += 1
+          example.column_mask.append(0.0)
+        for remaining in range(num_cols, utility.FLAGS.max_number_cols):
+          example.sorted_number_index.append([utility.FLAGS.pad_int] *
+                                             (utility.FLAGS.max_elements))
+          example.columns.append([utility.FLAGS.pad_int] *
+                                 (utility.FLAGS.max_elements))
+          example.processed_number_columns.append([utility.FLAGS.pad_int] *
+                                                  (utility.FLAGS.max_elements))
+          example.number_exact_match.append([0.0] *
+                                            (utility.FLAGS.max_elements))
+          example.number_group_by_max.append([0.0] *
+                                             (utility.FLAGS.max_elements))
+          example.column_mask.append(-100000000.0)
+          example.processed_column_mask.append(-100000000.0)
+          example.number_column_exact_match.append(0.0)
+          example.column_names.append([utility.dummy_token])
+        #word column  and related-padding
+        start = 0
+        word_num_cols = len(example.word_columns)
+        for column in example.word_columns:
+          if (check_processed_cols(example.processed_word_columns[start],
+                                   utility)):
+            example.processed_word_column_mask.append(0.0)
+          sorted_index = sorted(
+              range(len(example.processed_word_columns[start])),
+              key=lambda k: example.processed_word_columns[start][k],
+              reverse=True)
+          sorted_index = sorted_index + [utility.FLAGS.pad_int] * (
+              utility.FLAGS.max_elements - len(sorted_index))
+          example.sorted_word_index.append(sorted_index)
+          column = convert_to_int_2d_and_pad(column, utility)
+          example.word_columns[start] = column + [[
+              utility.word_ids[utility.dummy_token]
+          ] * utility.FLAGS.max_entry_length] * (utility.FLAGS.max_elements -
+                                                 len(column))
+          example.processed_word_columns[start] += [utility.FLAGS.pad_int] * (
+              utility.FLAGS.max_elements -
+              len(example.processed_word_columns[start]))
+          example.word_column_entry_mask.append([0] * len(column) + [
+              utility.word_ids[utility.dummy_token]
+          ] * (utility.FLAGS.max_elements - len(column)))
+          start += 1
+          example.word_column_mask.append(0.0)
+        for remaining in range(word_num_cols, utility.FLAGS.max_word_cols):
+          example.sorted_word_index.append([utility.FLAGS.pad_int] *
+                                           (utility.FLAGS.max_elements))
+          example.word_columns.append([[utility.word_ids[utility.dummy_token]] *
+                                       utility.FLAGS.max_entry_length] *
+                                      (utility.FLAGS.max_elements))
+          example.word_column_entry_mask.append(
+              [utility.word_ids[utility.dummy_token]] *
+              (utility.FLAGS.max_elements))
+          example.word_exact_match.append([0.0] * (utility.FLAGS.max_elements))
+          example.word_group_by_max.append([0.0] * (utility.FLAGS.max_elements))
+          example.processed_word_columns.append([utility.FLAGS.pad_int] *
+                                                (utility.FLAGS.max_elements))
+          example.word_column_mask.append(-100000000.0)
+          example.processed_word_column_mask.append(-100000000.0)
+          example.word_column_exact_match.append(0.0)
+          example.word_column_names.append([utility.dummy_token] *
+                                           utility.FLAGS.max_entry_length)
+        seen_tables[example.table_key] = 1
+      #convert column and word column names to integers
+      example.column_ids = convert_to_int_2d_and_pad(example.column_names,
+                                                     utility)
+      example.word_column_ids = convert_to_int_2d_and_pad(
+          example.word_column_names, utility)
+      for i_em in range(len(example.number_exact_match)):
+        example.number_exact_match[i_em] = example.number_exact_match[
+            i_em] + [0.0] * (utility.FLAGS.max_elements -
+                             len(example.number_exact_match[i_em]))
+        example.number_group_by_max[i_em] = example.number_group_by_max[
+            i_em] + [0.0] * (utility.FLAGS.max_elements -
+                             len(example.number_group_by_max[i_em]))
+      for i_em in range(len(example.word_exact_match)):
+        example.word_exact_match[i_em] = example.word_exact_match[
+            i_em] + [0.0] * (utility.FLAGS.max_elements -
+                             len(example.word_exact_match[i_em]))
+        example.word_group_by_max[i_em] = example.word_group_by_max[
+            i_em] + [0.0] * (utility.FLAGS.max_elements -
+                             len(example.word_group_by_max[i_em]))
+      example.exact_match = example.number_exact_match + example.word_exact_match
+      example.group_by_max = example.number_group_by_max + example.word_group_by_max
+      example.exact_column_match = example.number_column_exact_match + example.word_column_exact_match
+      #answer and related mask, padding
+      if (example.is_lookup):
+        example.answer = example.calc_answer
+        example.number_print_answer = example.number_lookup_matrix.tolist()
+        example.word_print_answer = example.word_lookup_matrix.tolist()
+        for i_answer in range(len(example.number_print_answer)):
+          example.number_print_answer[i_answer] = example.number_print_answer[
+              i_answer] + [0.0] * (utility.FLAGS.max_elements -
+                                   len(example.number_print_answer[i_answer]))
+        for i_answer in range(len(example.word_print_answer)):
+          example.word_print_answer[i_answer] = example.word_print_answer[
+              i_answer] + [0.0] * (utility.FLAGS.max_elements -
+                                   len(example.word_print_answer[i_answer]))
+        example.number_lookup_matrix = convert_to_bool_and_pad(
+            example.number_lookup_matrix, utility)
+        example.word_lookup_matrix = convert_to_bool_and_pad(
+            example.word_lookup_matrix, utility)
+        for remaining in range(num_cols, utility.FLAGS.max_number_cols):
+          example.number_lookup_matrix.append([False] *
+                                              utility.FLAGS.max_elements)
+          example.number_print_answer.append([0.0] * utility.FLAGS.max_elements)
+        for remaining in range(word_num_cols, utility.FLAGS.max_word_cols):
+          example.word_lookup_matrix.append([False] *
+                                            utility.FLAGS.max_elements)
+          example.word_print_answer.append([0.0] * utility.FLAGS.max_elements)
+        example.print_answer = example.number_print_answer + example.word_print_answer
+      else:
+        example.answer = example.calc_answer
+        example.print_answer = [[0.0] * (utility.FLAGS.max_elements)] * (
+            utility.FLAGS.max_number_cols + utility.FLAGS.max_word_cols)
+      #question_number masks
+      if (example.question_number == -1):
+        example.question_number_mask = np.zeros([utility.FLAGS.max_elements])
+      else:
+        example.question_number_mask = np.ones([utility.FLAGS.max_elements])
+      if (example.question_number_1 == -1):
+        example.question_number_one_mask = -10000.0
+      else:
+        example.question_number_one_mask = np.float64(0.0)
+      if (example.len_col > utility.FLAGS.max_elements):
+        continue
+      processed_data.append(example)
+  return processed_data
+
+
+def add_special_words(utility):
+  utility.words.append(utility.entry_match_token)
+  utility.word_ids[utility.entry_match_token] = len(utility.word_ids)
+  utility.reverse_word_ids[utility.word_ids[
+      utility.entry_match_token]] = utility.entry_match_token
+  utility.entry_match_token_id = utility.word_ids[utility.entry_match_token]
+  print "entry match token: ", utility.word_ids[
+      utility.entry_match_token], utility.entry_match_token_id
+  utility.words.append(utility.column_match_token)
+  utility.word_ids[utility.column_match_token] = len(utility.word_ids)
+  utility.reverse_word_ids[utility.word_ids[
+      utility.column_match_token]] = utility.column_match_token
+  utility.column_match_token_id = utility.word_ids[utility.column_match_token]
+  print "entry match token: ", utility.word_ids[
+      utility.column_match_token], utility.column_match_token_id
+  utility.words.append(utility.dummy_token)
+  utility.word_ids[utility.dummy_token] = len(utility.word_ids)
+  utility.reverse_word_ids[utility.word_ids[
+      utility.dummy_token]] = utility.dummy_token
+  utility.dummy_token_id = utility.word_ids[utility.dummy_token]
+  utility.words.append(utility.unk_token)
+  utility.word_ids[utility.unk_token] = len(utility.word_ids)
+  utility.reverse_word_ids[utility.word_ids[
+      utility.unk_token]] = utility.unk_token
+
+
+def perform_word_cutoff(utility):
+  if (utility.FLAGS.word_cutoff > 0):
+    for word in utility.word_ids.keys():
+      if (utility.word_count.has_key(word) and utility.word_count[word] <
+          utility.FLAGS.word_cutoff and word != utility.unk_token and
+          word != utility.dummy_token and word != utility.entry_match_token and
+          word != utility.column_match_token):
+        utility.word_ids.pop(word)
+        utility.words.remove(word)
+
+
+def word_dropout(question, utility):
+  if (utility.FLAGS.word_dropout_prob > 0.0):
+    new_question = []
+    for i in range(len(question)):
+      if (question[i] != utility.dummy_token_id and
+          utility.random.random() > utility.FLAGS.word_dropout_prob):
+        new_question.append(utility.word_ids[utility.unk_token])
+      else:
+        new_question.append(question[i])
+    return new_question
+  else:
+    return question
+
+
+def generate_feed_dict(data, curr, batch_size, gr, train=False, utility=None):
+  #prepare feed dict dictionary
+  feed_dict = {}
+  feed_examples = []
+  for j in range(batch_size):
+    feed_examples.append(data[curr + j])
+  if (train):
+    feed_dict[gr.batch_question] = [
+        word_dropout(feed_examples[j].question, utility)
+        for j in range(batch_size)
+    ]
+  else:
+    feed_dict[gr.batch_question] = [
+        feed_examples[j].question for j in range(batch_size)
+    ]
+  feed_dict[gr.batch_question_attention_mask] = [
+      feed_examples[j].question_attention_mask for j in range(batch_size)
+  ]
+  feed_dict[
+      gr.batch_answer] = [feed_examples[j].answer for j in range(batch_size)]
+  feed_dict[gr.batch_number_column] = [
+      feed_examples[j].columns for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_processed_number_column] = [
+      feed_examples[j].processed_number_columns for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_processed_sorted_index_number_column] = [
+      feed_examples[j].sorted_number_index for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_processed_sorted_index_word_column] = [
+      feed_examples[j].sorted_word_index for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_question_number] = np.array(
+      [feed_examples[j].question_number for j in range(batch_size)]).reshape(
+          (batch_size, 1))
+  feed_dict[gr.batch_question_number_one] = np.array(
+      [feed_examples[j].question_number_1 for j in range(batch_size)]).reshape(
+          (batch_size, 1))
+  feed_dict[gr.batch_question_number_mask] = [
+      feed_examples[j].question_number_mask for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_question_number_one_mask] = np.array(
+      [feed_examples[j].question_number_one_mask for j in range(batch_size)
+      ]).reshape((batch_size, 1))
+  feed_dict[gr.batch_print_answer] = [
+      feed_examples[j].print_answer for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_exact_match] = [
+      feed_examples[j].exact_match for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_group_by_max] = [
+      feed_examples[j].group_by_max for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_column_exact_match] = [
+      feed_examples[j].exact_column_match for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_ordinal_question] = [
+      feed_examples[j].ordinal_question for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_ordinal_question_one] = [
+      feed_examples[j].ordinal_question_one for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_number_column_mask] = [
+      feed_examples[j].column_mask for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_number_column_names] = [
+      feed_examples[j].column_ids for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_processed_word_column] = [
+      feed_examples[j].processed_word_columns for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_word_column_mask] = [
+      feed_examples[j].word_column_mask for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_word_column_names] = [
+      feed_examples[j].word_column_ids for j in range(batch_size)
+  ]
+  feed_dict[gr.batch_word_column_entry_mask] = [
+      feed_examples[j].word_column_entry_mask for j in range(batch_size)
+  ]
+  return feed_dict

+ 678 - 0
model.py

@@ -0,0 +1,678 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Author: aneelakantan (Arvind Neelakantan)
+"""
+
+import numpy as np
+import tensorflow as tf
+import nn_utils
+
+
+class Graph():
+
+  def __init__(self, utility, batch_size, max_passes, mode="train"):
+    self.utility = utility
+    self.data_type = self.utility.tf_data_type[self.utility.FLAGS.data_type]
+    self.max_elements = self.utility.FLAGS.max_elements
+    max_elements = self.utility.FLAGS.max_elements
+    self.num_cols = self.utility.FLAGS.max_number_cols
+    self.num_word_cols = self.utility.FLAGS.max_word_cols
+    self.question_length = self.utility.FLAGS.question_length
+    self.batch_size = batch_size
+    self.max_passes = max_passes
+    self.mode = mode
+    self.embedding_dims = self.utility.FLAGS.embedding_dims
+    #input question and a mask
+    self.batch_question = tf.placeholder(tf.int32,
+                                         [batch_size, self.question_length])
+    self.batch_question_attention_mask = tf.placeholder(
+        self.data_type, [batch_size, self.question_length])
+    #ground truth scalar answer and lookup answer
+    self.batch_answer = tf.placeholder(self.data_type, [batch_size])
+    self.batch_print_answer = tf.placeholder(
+        self.data_type,
+        [batch_size, self.num_cols + self.num_word_cols, max_elements])
+    #number columns and its processed version
+    self.batch_number_column = tf.placeholder(
+        self.data_type, [batch_size, self.num_cols, max_elements
+                        ])  #columns with numeric entries
+    self.batch_processed_number_column = tf.placeholder(
+        self.data_type, [batch_size, self.num_cols, max_elements])
+    self.batch_processed_sorted_index_number_column = tf.placeholder(
+        tf.int32, [batch_size, self.num_cols, max_elements])
+    #word columns and its processed version
+    self.batch_processed_word_column = tf.placeholder(
+        self.data_type, [batch_size, self.num_word_cols, max_elements])
+    self.batch_processed_sorted_index_word_column = tf.placeholder(
+        tf.int32, [batch_size, self.num_word_cols, max_elements])
+    self.batch_word_column_entry_mask = tf.placeholder(
+        tf.int32, [batch_size, self.num_word_cols, max_elements])
+    #names of word and number columns along with their mask
+    self.batch_word_column_names = tf.placeholder(
+        tf.int32,
+        [batch_size, self.num_word_cols, self.utility.FLAGS.max_entry_length])
+    self.batch_word_column_mask = tf.placeholder(
+        self.data_type, [batch_size, self.num_word_cols])
+    self.batch_number_column_names = tf.placeholder(
+        tf.int32,
+        [batch_size, self.num_cols, self.utility.FLAGS.max_entry_length])
+    self.batch_number_column_mask = tf.placeholder(self.data_type,
+                                                   [batch_size, self.num_cols])
+    #exact match and group by max operation
+    self.batch_exact_match = tf.placeholder(
+        self.data_type,
+        [batch_size, self.num_cols + self.num_word_cols, max_elements])
+    self.batch_column_exact_match = tf.placeholder(
+        self.data_type, [batch_size, self.num_cols + self.num_word_cols])
+    self.batch_group_by_max = tf.placeholder(
+        self.data_type,
+        [batch_size, self.num_cols + self.num_word_cols, max_elements])
+    #numbers in the question along with their position. This is used to compute arguments to the comparison operations
+    self.batch_question_number = tf.placeholder(self.data_type, [batch_size, 1])
+    self.batch_question_number_one = tf.placeholder(self.data_type,
+                                                    [batch_size, 1])
+    self.batch_question_number_mask = tf.placeholder(
+        self.data_type, [batch_size, max_elements])
+    self.batch_question_number_one_mask = tf.placeholder(self.data_type,
+                                                         [batch_size, 1])
+    self.batch_ordinal_question = tf.placeholder(
+        self.data_type, [batch_size, self.question_length])
+    self.batch_ordinal_question_one = tf.placeholder(
+        self.data_type, [batch_size, self.question_length])
+
+  def LSTM_question_embedding(self, sentence, sentence_length):
+    #LSTM processes the input question
+    lstm_params = "question_lstm"
+    hidden_vectors = []
+    sentence = self.batch_question
+    question_hidden = tf.zeros(
+        [self.batch_size, self.utility.FLAGS.embedding_dims], self.data_type)
+    question_c_hidden = tf.zeros(
+        [self.batch_size, self.utility.FLAGS.embedding_dims], self.data_type)
+    if (self.utility.FLAGS.rnn_dropout > 0.0):
+      if (self.mode == "train"):
+        rnn_dropout_mask = tf.cast(
+            tf.random_uniform(
+                tf.shape(question_hidden), minval=0.0, maxval=1.0) <
+            self.utility.FLAGS.rnn_dropout,
+            self.data_type) / self.utility.FLAGS.rnn_dropout
+      else:
+        rnn_dropout_mask = tf.ones_like(question_hidden)
+    for question_iterator in range(self.question_length):
+      curr_word = sentence[:, question_iterator]
+      question_vector = nn_utils.apply_dropout(
+          nn_utils.get_embedding(curr_word, self.utility, self.params),
+          self.utility.FLAGS.dropout, self.mode)
+      question_hidden, question_c_hidden = nn_utils.LSTMCell(
+          question_vector, question_hidden, question_c_hidden, lstm_params,
+          self.params)
+      if (self.utility.FLAGS.rnn_dropout > 0.0):
+        question_hidden = question_hidden * rnn_dropout_mask
+      hidden_vectors.append(tf.expand_dims(question_hidden, 0))
+    hidden_vectors = tf.concat(0, hidden_vectors)
+    return question_hidden, hidden_vectors
+
+  def history_recurrent_step(self, curr_hprev, hprev):
+    #A single RNN step for controller or history RNN
+    return tf.tanh(
+        tf.matmul(
+            tf.concat(1, [hprev, curr_hprev]), self.params[
+                "history_recurrent"])) + self.params["history_recurrent_bias"]
+
+  def question_number_softmax(self, hidden_vectors):
+    #Attention on quetsion to decide the question number to passed to comparison ops
+    def compute_ans(op_embedding, comparison):
+      op_embedding = tf.expand_dims(op_embedding, 0)
+      #dot product of operation embedding with hidden state to the left of the number occurence
+      first = tf.transpose(
+          tf.matmul(op_embedding,
+                    tf.transpose(
+                        tf.reduce_sum(hidden_vectors * tf.tile(
+                            tf.expand_dims(
+                                tf.transpose(self.batch_ordinal_question), 2),
+                            [1, 1, self.utility.FLAGS.embedding_dims]), 0))))
+      second = self.batch_question_number_one_mask + tf.transpose(
+          tf.matmul(op_embedding,
+                    tf.transpose(
+                        tf.reduce_sum(hidden_vectors * tf.tile(
+                            tf.expand_dims(
+                                tf.transpose(self.batch_ordinal_question_one), 2
+                            ), [1, 1, self.utility.FLAGS.embedding_dims]), 0))))
+      question_number_softmax = tf.nn.softmax(tf.concat(1, [first, second]))
+      if (self.mode == "test"):
+        cond = tf.equal(question_number_softmax,
+                        tf.reshape(
+                            tf.reduce_max(question_number_softmax, 1),
+                            [self.batch_size, 1]))
+        question_number_softmax = tf.select(
+            cond,
+            tf.fill(tf.shape(question_number_softmax), 1.0),
+            tf.fill(tf.shape(question_number_softmax), 0.0))
+        question_number_softmax = tf.cast(question_number_softmax,
+                                          self.data_type)
+      ans = tf.reshape(
+          tf.reduce_sum(question_number_softmax * tf.concat(
+              1, [self.batch_question_number, self.batch_question_number_one]),
+                        1), [self.batch_size, 1])
+      return ans
+
+    def compute_op_position(op_name):
+      for i in range(len(self.utility.operations_set)):
+        if (op_name == self.utility.operations_set[i]):
+          return i
+
+    def compute_question_number(op_name):
+      op_embedding = tf.nn.embedding_lookup(self.params_unit,
+                                            compute_op_position(op_name))
+      return compute_ans(op_embedding, op_name)
+
+    curr_greater_question_number = compute_question_number("greater")
+    curr_lesser_question_number = compute_question_number("lesser")
+    curr_geq_question_number = compute_question_number("geq")
+    curr_leq_question_number = compute_question_number("leq")
+    return curr_greater_question_number, curr_lesser_question_number, curr_geq_question_number, curr_leq_question_number
+
+  def perform_attention(self, context_vector, hidden_vectors, length, mask):
+    #Performs attention on hiddent_vectors using context vector
+    context_vector = tf.tile(
+        tf.expand_dims(context_vector, 0), [length, 1, 1])  #time * bs * d
+    attention_softmax = tf.nn.softmax(
+        tf.transpose(tf.reduce_sum(context_vector * hidden_vectors, 2)) +
+        mask)  #batch_size * time
+    attention_softmax = tf.tile(
+        tf.expand_dims(tf.transpose(attention_softmax), 2),
+        [1, 1, self.embedding_dims])
+    ans_vector = tf.reduce_sum(attention_softmax * hidden_vectors, 0)
+    return ans_vector
+
+  #computes embeddings for column names using parameters of question module
+  def get_column_hidden_vectors(self):
+    #vector representations for the column names
+    self.column_hidden_vectors = tf.reduce_sum(
+        nn_utils.get_embedding(self.batch_number_column_names, self.utility,
+                               self.params), 2)
+    self.word_column_hidden_vectors = tf.reduce_sum(
+        nn_utils.get_embedding(self.batch_word_column_names, self.utility,
+                               self.params), 2)
+
+  def create_summary_embeddings(self):
+    #embeddings for each text entry in the table using parameters of the question module
+    self.summary_text_entry_embeddings = tf.reduce_sum(
+        tf.expand_dims(self.batch_exact_match, 3) * tf.expand_dims(
+            tf.expand_dims(
+                tf.expand_dims(
+                    nn_utils.get_embedding(self.utility.entry_match_token_id,
+                                           self.utility, self.params), 0), 1),
+            2), 2)
+
+  def compute_column_softmax(self, column_controller_vector, time_step):
+    #compute softmax over all the columns using column controller vector
+    column_controller_vector = tf.tile(
+        tf.expand_dims(column_controller_vector, 1),
+        [1, self.num_cols + self.num_word_cols, 1])  #max_cols * bs * d
+    column_controller_vector = nn_utils.apply_dropout(
+        column_controller_vector, self.utility.FLAGS.dropout, self.mode)
+    self.full_column_hidden_vectors = tf.concat(
+        1, [self.column_hidden_vectors, self.word_column_hidden_vectors])
+    self.full_column_hidden_vectors += self.summary_text_entry_embeddings
+    self.full_column_hidden_vectors = nn_utils.apply_dropout(
+        self.full_column_hidden_vectors, self.utility.FLAGS.dropout, self.mode)
+    column_logits = tf.reduce_sum(
+        column_controller_vector * self.full_column_hidden_vectors, 2) + (
+            self.params["word_match_feature_column_name"] *
+            self.batch_column_exact_match) + self.full_column_mask
+    column_softmax = tf.nn.softmax(column_logits)  #batch_size * max_cols
+    return column_softmax
+
+  def compute_first_or_last(self, select, first=True):
+    #perform first ot last operation on row select with probabilistic row selection
+    answer = tf.zeros_like(select)
+    running_sum = tf.zeros([self.batch_size, 1], self.data_type)
+    for i in range(self.max_elements):
+      if (first):
+        current = tf.slice(select, [0, i], [self.batch_size, 1])
+      else:
+        current = tf.slice(select, [0, self.max_elements - 1 - i],
+                           [self.batch_size, 1])
+      curr_prob = current * (1 - running_sum)
+      curr_prob = curr_prob * tf.cast(curr_prob >= 0.0, self.data_type)
+      running_sum += curr_prob
+      temp_ans = []
+      curr_prob = tf.expand_dims(tf.reshape(curr_prob, [self.batch_size]), 0)
+      for i_ans in range(self.max_elements):
+        if (not (first) and i_ans == self.max_elements - 1 - i):
+          temp_ans.append(curr_prob)
+        elif (first and i_ans == i):
+          temp_ans.append(curr_prob)
+        else:
+          temp_ans.append(tf.zeros_like(curr_prob))
+      temp_ans = tf.transpose(tf.concat(0, temp_ans))
+      answer += temp_ans
+    return answer
+
+  def make_hard_softmax(self, softmax):
+    #converts soft selection to hard selection. used at test time
+    cond = tf.equal(
+        softmax, tf.reshape(tf.reduce_max(softmax, 1), [self.batch_size, 1]))
+    softmax = tf.select(
+        cond, tf.fill(tf.shape(softmax), 1.0), tf.fill(tf.shape(softmax), 0.0))
+    softmax = tf.cast(softmax, self.data_type)
+    return softmax
+
+  def compute_max_or_min(self, select, maxi=True):
+    #computes the argmax and argmin of a column with probabilistic row selection
+    answer = tf.zeros([
+        self.batch_size, self.num_cols + self.num_word_cols, self.max_elements
+    ], self.data_type)
+    sum_prob = tf.zeros([self.batch_size, self.num_cols + self.num_word_cols],
+                        self.data_type)
+    for j in range(self.max_elements):
+      if (maxi):
+        curr_pos = j
+      else:
+        curr_pos = self.max_elements - 1 - j
+      select_index = tf.slice(self.full_processed_sorted_index_column,
+                              [0, 0, curr_pos], [self.batch_size, -1, 1])
+      select_mask = tf.equal(
+          tf.tile(
+              tf.expand_dims(
+                  tf.tile(
+                      tf.expand_dims(tf.range(self.max_elements), 0),
+                      [self.batch_size, 1]), 1),
+              [1, self.num_cols + self.num_word_cols, 1]), select_index)
+      curr_prob = tf.expand_dims(select, 1) * tf.cast(
+          select_mask, self.data_type) * self.select_bad_number_mask
+      curr_prob = curr_prob * tf.expand_dims((1 - sum_prob), 2)
+      curr_prob = curr_prob * tf.expand_dims(
+          tf.cast((1 - sum_prob) > 0.0, self.data_type), 2)
+      answer = tf.select(select_mask, curr_prob, answer)
+      sum_prob += tf.reduce_sum(curr_prob, 2)
+    return answer
+
+  def perform_operations(self, softmax, full_column_softmax, select,
+                         prev_select_1, curr_pass):
+    #performs all the 15 operations. computes scalar output, lookup answer and row selector
+    column_softmax = tf.slice(full_column_softmax, [0, 0],
+                              [self.batch_size, self.num_cols])
+    word_column_softmax = tf.slice(full_column_softmax, [0, self.num_cols],
+                                   [self.batch_size, self.num_word_cols])
+    init_max = self.compute_max_or_min(select, maxi=True)
+    init_min = self.compute_max_or_min(select, maxi=False)
+    #operations that are column  independent
+    count = tf.reshape(tf.reduce_sum(select, 1), [self.batch_size, 1])
+    select_full_column_softmax = tf.tile(
+        tf.expand_dims(full_column_softmax, 2),
+        [1, 1, self.max_elements
+        ])  #BS * (max_cols + max_word_cols) * max_elements
+    select_word_column_softmax = tf.tile(
+        tf.expand_dims(word_column_softmax, 2),
+        [1, 1, self.max_elements])  #BS * max_word_cols * max_elements
+    select_greater = tf.reduce_sum(
+        self.init_select_greater * select_full_column_softmax,
+        1) * self.batch_question_number_mask  #BS * max_elements
+    select_lesser = tf.reduce_sum(
+        self.init_select_lesser * select_full_column_softmax,
+        1) * self.batch_question_number_mask  #BS * max_elements
+    select_geq = tf.reduce_sum(
+        self.init_select_geq * select_full_column_softmax,
+        1) * self.batch_question_number_mask  #BS * max_elements
+    select_leq = tf.reduce_sum(
+        self.init_select_leq * select_full_column_softmax,
+        1) * self.batch_question_number_mask  #BS * max_elements
+    select_max = tf.reduce_sum(init_max * select_full_column_softmax,
+                               1)  #BS * max_elements
+    select_min = tf.reduce_sum(init_min * select_full_column_softmax,
+                               1)  #BS * max_elements
+    select_prev = tf.concat(1, [
+        tf.slice(select, [0, 1], [self.batch_size, self.max_elements - 1]),
+        tf.cast(tf.zeros([self.batch_size, 1]), self.data_type)
+    ])
+    select_next = tf.concat(1, [
+        tf.cast(tf.zeros([self.batch_size, 1]), self.data_type), tf.slice(
+            select, [0, 0], [self.batch_size, self.max_elements - 1])
+    ])
+    select_last_rs = self.compute_first_or_last(select, False)
+    select_first_rs = self.compute_first_or_last(select, True)
+    select_word_match = tf.reduce_sum(self.batch_exact_match *
+                                      select_full_column_softmax, 1)
+    select_group_by_max = tf.reduce_sum(self.batch_group_by_max *
+                                        select_full_column_softmax, 1)
+    length_content = 1
+    length_select = 13
+    length_print = 1
+    values = tf.concat(1, [count])
+    softmax_content = tf.slice(softmax, [0, 0],
+                               [self.batch_size, length_content])
+    #compute scalar output
+    output = tf.reduce_sum(tf.mul(softmax_content, values), 1)
+    #compute lookup answer
+    softmax_print = tf.slice(softmax, [0, length_content + length_select],
+                             [self.batch_size, length_print])
+    curr_print = select_full_column_softmax * tf.tile(
+        tf.expand_dims(select, 1),
+        [1, self.num_cols + self.num_word_cols, 1
+        ])  #BS * max_cols * max_elements (conisders only column)
+    self.batch_lookup_answer = curr_print * tf.tile(
+        tf.expand_dims(softmax_print, 2),
+        [1, self.num_cols + self.num_word_cols, self.max_elements
+        ])  #BS * max_cols * max_elements
+    self.batch_lookup_answer = self.batch_lookup_answer * self.select_full_mask
+    #compute row select
+    softmax_select = tf.slice(softmax, [0, length_content],
+                              [self.batch_size, length_select])
+    select_lists = [
+        tf.expand_dims(select_prev, 1), tf.expand_dims(select_next, 1),
+        tf.expand_dims(select_first_rs, 1), tf.expand_dims(select_last_rs, 1),
+        tf.expand_dims(select_group_by_max, 1),
+        tf.expand_dims(select_greater, 1), tf.expand_dims(select_lesser, 1),
+        tf.expand_dims(select_geq, 1), tf.expand_dims(select_leq, 1),
+        tf.expand_dims(select_max, 1), tf.expand_dims(select_min, 1),
+        tf.expand_dims(select_word_match, 1),
+        tf.expand_dims(self.reset_select, 1)
+    ]
+    select = tf.reduce_sum(
+        tf.tile(tf.expand_dims(softmax_select, 2), [1, 1, self.max_elements]) *
+        tf.concat(1, select_lists), 1)
+    select = select * self.select_whole_mask
+    return output, select
+
+  def one_pass(self, select, question_embedding, hidden_vectors, hprev,
+               prev_select_1, curr_pass):
+    #Performs one timestep which involves selecting an operation and a column
+    attention_vector = self.perform_attention(
+        hprev, hidden_vectors, self.question_length,
+        self.batch_question_attention_mask)  #batch_size * embedding_dims
+    controller_vector = tf.nn.relu(
+        tf.matmul(hprev, self.params["controller_prev"]) + tf.matmul(
+            tf.concat(1, [question_embedding, attention_vector]), self.params[
+                "controller"]))
+    column_controller_vector = tf.nn.relu(
+        tf.matmul(hprev, self.params["column_controller_prev"]) + tf.matmul(
+            tf.concat(1, [question_embedding, attention_vector]), self.params[
+                "column_controller"]))
+    controller_vector = nn_utils.apply_dropout(
+        controller_vector, self.utility.FLAGS.dropout, self.mode)
+    self.operation_logits = tf.matmul(controller_vector,
+                                      tf.transpose(self.params_unit))
+    softmax = tf.nn.softmax(self.operation_logits)
+    soft_softmax = softmax
+    #compute column softmax: bs * max_columns
+    weighted_op_representation = tf.transpose(
+        tf.matmul(tf.transpose(self.params_unit), tf.transpose(softmax)))
+    column_controller_vector = tf.nn.relu(
+        tf.matmul(
+            tf.concat(1, [
+                column_controller_vector, weighted_op_representation
+            ]), self.params["break_conditional"]))
+    full_column_softmax = self.compute_column_softmax(column_controller_vector,
+                                                      curr_pass)
+    soft_column_softmax = full_column_softmax
+    if (self.mode == "test"):
+      full_column_softmax = self.make_hard_softmax(full_column_softmax)
+      softmax = self.make_hard_softmax(softmax)
+    output, select = self.perform_operations(softmax, full_column_softmax,
+                                             select, prev_select_1, curr_pass)
+    return output, select, softmax, soft_softmax, full_column_softmax, soft_column_softmax
+
+  def compute_lookup_error(self, val):
+    #computes lookup error.
+    cond = tf.equal(self.batch_print_answer, val)
+    inter = tf.select(
+        cond, self.init_print_error,
+        tf.tile(
+            tf.reshape(tf.constant(1e10, self.data_type), [1, 1, 1]), [
+                self.batch_size, self.utility.FLAGS.max_word_cols +
+                self.utility.FLAGS.max_number_cols,
+                self.utility.FLAGS.max_elements
+            ]))
+    return tf.reduce_min(tf.reduce_min(inter, 1), 1) * tf.cast(
+        tf.greater(
+            tf.reduce_sum(tf.reduce_sum(tf.cast(cond, self.data_type), 1), 1),
+            0.0), self.data_type)
+
+  def soft_min(self, x, y):
+    return tf.maximum(-1.0 * (1 / (
+        self.utility.FLAGS.soft_min_value + 0.0)) * tf.log(
+            tf.exp(-self.utility.FLAGS.soft_min_value * x) + tf.exp(
+                -self.utility.FLAGS.soft_min_value * y)), tf.zeros_like(x))
+
+  def error_computation(self):
+    #computes the error of each example in a batch
+    math_error = 0.5 * tf.square(tf.sub(self.scalar_output, self.batch_answer))
+    #scale math error
+    math_error = math_error / self.rows
+    math_error = tf.minimum(math_error, self.utility.FLAGS.max_math_error *
+                            tf.ones(tf.shape(math_error), self.data_type))
+    self.init_print_error = tf.select(
+        self.batch_gold_select, -1 * tf.log(self.batch_lookup_answer + 1e-300 +
+                                            self.invert_select_full_mask), -1 *
+        tf.log(1 - self.batch_lookup_answer)) * self.select_full_mask
+    print_error_1 = self.init_print_error * tf.cast(
+        tf.equal(self.batch_print_answer, 0.0), self.data_type)
+    print_error = tf.reduce_sum(tf.reduce_sum((print_error_1), 1), 1)
+    for val in range(1, 58):
+      print_error += self.compute_lookup_error(val + 0.0)
+    print_error = print_error * self.utility.FLAGS.print_cost / self.num_entries
+    if (self.mode == "train"):
+      error = tf.select(
+          tf.logical_and(
+              tf.not_equal(self.batch_answer, 0.0),
+              tf.not_equal(
+                  tf.reduce_sum(tf.reduce_sum(self.batch_print_answer, 1), 1),
+                  0.0)),
+          self.soft_min(math_error, print_error),
+          tf.select(
+              tf.not_equal(self.batch_answer, 0.0), math_error, print_error))
+    else:
+      error = tf.select(
+          tf.logical_and(
+              tf.equal(self.scalar_output, 0.0),
+              tf.equal(
+                  tf.reduce_sum(tf.reduce_sum(self.batch_lookup_answer, 1), 1),
+                  0.0)),
+          tf.ones_like(math_error),
+          tf.select(
+              tf.equal(self.scalar_output, 0.0), print_error, math_error))
+    return error
+
+  def batch_process(self):
+    #Computes loss and fraction of correct examples in a batch.
+    self.params_unit = nn_utils.apply_dropout(
+        self.params["unit"], self.utility.FLAGS.dropout, self.mode)
+    batch_size = self.batch_size
+    max_passes = self.max_passes
+    num_timesteps = 1
+    max_elements = self.max_elements
+    select = tf.cast(
+        tf.fill([self.batch_size, max_elements], 1.0), self.data_type)
+    hprev = tf.cast(
+        tf.fill([self.batch_size, self.embedding_dims], 0.0),
+        self.data_type)  #running sum of the hidden states of the model
+    output = tf.cast(tf.fill([self.batch_size, 1], 0.0),
+                     self.data_type)  #output of the model
+    correct = tf.cast(
+        tf.fill([1], 0.0), self.data_type
+    )  #to compute accuracy, returns number of correct examples for this batch
+    total_error = 0.0
+    prev_select_1 = tf.zeros_like(select)
+    self.create_summary_embeddings()
+    self.get_column_hidden_vectors()
+    #get question embedding
+    question_embedding, hidden_vectors = self.LSTM_question_embedding(
+        self.batch_question, self.question_length)
+    #compute arguments for comparison operation
+    greater_question_number, lesser_question_number, geq_question_number, leq_question_number = self.question_number_softmax(
+        hidden_vectors)
+    self.init_select_greater = tf.cast(
+        tf.greater(self.full_processed_column,
+                   tf.expand_dims(greater_question_number, 2)), self.
+        data_type) * self.select_bad_number_mask  #bs * max_cols * max_elements
+    self.init_select_lesser = tf.cast(
+        tf.less(self.full_processed_column,
+                tf.expand_dims(lesser_question_number, 2)), self.
+        data_type) * self.select_bad_number_mask  #bs * max_cols * max_elements
+    self.init_select_geq = tf.cast(
+        tf.greater_equal(self.full_processed_column,
+                         tf.expand_dims(geq_question_number, 2)), self.
+        data_type) * self.select_bad_number_mask  #bs * max_cols * max_elements
+    self.init_select_leq = tf.cast(
+        tf.less_equal(self.full_processed_column,
+                      tf.expand_dims(leq_question_number, 2)), self.
+        data_type) * self.select_bad_number_mask  #bs * max_cols * max_elements
+    self.init_select_word_match = 0
+    if (self.utility.FLAGS.rnn_dropout > 0.0):
+      if (self.mode == "train"):
+        history_rnn_dropout_mask = tf.cast(
+            tf.random_uniform(
+                tf.shape(hprev), minval=0.0, maxval=1.0) <
+            self.utility.FLAGS.rnn_dropout,
+            self.data_type) / self.utility.FLAGS.rnn_dropout
+      else:
+        history_rnn_dropout_mask = tf.ones_like(hprev)
+    select = select * self.select_whole_mask
+    self.batch_log_prob = tf.zeros([self.batch_size], dtype=self.data_type)
+    #Perform max_passes and at each  pass select operation and column
+    for curr_pass in range(max_passes):
+      print "step: ", curr_pass
+      output, select, softmax, soft_softmax, column_softmax, soft_column_softmax = self.one_pass(
+          select, question_embedding, hidden_vectors, hprev, prev_select_1,
+          curr_pass)
+      prev_select_1 = select
+      #compute input to history RNN
+      input_op = tf.transpose(
+          tf.matmul(
+              tf.transpose(self.params_unit), tf.transpose(
+                  soft_softmax)))  #weighted average of emebdding of operations
+      input_col = tf.reduce_sum(
+          tf.expand_dims(soft_column_softmax, 2) *
+          self.full_column_hidden_vectors, 1)
+      history_input = tf.concat(1, [input_op, input_col])
+      history_input = nn_utils.apply_dropout(
+          history_input, self.utility.FLAGS.dropout, self.mode)
+      hprev = self.history_recurrent_step(history_input, hprev)
+      if (self.utility.FLAGS.rnn_dropout > 0.0):
+        hprev = hprev * history_rnn_dropout_mask
+    self.scalar_output = output
+    error = self.error_computation()
+    cond = tf.less(error, 0.0001, name="cond")
+    correct_add = tf.select(
+        cond, tf.fill(tf.shape(cond), 1.0), tf.fill(tf.shape(cond), 0.0))
+    correct = tf.reduce_sum(correct_add)
+    error = error / batch_size
+    total_error = tf.reduce_sum(error)
+    total_correct = correct / batch_size
+    return total_error, total_correct
+
+  def compute_error(self):
+    #Sets mask variables and performs batch processing
+    self.batch_gold_select = self.batch_print_answer > 0.0
+    self.full_column_mask = tf.concat(
+        1, [self.batch_number_column_mask, self.batch_word_column_mask])
+    self.full_processed_column = tf.concat(
+        1,
+        [self.batch_processed_number_column, self.batch_processed_word_column])
+    self.full_processed_sorted_index_column = tf.concat(1, [
+        self.batch_processed_sorted_index_number_column,
+        self.batch_processed_sorted_index_word_column
+    ])
+    self.select_bad_number_mask = tf.cast(
+        tf.logical_and(
+            tf.not_equal(self.full_processed_column,
+                         self.utility.FLAGS.pad_int),
+            tf.not_equal(self.full_processed_column,
+                         self.utility.FLAGS.bad_number_pre_process)),
+        self.data_type)
+    self.select_mask = tf.cast(
+        tf.logical_not(
+            tf.equal(self.batch_number_column, self.utility.FLAGS.pad_int)),
+        self.data_type)
+    self.select_word_mask = tf.cast(
+        tf.logical_not(
+            tf.equal(self.batch_word_column_entry_mask,
+                     self.utility.dummy_token_id)), self.data_type)
+    self.select_full_mask = tf.concat(
+        1, [self.select_mask, self.select_word_mask])
+    self.select_whole_mask = tf.maximum(
+        tf.reshape(
+            tf.slice(self.select_mask, [0, 0, 0],
+                     [self.batch_size, 1, self.max_elements]),
+            [self.batch_size, self.max_elements]),
+        tf.reshape(
+            tf.slice(self.select_word_mask, [0, 0, 0],
+                     [self.batch_size, 1, self.max_elements]),
+            [self.batch_size, self.max_elements]))
+    self.invert_select_full_mask = tf.cast(
+        tf.concat(1, [
+            tf.equal(self.batch_number_column, self.utility.FLAGS.pad_int),
+            tf.equal(self.batch_word_column_entry_mask,
+                     self.utility.dummy_token_id)
+        ]), self.data_type)
+    self.batch_lookup_answer = tf.zeros(tf.shape(self.batch_gold_select))
+    self.reset_select = self.select_whole_mask
+    self.rows = tf.reduce_sum(self.select_whole_mask, 1)
+    self.num_entries = tf.reshape(
+        tf.reduce_sum(tf.reduce_sum(self.select_full_mask, 1), 1),
+        [self.batch_size])
+    self.final_error, self.final_correct = self.batch_process()
+    return self.final_error
+
+  def create_graph(self, params, global_step):
+    #Creates the graph to compute error, gradient computation and updates parameters
+    self.params = params
+    batch_size = self.batch_size
+    learning_rate = tf.cast(self.utility.FLAGS.learning_rate, self.data_type)
+    self.total_cost = self.compute_error() 
+    optimize_params = self.params.values()
+    optimize_names = self.params.keys()
+    print "optimize params ", optimize_names
+    if (self.utility.FLAGS.l2_regularizer > 0.0):
+      reg_cost = 0.0
+      for ind_param in self.params.keys():
+        reg_cost += tf.nn.l2_loss(self.params[ind_param])
+      self.total_cost += self.utility.FLAGS.l2_regularizer * reg_cost
+    grads = tf.gradients(self.total_cost, optimize_params, name="gradients")
+    grad_norm = 0.0
+    for p, name in zip(grads, optimize_names):
+      print "grads: ", p, name
+      if isinstance(p, tf.IndexedSlices):
+        grad_norm += tf.reduce_sum(p.values * p.values)
+      elif not (p == None):
+        grad_norm += tf.reduce_sum(p * p)
+    grad_norm = tf.sqrt(grad_norm)
+    max_grad_norm = np.float32(self.utility.FLAGS.clip_gradients).astype(
+        self.utility.np_data_type[self.utility.FLAGS.data_type])
+    grad_scale = tf.minimum(
+        tf.cast(1.0, self.data_type), max_grad_norm / grad_norm)
+    clipped_grads = list()
+    for p in grads:
+      if isinstance(p, tf.IndexedSlices):
+        tmp = p.values * grad_scale
+        clipped_grads.append(tf.IndexedSlices(tmp, p.indices))
+      elif not (p == None):
+        clipped_grads.append(p * grad_scale)
+      else:
+        clipped_grads.append(p)
+    grads = clipped_grads
+    self.global_step = global_step
+    params_list = self.params.values()
+    params_list.append(self.global_step)
+    adam = tf.train.AdamOptimizer(
+        learning_rate,
+        epsilon=tf.cast(self.utility.FLAGS.eps, self.data_type),
+        use_locking=True)
+    self.step = adam.apply_gradients(zip(grads, optimize_params), 
+					global_step=self.global_step)
+    self.init_op = tf.initialize_all_variables()
+

+ 234 - 0
neural_programmer.py

@@ -0,0 +1,234 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implementation of the Neural Programmer model described in https://openreview.net/pdf?id=ry2YOrcge
+
+This file calls functions to load & pre-process data, construct the TF graph
+and performs training or evaluation as specified by the flag evaluator_job
+Author: aneelakantan (Arvind Neelakantan)
+"""
+import time
+from random import Random
+import numpy as np
+import tensorflow as tf
+import model
+import wiki_data
+import parameters
+import data_utils
+
+tf.flags.DEFINE_integer("train_steps", 100001, "Number of steps to train")
+tf.flags.DEFINE_integer("eval_cycle", 500,
+                        "Evaluate model at every eval_cycle steps")
+tf.flags.DEFINE_integer("max_elements", 100,
+                        "maximum rows that are  considered for processing")
+tf.flags.DEFINE_integer(
+    "max_number_cols", 15,
+    "maximum number columns that are considered for processing")
+tf.flags.DEFINE_integer(
+    "max_word_cols", 25,
+    "maximum number columns that are considered for processing")
+tf.flags.DEFINE_integer("question_length", 62, "maximum question length")
+tf.flags.DEFINE_integer("max_entry_length", 1, "")
+tf.flags.DEFINE_integer("max_passes", 4, "number of operation passes")
+tf.flags.DEFINE_integer("embedding_dims", 256, "")
+tf.flags.DEFINE_integer("batch_size", 20, "")
+tf.flags.DEFINE_float("clip_gradients", 1.0, "")
+tf.flags.DEFINE_float("eps", 1e-6, "")
+tf.flags.DEFINE_float("param_init", 0.1, "")
+tf.flags.DEFINE_float("learning_rate", 0.001, "")
+tf.flags.DEFINE_float("l2_regularizer", 0.0001, "")
+tf.flags.DEFINE_float("print_cost", 50.0,
+                      "weighting factor in the objective function")
+tf.flags.DEFINE_string("job_id", "temp", """job id""")
+tf.flags.DEFINE_string("output_dir", "../model/",
+                       """output_dir""")
+tf.flags.DEFINE_string("data_dir", "../data/",
+                       """data_dir""")
+tf.flags.DEFINE_integer("write_every", 500, "wrtie every N")
+tf.flags.DEFINE_integer("param_seed", 150, "")
+tf.flags.DEFINE_integer("python_seed", 200, "")
+tf.flags.DEFINE_float("dropout", 0.8, "dropout keep probability")
+tf.flags.DEFINE_float("rnn_dropout", 0.9,
+                      "dropout keep probability for rnn connections")
+tf.flags.DEFINE_float("pad_int", -20000.0,
+                      "number columns are padded with pad_int")
+tf.flags.DEFINE_string("data_type", "double", "float or double")
+tf.flags.DEFINE_float("word_dropout_prob", 0.9, "word dropout keep prob")
+tf.flags.DEFINE_integer("word_cutoff", 10, "")
+tf.flags.DEFINE_integer("vocab_size", 10800, "")
+tf.flags.DEFINE_boolean("evaluator_job", False,
+                        "wehther to run as trainer/evaluator")
+tf.flags.DEFINE_float(
+    "bad_number_pre_process", -200000.0,
+    "number that is added to a corrupted table entry in a number column")
+tf.flags.DEFINE_float("max_math_error", 3.0,
+                      "max square loss error that is considered")
+tf.flags.DEFINE_float("soft_min_value", 5.0, "")
+FLAGS = tf.flags.FLAGS
+
+
+class Utility:
+  #holds FLAGS and other variables that are used in different files
+  def __init__(self):
+    global FLAGS
+    self.FLAGS = FLAGS
+    self.unk_token = "UNK"
+    self.entry_match_token = "entry_match"
+    self.column_match_token = "column_match"
+    self.dummy_token = "dummy_token"
+    self.tf_data_type = {}
+    self.tf_data_type["double"] = tf.float64
+    self.tf_data_type["float"] = tf.float32
+    self.np_data_type = {}
+    self.np_data_type["double"] = np.float64
+    self.np_data_type["float"] = np.float32
+    self.operations_set = ["count"] + [
+        "prev", "next", "first_rs", "last_rs", "group_by_max", "greater",
+        "lesser", "geq", "leq", "max", "min", "word-match"
+    ] + ["reset_select"] + ["print"]
+    self.word_ids = {}
+    self.reverse_word_ids = {}
+    self.word_count = {}
+    self.random = Random(FLAGS.python_seed)
+
+
+def evaluate(sess, data, batch_size, graph, i):
+  #computes accuracy
+  num_examples = 0.0
+  gc = 0.0
+  for j in range(0, len(data) - batch_size + 1, batch_size):
+    [ct] = sess.run([graph.final_correct],
+                    feed_dict=data_utils.generate_feed_dict(data, j, batch_size,
+                                                            graph))
+    gc += ct * batch_size
+    num_examples += batch_size
+  print "dev set accuracy   after ", i, " : ", gc / num_examples
+  print num_examples, len(data)
+  print "--------"
+
+
+def Train(graph, utility, batch_size, train_data, sess, model_dir,
+          saver):
+  #performs training
+  curr = 0
+  train_set_loss = 0.0
+  utility.random.shuffle(train_data)
+  start = time.time()
+  for i in range(utility.FLAGS.train_steps):
+    curr_step = i
+    if (i > 0 and i % FLAGS.write_every == 0):
+      model_file = model_dir + "/model_" + str(i)
+      saver.save(sess, model_file)
+    if curr + batch_size >= len(train_data):
+      curr = 0
+      utility.random.shuffle(train_data)
+    step, cost_value = sess.run(
+        [graph.step, graph.total_cost],
+        feed_dict=data_utils.generate_feed_dict(
+            train_data, curr, batch_size, graph, train=True, utility=utility))
+    curr = curr + batch_size
+    train_set_loss += cost_value
+    if (i > 0 and i % FLAGS.eval_cycle == 0):
+      end = time.time()
+      time_taken = end - start
+      print "step ", i, " ", time_taken, " seconds "
+      start = end
+      print " printing train set loss: ", train_set_loss / utility.FLAGS.eval_cycle
+      train_set_loss = 0.0
+
+
+def master(train_data, dev_data, utility):
+  #creates TF graph and calls trainer or evaluator
+  batch_size = utility.FLAGS.batch_size 
+  model_dir = utility.FLAGS.output_dir + "/model" + utility.FLAGS.job_id + "/"
+  #create all paramters of the model
+  param_class = parameters.Parameters(utility)
+  params, global_step, init = param_class.parameters(utility)
+  key = "test" if (FLAGS.evaluator_job) else "train"
+  graph = model.Graph(utility, batch_size, utility.FLAGS.max_passes, mode=key)
+  graph.create_graph(params, global_step)
+  prev_dev_error = 0.0
+  final_loss = 0.0
+  final_accuracy = 0.0
+  #start session
+  with tf.Session() as sess:
+    sess.run(init.name)
+    sess.run(graph.init_op.name)
+    to_save = params.copy()
+    saver = tf.train.Saver(to_save, max_to_keep=500)
+    if (FLAGS.evaluator_job):
+      while True:
+        selected_models = {}
+        file_list = tf.gfile.ListDirectory(model_dir)
+        for model_file in file_list:
+          if ("checkpoint" in model_file or "index" in model_file or
+              "meta" in model_file):
+            continue
+          if ("data" in model_file):
+            model_file = model_file.split(".")[0]
+          model_step = int(
+              model_file.split("_")[len(model_file.split("_")) - 1])
+          selected_models[model_step] = model_file
+        file_list = sorted(selected_models.items(), key=lambda x: x[0])
+        if (len(file_list) > 0):
+          file_list = file_list[0:len(file_list) - 1]
+	print "list of models: ", file_list
+        for model_file in file_list:
+          model_file = model_file[1]
+          print "restoring: ", model_file
+          saver.restore(sess, model_dir + "/" + model_file)
+          model_step = int(
+              model_file.split("_")[len(model_file.split("_")) - 1])
+          print "evaluating on dev ", model_file, model_step
+          evaluate(sess, dev_data, batch_size, graph, model_step)
+    else:
+      ckpt = tf.train.get_checkpoint_state(model_dir)
+      print "model dir: ", model_dir
+      if (not (tf.gfile.IsDirectory(model_dir))):
+        print "create dir: ", model_dir
+        tf.gfile.MkDir(model_dir)
+      Train(graph, utility, batch_size, train_data, sess, model_dir,
+            saver)
+
+def main(args):
+  utility = Utility()
+  train_name = "random-split-1-train.examples"
+  dev_name = "random-split-1-dev.examples"
+  test_name = "pristine-unseen-tables.examples"
+  #load data
+  dat = wiki_data.WikiQuestionGenerator(train_name, dev_name, test_name, FLAGS.data_dir)
+  train_data, dev_data, test_data = dat.load()
+  utility.words = []
+  utility.word_ids = {}
+  utility.reverse_word_ids = {}
+  #construct vocabulary
+  data_utils.construct_vocab(train_data, utility)
+  data_utils.construct_vocab(dev_data, utility, True)
+  data_utils.construct_vocab(test_data, utility, True)
+  data_utils.add_special_words(utility)
+  data_utils.perform_word_cutoff(utility)
+  #convert data to int format and pad the inputs
+  train_data = data_utils.complete_wiki_processing(train_data, utility, True)
+  dev_data = data_utils.complete_wiki_processing(dev_data, utility, False)
+  test_data = data_utils.complete_wiki_processing(test_data, utility, False)
+  print "# train examples ", len(train_data)
+  print "# dev examples ", len(dev_data)
+  print "# test examples ", len(test_data)
+  print "running open source"
+  #construct TF graph and train or evaluate
+  master(train_data, dev_data, utility)
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 68 - 0
nn_utils.py

@@ -0,0 +1,68 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Author: aneelakantan (Arvind Neelakantan)
+"""
+
+import tensorflow as tf
+
+def get_embedding(word, utility, params):
+  return tf.nn.embedding_lookup(params["word"], word)
+
+
+def apply_dropout(x, dropout_rate, mode):
+  if (dropout_rate > 0.0):
+    if (mode == "train"):
+      x = tf.nn.dropout(x, dropout_rate)
+    else:
+      x = x
+  return x
+
+
+def LSTMCell(x, mprev, cprev, key, params):
+  """Create an LSTM cell.
+
+  Implements the equations in pg.2 from
+  "Long Short-Term Memory Based Recurrent Neural Network Architectures
+  For Large Vocabulary Speech Recognition",
+  Hasim Sak, Andrew Senior, Francoise Beaufays.
+
+  Args:
+    w: A dictionary of the weights and optional biases as returned
+      by LSTMParametersSplit().
+    x: Inputs to this cell.
+    mprev: m_{t-1}, the recurrent activations (same as the output)
+      from the previous cell.
+    cprev: c_{t-1}, the cell activations from the previous cell.
+    keep_prob: Keep probability on the input and the outputs of a cell.
+
+  Returns:
+    m: Outputs of this cell.
+    c: Cell Activations.
+    """
+
+  i = tf.matmul(x, params[key + "_ix"]) + tf.matmul(mprev, params[key + "_im"])
+  i = tf.nn.bias_add(i, params[key + "_i"])
+  f = tf.matmul(x, params[key + "_fx"]) + tf.matmul(mprev, params[key + "_fm"])
+  f = tf.nn.bias_add(f, params[key + "_f"])
+  c = tf.matmul(x, params[key + "_cx"]) + tf.matmul(mprev, params[key + "_cm"])
+  c = tf.nn.bias_add(c, params[key + "_c"])
+  o = tf.matmul(x, params[key + "_ox"]) + tf.matmul(mprev, params[key + "_om"])
+  o = tf.nn.bias_add(o, params[key + "_o"])
+  i = tf.sigmoid(i, name="i_gate")
+  f = tf.sigmoid(f, name="f_gate")
+  o = tf.sigmoid(o, name="o_gate")
+  c = f * cprev + i * tf.tanh(c)
+  m = o * c
+  return m, c

+ 89 - 0
parameters.py

@@ -0,0 +1,89 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Author: aneelakantan (Arvind Neelakantan)
+"""
+
+import numpy as np
+import tensorflow as tf
+
+
+class Parameters:
+
+  def __init__(self, u):
+    self.utility = u
+    self.init_seed_counter = 0
+    self.word_init = {}
+
+  def parameters(self, utility):
+    params = {}
+    inits = []
+    embedding_dims = self.utility.FLAGS.embedding_dims
+    params["unit"] = tf.Variable(
+        self.RandomUniformInit([len(utility.operations_set), embedding_dims]))
+    params["word"] = tf.Variable(
+        self.RandomUniformInit([utility.FLAGS.vocab_size, embedding_dims]))
+    params["word_match_feature_column_name"] = tf.Variable(
+        self.RandomUniformInit([1]))
+    params["controller"] = tf.Variable(
+        self.RandomUniformInit([2 * embedding_dims, embedding_dims]))
+    params["column_controller"] = tf.Variable(
+        self.RandomUniformInit([2 * embedding_dims, embedding_dims]))
+    params["column_controller_prev"] = tf.Variable(
+        self.RandomUniformInit([embedding_dims, embedding_dims]))
+    params["controller_prev"] = tf.Variable(
+        self.RandomUniformInit([embedding_dims, embedding_dims]))
+    global_step = tf.Variable(1, name="global_step")
+    #weigths of question and history RNN (or LSTM)
+    key_list = ["question_lstm"]
+    for key in key_list:
+      # Weights going from inputs to nodes.
+      for wgts in ["ix", "fx", "cx", "ox"]:
+        params[key + "_" + wgts] = tf.Variable(
+            self.RandomUniformInit([embedding_dims, embedding_dims]))
+      # Weights going from nodes to nodes.
+      for wgts in ["im", "fm", "cm", "om"]:
+        params[key + "_" + wgts] = tf.Variable(
+            self.RandomUniformInit([embedding_dims, embedding_dims]))
+      #Biases for the gates and cell
+      for bias in ["i", "f", "c", "o"]:
+        if (bias == "f"):
+          print "forget gate bias"
+          params[key + "_" + bias] = tf.Variable(
+              tf.random_uniform([embedding_dims], 1.0, 1.1, self.utility.
+                                tf_data_type[self.utility.FLAGS.data_type]))
+        else:
+          params[key + "_" + bias] = tf.Variable(
+              self.RandomUniformInit([embedding_dims]))
+    params["history_recurrent"] = tf.Variable(
+        self.RandomUniformInit([3 * embedding_dims, embedding_dims]))
+    params["history_recurrent_bias"] = tf.Variable(
+        self.RandomUniformInit([1, embedding_dims]))
+    params["break_conditional"] = tf.Variable(
+        self.RandomUniformInit([2 * embedding_dims, embedding_dims]))
+    init = tf.initialize_all_variables()
+    return params, global_step, init
+
+  def RandomUniformInit(self, shape):
+    """Returns a RandomUniform Tensor between -param_init and param_init."""
+    param_seed = self.utility.FLAGS.param_seed
+    self.init_seed_counter += 1
+    return tf.random_uniform(
+        shape, -1.0 *
+        (np.float32(self.utility.FLAGS.param_init)
+        ).astype(self.utility.np_data_type[self.utility.FLAGS.data_type]),
+        (np.float32(self.utility.FLAGS.param_init)
+        ).astype(self.utility.np_data_type[self.utility.FLAGS.data_type]),
+        self.utility.tf_data_type[self.utility.FLAGS.data_type],
+        param_seed + self.init_seed_counter)

+ 530 - 0
wiki_data.py

@@ -0,0 +1,530 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Loads the WikiQuestions dataset.
+
+An example consists of question, table. Additionally, we store the processed
+columns which store the entries after performing number, date and other
+preprocessing as done in the baseline.
+columns, column names and processed columns are split into word and number
+columns.
+lookup answer (or matrix) is also split into number and word lookup matrix
+Author: aneelakantan (Arvind Neelakantan)
+"""
+import math
+import os
+import re
+import numpy as np
+import unicodedata as ud
+import tensorflow as tf
+
+bad_number = -200000.0  #number that is added to a corrupted table entry in a number column
+
+def is_nan_or_inf(number):
+  return math.isnan(number) or math.isinf(number)
+
+def strip_accents(s):
+  u = unicode(s, "utf-8")
+  u_new = ''.join(c for c in ud.normalize('NFKD', u) if ud.category(c) != 'Mn')
+  return u_new.encode("utf-8")
+
+
+def correct_unicode(string):
+  string = strip_accents(string)
+  string = re.sub("\xc2\xa0", " ", string).strip()
+  string = re.sub("\xe2\x80\x93", "-", string).strip()
+  #string = re.sub(ur'[\u0300-\u036F]', "", string)
+  string = re.sub("‚", ",", string)
+  string = re.sub("…", "...", string)
+  #string = re.sub("[·・]", ".", string)
+  string = re.sub("ˆ", "^", string)
+  string = re.sub("˜", "~", string)
+  string = re.sub("‹", "<", string)
+  string = re.sub("›", ">", string)
+  #string = re.sub("[‘’´`]", "'", string)
+  #string = re.sub("[“”«»]", "\"", string)
+  #string = re.sub("[•†‡]", "", string)
+  #string = re.sub("[‐‑–—]", "-", string)
+  string = re.sub(ur'[\u2E00-\uFFFF]', "", string)
+  string = re.sub("\\s+", " ", string).strip()
+  return string
+
+
+def simple_normalize(string):
+  string = correct_unicode(string)
+  # Citations
+  string = re.sub("\[(nb ?)?\d+\]", "", string)
+  string = re.sub("\*+$", "", string)
+  # Year in parenthesis
+  string = re.sub("\(\d* ?-? ?\d*\)", "", string)
+  string = re.sub("^\"(.*)\"$", "", string)
+  return string
+
+
+def full_normalize(string):
+  #print "an: ", string
+  string = simple_normalize(string)
+  # Remove trailing info in brackets
+  string = re.sub("\[[^\]]*\]", "", string)
+  # Remove most unicode characters in other languages
+  string = re.sub(ur'[\u007F-\uFFFF]', "", string.strip())
+  # Remove trailing info in parenthesis
+  string = re.sub("\([^)]*\)$", "", string.strip())
+  string = final_normalize(string)
+  # Get rid of question marks
+  string = re.sub("\?", "", string).strip()
+  # Get rid of trailing colons (usually occur in column titles)
+  string = re.sub("\:$", " ", string).strip()
+  # Get rid of slashes
+  string = re.sub(r"/", " ", string).strip()
+  string = re.sub(r"\\", " ", string).strip()
+  # Replace colon, slash, and dash with space
+  # Note: need better replacement for this when parsing time
+  string = re.sub(r"\:", " ", string).strip()
+  string = re.sub("/", " ", string).strip()
+  string = re.sub("-", " ", string).strip()
+  # Convert empty strings to UNK
+  # Important to do this last or near last
+  if not string:
+    string = "UNK"
+  return string
+
+def final_normalize(string):
+  # Remove leading and trailing whitespace
+  string = re.sub("\\s+", " ", string).strip()
+  # Convert entirely to lowercase
+  string = string.lower()
+  # Get rid of strangely escaped newline characters
+  string = re.sub("\\\\n", " ", string).strip()
+  # Get rid of quotation marks
+  string = re.sub(r"\"", "", string).strip()
+  string = re.sub(r"\'", "", string).strip()
+  string = re.sub(r"`", "", string).strip()
+  # Get rid of *
+  string = re.sub("\*", "", string).strip()
+  return string
+
+def is_number(x):
+  try:
+    f = float(x)
+    return not is_nan_or_inf(f)
+  except ValueError:
+    return False
+  except TypeError:
+    return False
+
+
+class WikiExample(object):
+
+  def __init__(self, id, question, answer, table_key):
+    self.question_id = id
+    self.question = question
+    self.answer = answer
+    self.table_key = table_key
+    self.lookup_matrix = []
+    self.is_bad_example = False
+    self.is_word_lookup = False
+    self.is_ambiguous_word_lookup = False
+    self.is_number_lookup = False
+    self.is_number_calc = False
+    self.is_unknown_answer = False
+
+
+class TableInfo(object):
+
+  def __init__(self, word_columns, word_column_names, word_column_indices,
+               number_columns, number_column_names, number_column_indices,
+               processed_word_columns, processed_number_columns, orig_columns):
+    self.word_columns = word_columns
+    self.word_column_names = word_column_names
+    self.word_column_indices = word_column_indices
+    self.number_columns = number_columns
+    self.number_column_names = number_column_names
+    self.number_column_indices = number_column_indices
+    self.processed_word_columns = processed_word_columns
+    self.processed_number_columns = processed_number_columns
+    self.orig_columns = orig_columns
+
+
+class WikiQuestionLoader(object):
+
+  def __init__(self, data_name, root_folder):
+    self.root_folder = root_folder
+    self.data_folder = os.path.join(self.root_folder, "data")
+    self.examples = []
+    self.data_name = data_name
+
+  def num_questions(self):
+    return len(self.examples)
+
+  def load_qa(self):
+    data_source = os.path.join(self.data_folder, self.data_name)
+    f = tf.gfile.GFile(data_source, "r")
+    id_regex = re.compile("\(id ([^\)]*)\)")
+    for line in f:
+      id_match = id_regex.search(line)
+      id = id_match.group(1)
+      self.examples.append(id)
+
+  def load(self):
+    self.load_qa()
+
+
+def is_date(word):
+  if (not (bool(re.search("[a-z0-9]", word, re.IGNORECASE)))):
+    return False
+  if (len(word) != 10):
+    return False
+  if (word[4] != "-"):
+    return False
+  if (word[7] != "-"):
+    return False
+  for i in range(len(word)):
+    if (not (word[i] == "X" or word[i] == "x" or word[i] == "-" or re.search(
+        "[0-9]", word[i]))):
+      return False
+  return True
+
+
+class WikiQuestionGenerator(object):
+
+  def __init__(self, train_name, dev_name, test_name, root_folder):
+    self.train_name = train_name
+    self.dev_name = dev_name
+    self.test_name = test_name
+    self.train_loader = WikiQuestionLoader(train_name, root_folder)
+    self.dev_loader = WikiQuestionLoader(dev_name, root_folder)
+    self.test_loader = WikiQuestionLoader(test_name, root_folder)
+    self.bad_examples = 0
+    self.root_folder = root_folder   
+    self.data_folder = os.path.join(self.root_folder, "annotated/data")
+    self.annotated_examples = {}
+    self.annotated_tables = {}
+    self.annotated_word_reject = {}
+    self.annotated_word_reject["-lrb-"] = 1
+    self.annotated_word_reject["-rrb-"] = 1
+    self.annotated_word_reject["UNK"] = 1
+
+  def is_money(self, word):
+    if (not (bool(re.search("[a-z0-9]", word, re.IGNORECASE)))):
+      return False
+    for i in range(len(word)):
+      if (not (word[i] == "E" or word[i] == "." or re.search("[0-9]",
+                                                             word[i]))):
+        return False
+    return True
+
+  def remove_consecutive(self, ner_tags, ner_values):
+    for i in range(len(ner_tags)):
+      if ((ner_tags[i] == "NUMBER" or ner_tags[i] == "MONEY" or
+           ner_tags[i] == "PERCENT" or ner_tags[i] == "DATE") and
+          i + 1 < len(ner_tags) and ner_tags[i] == ner_tags[i + 1] and
+          ner_values[i] == ner_values[i + 1] and ner_values[i] != ""):
+        word = ner_values[i]
+        word = word.replace(">", "").replace("<", "").replace("=", "").replace(
+            "%", "").replace("~", "").replace("$", "").replace("£", "").replace(
+                "€", "")
+        if (re.search("[A-Z]", word) and not (is_date(word)) and not (
+            self.is_money(word))):
+          ner_values[i] = "A"
+        else:
+          ner_values[i] = ","
+    return ner_tags, ner_values
+
+  def pre_process_sentence(self, tokens, ner_tags, ner_values):
+    sentence = []
+    tokens = tokens.split("|")
+    ner_tags = ner_tags.split("|")
+    ner_values = ner_values.split("|")
+    ner_tags, ner_values = self.remove_consecutive(ner_tags, ner_values)
+    #print "old: ", tokens
+    for i in range(len(tokens)):
+      word = tokens[i]
+      if (ner_values[i] != "" and
+          (ner_tags[i] == "NUMBER" or ner_tags[i] == "MONEY" or
+           ner_tags[i] == "PERCENT" or ner_tags[i] == "DATE")):
+        word = ner_values[i]
+        word = word.replace(">", "").replace("<", "").replace("=", "").replace(
+            "%", "").replace("~", "").replace("$", "").replace("£", "").replace(
+                "€", "")
+        if (re.search("[A-Z]", word) and not (is_date(word)) and not (
+            self.is_money(word))):
+          word = tokens[i]
+        if (is_number(ner_values[i])):
+          word = float(ner_values[i])
+        elif (is_number(word)):
+          word = float(word)
+        if (tokens[i] == "score"):
+          word = "score"
+      if (is_number(word)):
+        word = float(word)
+      if (not (self.annotated_word_reject.has_key(word))):
+        if (is_number(word) or is_date(word) or self.is_money(word)):
+          sentence.append(word)
+        else:
+          word = full_normalize(word)
+          if (not (self.annotated_word_reject.has_key(word)) and
+              bool(re.search("[a-z0-9]", word, re.IGNORECASE))):
+            m = re.search(",", word)
+            sentence.append(word.replace(",", ""))
+    if (len(sentence) == 0):
+      sentence.append("UNK")
+    return sentence
+
+  def load_annotated_data(self, in_file):
+    self.annotated_examples = {}
+    self.annotated_tables = {}
+    f = tf.gfile.GFile(in_file, "r")
+    counter = 0
+    for line in f:
+      if (counter > 0):
+        line = line.strip()
+        (question_id, utterance, context, target_value, tokens, lemma_tokens,
+         pos_tags, ner_tags, ner_values, target_canon) = line.split("\t")
+        question = self.pre_process_sentence(tokens, ner_tags, ner_values)
+        target_canon = target_canon.split("|")
+        self.annotated_examples[question_id] = WikiExample(
+            question_id, question, target_canon, context)
+        self.annotated_tables[context] = []
+      counter += 1
+    print "Annotated examples loaded ", len(self.annotated_examples)
+    f.close()
+
+  def is_number_column(self, a):
+    for w in a:
+      if (len(w) != 1):
+        return False
+      if (not (is_number(w[0]))):
+        return False
+    return True
+
+  def convert_table(self, table):
+    answer = []
+    for i in range(len(table)):
+      temp = []
+      for j in range(len(table[i])):
+        temp.append(" ".join([str(w) for w in table[i][j]]))
+      answer.append(temp)
+    return answer
+
+  def load_annotated_tables(self):
+    for table in self.annotated_tables.keys():
+      annotated_table = table.replace("csv", "annotated")
+      orig_columns = []
+      processed_columns = []
+      f = tf.gfile.GFile(os.path.join(self.root_folder, annotated_table), "r")
+      counter = 0
+      for line in f:
+        if (counter > 0):
+          line = line.strip()
+          line = line + "\t" * (13 - len(line.split("\t")))
+          (row, col, read_id, content, tokens, lemma_tokens, pos_tags, ner_tags,
+           ner_values, number, date, num2, read_list) = line.split("\t")
+        counter += 1
+      f.close()
+      max_row = int(row)
+      max_col = int(col)
+      for i in range(max_col + 1):
+        orig_columns.append([])
+        processed_columns.append([])
+        for j in range(max_row + 1):
+          orig_columns[i].append(bad_number)
+          processed_columns[i].append(bad_number)
+      #print orig_columns
+      f = tf.gfile.GFile(os.path.join(self.root_folder, annotated_table), "r")
+      counter = 0
+      column_names = []
+      for line in f:
+        if (counter > 0):
+          line = line.strip()
+          line = line + "\t" * (13 - len(line.split("\t")))
+          (row, col, read_id, content, tokens, lemma_tokens, pos_tags, ner_tags,
+           ner_values, number, date, num2, read_list) = line.split("\t")
+          entry = self.pre_process_sentence(tokens, ner_tags, ner_values)
+          if (row == "-1"):
+            column_names.append(entry)
+          else:
+            orig_columns[int(col)][int(row)] = entry
+            if (len(entry) == 1 and is_number(entry[0])):
+              processed_columns[int(col)][int(row)] = float(entry[0])
+            else:
+              for single_entry in entry:
+                if (is_number(single_entry)):
+                  processed_columns[int(col)][int(row)] = float(single_entry)
+                  break
+              nt = ner_tags.split("|")
+              nv = ner_values.split("|")
+              for i_entry in range(len(tokens.split("|"))):
+                if (nt[i_entry] == "DATE" and
+                    is_number(nv[i_entry].replace("-", "").replace("X", ""))):
+                  processed_columns[int(col)][int(row)] = float(nv[
+                      i_entry].replace("-", "").replace("X", ""))
+                  #processed_columns[int(col)][int(row)] =  float(nv[i_entry])
+            if (len(entry) == 1 and (is_number(entry[0]) or is_date(entry[0]) or
+                                     self.is_money(entry[0]))):
+              if (len(entry) == 1 and not (is_number(entry[0])) and
+                  is_date(entry[0])):
+                entry[0] = entry[0].replace("X", "x")
+        counter += 1
+      word_columns = []
+      processed_word_columns = []
+      word_column_names = []
+      word_column_indices = []
+      number_columns = []
+      processed_number_columns = []
+      number_column_names = []
+      number_column_indices = []
+      for i in range(max_col + 1):
+        if (self.is_number_column(orig_columns[i])):
+          number_column_indices.append(i)
+          number_column_names.append(column_names[i])
+          temp = []
+          for w in orig_columns[i]:
+            if (is_number(w[0])):
+              temp.append(w[0])
+          number_columns.append(temp)
+          processed_number_columns.append(processed_columns[i])
+        else:
+          word_column_indices.append(i)
+          word_column_names.append(column_names[i])
+          word_columns.append(orig_columns[i])
+          processed_word_columns.append(processed_columns[i])
+      table_info = TableInfo(
+          word_columns, word_column_names, word_column_indices, number_columns,
+          number_column_names, number_column_indices, processed_word_columns,
+          processed_number_columns, orig_columns)
+      self.annotated_tables[table] = table_info
+      f.close()
+
+  def answer_classification(self):
+    lookup_questions = 0
+    number_lookup_questions = 0
+    word_lookup_questions = 0
+    ambiguous_lookup_questions = 0
+    number_questions = 0
+    bad_questions = 0
+    ice_bad_questions = 0
+    tot = 0
+    got = 0
+    ice = {}
+    with tf.gfile.GFile(
+        self.root_folder + "/arvind-with-norms-2.tsv", mode="r") as f:
+      lines = f.readlines()
+      for line in lines:
+        line = line.strip()
+        if (not (self.annotated_examples.has_key(line.split("\t")[0]))):
+          continue
+        if (len(line.split("\t")) == 4):
+          line = line + "\t" * (5 - len(line.split("\t")))
+          if (not (is_number(line.split("\t")[2]))):
+            ice_bad_questions += 1
+        (example_id, ans_index, ans_raw, process_answer,
+         matched_cells) = line.split("\t")
+        if (ice.has_key(example_id)):
+          ice[example_id].append(line.split("\t"))
+        else:
+          ice[example_id] = [line.split("\t")]
+    for q_id in self.annotated_examples.keys():
+      tot += 1
+      example = self.annotated_examples[q_id]
+      table_info = self.annotated_tables[example.table_key]
+      # Figure out if the answer is numerical or lookup
+      n_cols = len(table_info.orig_columns)
+      n_rows = len(table_info.orig_columns[0])
+      example.lookup_matrix = np.zeros((n_rows, n_cols))
+      exact_matches = {}
+      for (example_id, ans_index, ans_raw, process_answer,
+           matched_cells) in ice[q_id]:
+        for match_cell in matched_cells.split("|"):
+          if (len(match_cell.split(",")) == 2):
+            (row, col) = match_cell.split(",")
+            row = int(row)
+            col = int(col)
+            if (row >= 0):
+              exact_matches[ans_index] = 1
+      answer_is_in_table = len(exact_matches) == len(example.answer)
+      if (answer_is_in_table):
+        for (example_id, ans_index, ans_raw, process_answer,
+             matched_cells) in ice[q_id]:
+          for match_cell in matched_cells.split("|"):
+            if (len(match_cell.split(",")) == 2):
+              (row, col) = match_cell.split(",")
+              row = int(row)
+              col = int(col)
+              example.lookup_matrix[row, col] = float(ans_index) + 1.0
+      example.lookup_number_answer = 0.0
+      if (answer_is_in_table):
+        lookup_questions += 1
+        if len(example.answer) == 1 and is_number(example.answer[0]):
+          example.number_answer = float(example.answer[0])
+          number_lookup_questions += 1
+          example.is_number_lookup = True
+        else:
+          #print "word lookup"
+          example.calc_answer = example.number_answer = 0.0
+          word_lookup_questions += 1
+          example.is_word_lookup = True
+      else:
+        if (len(example.answer) == 1 and is_number(example.answer[0])):
+          example.number_answer = example.answer[0]
+          example.is_number_calc = True
+        else:
+          bad_questions += 1
+          example.is_bad_example = True
+          example.is_unknown_answer = True
+      example.is_lookup = example.is_word_lookup or example.is_number_lookup
+      if not example.is_word_lookup and not example.is_bad_example:
+        number_questions += 1
+        example.calc_answer = example.answer[0]
+        example.lookup_number_answer = example.calc_answer
+      # Split up the lookup matrix into word part and number part
+      number_column_indices = table_info.number_column_indices
+      word_column_indices = table_info.word_column_indices
+      example.word_columns = table_info.word_columns
+      example.number_columns = table_info.number_columns
+      example.word_column_names = table_info.word_column_names
+      example.processed_number_columns = table_info.processed_number_columns
+      example.processed_word_columns = table_info.processed_word_columns
+      example.number_column_names = table_info.number_column_names
+      example.number_lookup_matrix = example.lookup_matrix[:,
+                                                           number_column_indices]
+      example.word_lookup_matrix = example.lookup_matrix[:, word_column_indices]
+
+  def load(self):
+    train_data = []
+    dev_data = []
+    test_data = []
+    self.load_annotated_data(
+        os.path.join(self.data_folder, "training.annotated"))
+    self.load_annotated_tables()
+    self.answer_classification()
+    self.train_loader.load()
+    self.dev_loader.load()
+    for i in range(self.train_loader.num_questions()):
+      example = self.train_loader.examples[i]
+      example = self.annotated_examples[example]
+      train_data.append(example)
+    for i in range(self.dev_loader.num_questions()):
+      example = self.dev_loader.examples[i]
+      dev_data.append(self.annotated_examples[example])
+
+    self.load_annotated_data(
+        os.path.join(self.data_folder, "pristine-unseen-tables.annotated"))
+    self.load_annotated_tables()
+    self.answer_classification()
+    self.test_loader.load()
+    for i in range(self.test_loader.num_questions()):
+      example = self.test_loader.examples[i]
+      test_data.append(self.annotated_examples[example])
+    return train_data, dev_data, test_data