Преглед изворни кода

Raises AssertionError on Incomplete Vocabulary

fixes issue #621
added a new function CheckVocab, to check for presence of a word in vocabulary
Pranay Mathur пре 8 година
родитељ
комит
27a178a955
2 измењених фајлова са 10 додато и 4 уклоњено
  1. 6 0
      textsum/data.py
  2. 4 4
      textsum/seq2seq_attention.py

+ 6 - 0
textsum/data.py

@@ -56,6 +56,12 @@ class Vocab(object):
         if self._count > max_size:
           raise ValueError('Too many words: >%d.' % max_size)
 
+  def CheckVocab(self, word):
+    if word not in self._word_to_id:
+      return None
+    return self._word_to_id[word]
+
+  
   def WordToId(self, word):
     if word not in self._word_to_id:
       return self._word_to_id[UNKNOWN_TOKEN]

+ 4 - 4
textsum/seq2seq_attention.py

@@ -160,10 +160,10 @@ def _Eval(model, data_batcher, vocab=None):
 def main(unused_argv):
   vocab = data.Vocab(FLAGS.vocab_path, 1000000)
   # Check for presence of required special tokens.
-  assert vocab.WordToId(data.PAD_TOKEN) > 0
-  assert vocab.WordToId(data.UNKNOWN_TOKEN) >= 0
-  assert vocab.WordToId(data.SENTENCE_START) > 0
-  assert vocab.WordToId(data.SENTENCE_END) > 0
+  assert vocab.CheckVocab(data.PAD_TOKEN) > 0
+  assert vocab.CheckVocab(data.UNKNOWN_TOKEN) >= 0
+  assert vocab.CheckVocab(data.SENTENCE_START) > 0
+  assert vocab.CheckVocab(data.SENTENCE_END) > 0
 
   batch_size = 4
   if FLAGS.mode == 'decode':