|
@@ -160,10 +160,10 @@ def _Eval(model, data_batcher, vocab=None):
|
|
|
def main(unused_argv):
|
|
|
vocab = data.Vocab(FLAGS.vocab_path, 1000000)
|
|
|
# Check for presence of required special tokens.
|
|
|
- assert vocab.WordToId(data.PAD_TOKEN) > 0
|
|
|
- assert vocab.WordToId(data.UNKNOWN_TOKEN) >= 0
|
|
|
- assert vocab.WordToId(data.SENTENCE_START) > 0
|
|
|
- assert vocab.WordToId(data.SENTENCE_END) > 0
|
|
|
+ assert vocab.CheckVocab(data.PAD_TOKEN) > 0
|
|
|
+ assert vocab.CheckVocab(data.UNKNOWN_TOKEN) >= 0
|
|
|
+ assert vocab.CheckVocab(data.SENTENCE_START) > 0
|
|
|
+ assert vocab.CheckVocab(data.SENTENCE_END) > 0
|
|
|
|
|
|
batch_size = 4
|
|
|
if FLAGS.mode == 'decode':
|