|
@@ -1,19 +1,3 @@
|
|
|
-# Copyright 2015 Google Inc. All Rights Reserved.
|
|
|
-#
|
|
|
-# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
-# you may not use this file except in compliance with the License.
|
|
|
-# You may obtain a copy of the License at
|
|
|
-#
|
|
|
-# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
-#
|
|
|
-# Unless required by applicable law or agreed to in writing, software
|
|
|
-# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
-# See the License for the specific language governing permissions and
|
|
|
-# limitations under the License.
|
|
|
-#
|
|
|
-#==============================================================================
|
|
|
-
|
|
|
"""Neural GPU for Learning Algorithms."""
|
|
|
|
|
|
import math
|
|
@@ -22,16 +6,15 @@ import random
|
|
|
import sys
|
|
|
import time
|
|
|
|
|
|
-import google3
|
|
|
-
|
|
|
import matplotlib.animation as anim
|
|
|
import matplotlib.pyplot as plt
|
|
|
import numpy as np
|
|
|
import tensorflow as tf
|
|
|
|
|
|
-from google3.third_party.tensorflow.python.platform import gfile
|
|
|
-import google3.experimental.users.lukaszkaiser.neural_gpu.data_utils as data
|
|
|
-import google3.experimental.users.lukaszkaiser.neural_gpu.neural_gpu as ngpu
|
|
|
+from tensorflow.python.platform import gfile
|
|
|
+
|
|
|
+import data_utils as data
|
|
|
+import neural_gpu
|
|
|
|
|
|
tf.app.flags.DEFINE_float("lr", 0.1, "Learning rate.")
|
|
|
tf.app.flags.DEFINE_float("init_weight", 1.0, "Initial weights deviation.")
|
|
@@ -39,7 +22,7 @@ tf.app.flags.DEFINE_float("max_grad_norm", 0.05, "Clip gradients to this norm.")
|
|
|
tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.")
|
|
|
tf.app.flags.DEFINE_float("pull", 0.0005, "Starting pull of the relaxations.")
|
|
|
tf.app.flags.DEFINE_float("pull_incr", 1.2, "Increase pull by that much.")
|
|
|
-tf.app.flags.DEFINE_float("dropout", 0.2, "Dropout that much.")
|
|
|
+tf.app.flags.DEFINE_float("dropout", 0.15, "Dropout that much.")
|
|
|
tf.app.flags.DEFINE_float("grad_noise_scale", 1.0, "Gradient noise scale.")
|
|
|
tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size.")
|
|
|
tf.app.flags.DEFINE_integer("low_batch_size", 16, "Low batch size.")
|
|
@@ -63,6 +46,7 @@ tf.app.flags.DEFINE_string("task", "rev", "Which task are we learning?")
|
|
|
tf.app.flags.DEFINE_string("train_dir", "/tmp/", "Directory to store models.")
|
|
|
|
|
|
FLAGS = tf.app.flags.FLAGS
|
|
|
+EXTRA_EVAL = 12
|
|
|
|
|
|
|
|
|
def initialize(sess):
|
|
@@ -83,7 +67,7 @@ def initialize(sess):
|
|
|
min_length = 3
|
|
|
max_length = min(FLAGS.max_length, data.bins[-1])
|
|
|
assert max_length + 1 > min_length
|
|
|
- while len(data.bins) > 1 and data.bins[-2] > max_length + 12:
|
|
|
+ while len(data.bins) > 1 and data.bins[-2] > max_length + EXTRA_EVAL:
|
|
|
data.bins = data.bins[:-1]
|
|
|
assert data.bins[0] > FLAGS.rx_step
|
|
|
nclass = min(FLAGS.niclass, FLAGS.noclass)
|
|
@@ -92,7 +76,7 @@ def initialize(sess):
|
|
|
# Initialize data for each task.
|
|
|
tasks = FLAGS.task.split("-")
|
|
|
for t in tasks:
|
|
|
- for l in xrange(max_length + 11):
|
|
|
+ for l in xrange(max_length + EXTRA_EVAL - 1):
|
|
|
data.init_data(t, l, data_size, nclass)
|
|
|
data.init_data(t, data.bins[-2], data_size, nclass)
|
|
|
data.init_data(t, data.bins[-1], data_size, nclass)
|
|
@@ -101,14 +85,14 @@ def initialize(sess):
|
|
|
|
|
|
# Print out parameters.
|
|
|
curriculum = 0.12
|
|
|
- fin = ("cv %d kw %d h %d kh %d rxr %d bs %d ns %.2f t %s"
|
|
|
- % (FLAGS.nconvs, FLAGS.kw, FLAGS.height, FLAGS.kh, FLAGS.rx_step,
|
|
|
- FLAGS.batch_size, FLAGS.grad_noise_scale, FLAGS.task))
|
|
|
- fin = "data %d %s" % (FLAGS.train_data_size, fin)
|
|
|
- tag = ("df %.2f p %.3f lr %.2f iw %.2f cr %.2f nm %d d%.4f gn %.2f %s" %
|
|
|
- (FLAGS.cutoff, FLAGS.pull_incr, FLAGS.lr, FLAGS.init_weight,
|
|
|
- curriculum, FLAGS.nmaps, FLAGS.dropout, FLAGS.max_grad_norm, fin))
|
|
|
- data.print_out(tag)
|
|
|
+ msg1 = ("layers %d kw %d h %d kh %d relax %d batch %d noise %.2f task %s"
|
|
|
+ % (FLAGS.nconvs, FLAGS.kw, FLAGS.height, FLAGS.kh, FLAGS.rx_step,
|
|
|
+ FLAGS.batch_size, FLAGS.grad_noise_scale, FLAGS.task))
|
|
|
+ msg2 = "data %d %s" % (FLAGS.train_data_size, msg1)
|
|
|
+ msg3 = ("cut %.2f pull %.3f lr %.2f iw %.2f cr %.2f nm %d d%.4f gn %.2f %s" %
|
|
|
+ (FLAGS.cutoff, FLAGS.pull_incr, FLAGS.lr, FLAGS.init_weight,
|
|
|
+ curriculum, FLAGS.nmaps, FLAGS.dropout, FLAGS.max_grad_norm, msg2))
|
|
|
+ data.print_out(msg3)
|
|
|
|
|
|
# Create checkpoint directory if it does not exist.
|
|
|
checkpoint_dir = os.path.join(FLAGS.train_dir, "neural_gpu%s"
|
|
@@ -120,7 +104,7 @@ def initialize(sess):
|
|
|
# Create model and initialize it.
|
|
|
tf.get_variable_scope().set_initializer(
|
|
|
tf.uniform_unit_scaling_initializer(factor=1.8 * FLAGS.init_weight))
|
|
|
- model = ngpu.NeuralGPU(
|
|
|
+ model = neural_gpu.NeuralGPU(
|
|
|
FLAGS.nmaps, FLAGS.nmaps, FLAGS.niclass, FLAGS.noclass, FLAGS.dropout,
|
|
|
FLAGS.rx_step, FLAGS.max_grad_norm, FLAGS.cutoff, FLAGS.nconvs,
|
|
|
FLAGS.kw, FLAGS.kh, FLAGS.height, FLAGS.mode, FLAGS.lr,
|
|
@@ -145,131 +129,148 @@ def single_test(l, model, sess, task, nprint, batch_size, print_out=True,
|
|
|
"""Test model on test data of length l using the given session."""
|
|
|
inpt, target = data.get_batch(l, batch_size, False, task, offset)
|
|
|
_, res, _, steps = model.step(sess, inpt, target, False)
|
|
|
- errors, total, seq = data.accuracy(inpt, res, target, batch_size, nprint)
|
|
|
- seq = float(seq) / batch_size
|
|
|
+ errors, total, seq_err = data.accuracy(inpt, res, target, batch_size, nprint)
|
|
|
+ seq_err = float(seq_err) / batch_size
|
|
|
if total > 0:
|
|
|
errors = float(errors) / total
|
|
|
if print_out:
|
|
|
data.print_out(" %s len %d errors %.2f sequence-errors %.2f"
|
|
|
- % (task, l, 100*errors, 100*seq))
|
|
|
- return errors, seq, (steps, inpt, [np.argmax(o, axis=1) for o in res])
|
|
|
+ % (task, l, 100*errors, 100*seq_err))
|
|
|
+ return errors, seq_err, (steps, inpt, [np.argmax(o, axis=1) for o in res])
|
|
|
|
|
|
|
|
|
def multi_test(l, model, sess, task, nprint, batch_size, offset=None):
|
|
|
"""Run multiple tests at lower batch size to save memory."""
|
|
|
- errors = 0.0
|
|
|
- seq = 0.0
|
|
|
+ errors, seq_err = 0.0, 0.0
|
|
|
to_print = nprint
|
|
|
low_batch = FLAGS.low_batch_size
|
|
|
low_batch = min(low_batch, batch_size)
|
|
|
for mstep in xrange(batch_size / low_batch):
|
|
|
cur_offset = None if offset is None else offset + mstep * low_batch
|
|
|
- err, sq, _ = single_test(l, model, sess, task, to_print, low_batch, False,
|
|
|
- cur_offset)
|
|
|
+ err, sq_err, _ = single_test(l, model, sess, task, to_print, low_batch,
|
|
|
+ False, cur_offset)
|
|
|
to_print = max(0, to_print - low_batch)
|
|
|
errors += err
|
|
|
- seq += sq
|
|
|
+ seq_err += sq_err
|
|
|
if FLAGS.mode > 0:
|
|
|
cur_errors = float(low_batch * errors) / ((mstep+1) * low_batch)
|
|
|
- cur_seq = float(low_batch * seq) / ((mstep+1) * low_batch)
|
|
|
+ cur_seq_err = float(low_batch * seq_err) / ((mstep+1) * low_batch)
|
|
|
data.print_out(" %s multitest current errors %.2f sequence-errors %.2f"
|
|
|
- % (task, 100*cur_errors, 100*cur_seq))
|
|
|
+ % (task, 100*cur_errors, 100*cur_seq_err))
|
|
|
errors = float(low_batch) * float(errors) / batch_size
|
|
|
- seq = float(low_batch) * float(seq) / batch_size
|
|
|
+ seq_err = float(low_batch) * float(seq_err) / batch_size
|
|
|
data.print_out(" %s len %d errors %.2f sequence-errors %.2f"
|
|
|
- % (task, l, 100*errors, 100*seq))
|
|
|
- return errors, seq
|
|
|
+ % (task, l, 100*errors, 100*seq_err))
|
|
|
+ return errors, seq_err
|
|
|
|
|
|
|
|
|
def train():
|
|
|
- """Main training function."""
|
|
|
+ """Train the model."""
|
|
|
batch_size = FLAGS.batch_size
|
|
|
tasks = FLAGS.task.split("-")
|
|
|
with tf.Session() as sess:
|
|
|
model, min_length, max_length, checkpoint_dir, curriculum = initialize(sess)
|
|
|
max_cur_length = min(min_length + 3, max_length)
|
|
|
prev_acc_perp = [1000000 for _ in xrange(3)]
|
|
|
- prev_sq = 1.0
|
|
|
+ prev_seq_err = 1.0
|
|
|
|
|
|
+ # Main traning loop.
|
|
|
while True:
|
|
|
global_step, pull, max_cur_length, learning_rate = sess.run(
|
|
|
[model.global_step, model.pull, model.cur_length, model.lr])
|
|
|
- ep = global_step / FLAGS.steps_per_checkpoint
|
|
|
- acc_loss, acc_total, acc_errors, acc_seq = 0.0, 0, 0, 0
|
|
|
+ acc_loss, acc_total, acc_errors, acc_seq_err = 0.0, 0, 0, 0
|
|
|
acc_grad_norm, step_count, step_time = 0.0, 0, 0.0
|
|
|
for _ in xrange(FLAGS.steps_per_checkpoint):
|
|
|
global_step += 1
|
|
|
task = random.choice(tasks)
|
|
|
- l1 = np.random.randint(max_cur_length - min_length + 1) + min_length
|
|
|
- l = l1
|
|
|
- if np.random.randint(10) > 3: # Prefer longer stuff 60% of time.
|
|
|
- l = np.random.randint(max_cur_length - min_length+1) + min_length
|
|
|
+
|
|
|
+ # Select the length for curriculum learning.
|
|
|
+ l = np.random.randint(max_cur_length - min_length + 1) + min_length
|
|
|
+ # Prefer longer stuff 60% of time.
|
|
|
+ if np.random.randint(100) < 60:
|
|
|
+ l1 = np.random.randint(max_cur_length - min_length+1) + min_length
|
|
|
l = max(l, l1)
|
|
|
- if np.random.randint(4) < 1: # Mixed learning: once in a while big.
|
|
|
- l = np.random.randint(max_length - min_length + 1) + min_length
|
|
|
+ # Mixed curriculum learning: in 25% of cases go to any larger length.
|
|
|
+ if np.random.randint(100) < 25:
|
|
|
+ l1 = np.random.randint(max_length - min_length + 1) + min_length
|
|
|
l = max(l, l1)
|
|
|
+
|
|
|
+ # Run a step and time it.
|
|
|
start_time = time.time()
|
|
|
inp, target = data.get_batch(l, batch_size, True, task)
|
|
|
- stepp = math.pow(global_step, -0.55)
|
|
|
- noise_param = math.sqrt(stepp * 20 * prev_sq) * FLAGS.grad_noise_scale
|
|
|
+ noise_param = math.sqrt(math.pow(global_step, -0.55) *
|
|
|
+ (20 * prev_seq_err)) * FLAGS.grad_noise_scale
|
|
|
loss, res, gnorm, _ = model.step(sess, inp, target, True, noise_param)
|
|
|
step_time += time.time() - start_time
|
|
|
acc_grad_norm += float(gnorm)
|
|
|
+
|
|
|
+ # Accumulate statistics only if we did not exceed curriculum length.
|
|
|
if l < max_cur_length + 1:
|
|
|
step_count += 1
|
|
|
acc_loss += loss
|
|
|
- errors, total, seq = data.accuracy(inp, res, target,
|
|
|
- batch_size, 0)
|
|
|
+ errors, total, seq_err = data.accuracy(inp, res, target,
|
|
|
+ batch_size, 0)
|
|
|
acc_total += total
|
|
|
acc_errors += errors
|
|
|
- acc_seq += seq
|
|
|
+ acc_seq_err += seq_err
|
|
|
+
|
|
|
+ # Normalize and print out accumulated statistics.
|
|
|
acc_loss /= step_count
|
|
|
step_time /= FLAGS.steps_per_checkpoint
|
|
|
- acc_seq = float(acc_seq) / (step_count * batch_size)
|
|
|
- prev_sq = acc_seq
|
|
|
+ acc_seq_err = float(acc_seq_err) / (step_count * batch_size)
|
|
|
+ prev_seq_err = acc_seq_err
|
|
|
acc_errors = float(acc_errors) / acc_total if acc_total > 0 else 1.0
|
|
|
- msg1 = "ep %d st %.2f lr %.8f" % (ep, step_time, learning_rate)
|
|
|
- msg2 = "pl %.3f cme %.3f" % (pull, curriculum)
|
|
|
- msg = ("%s %s gn %.8f"
|
|
|
- % (msg1, msg2, acc_grad_norm / FLAGS.steps_per_checkpoint))
|
|
|
- data.print_out("%s len %d ppx %.8f errs %.2f sq %.2f" %
|
|
|
- (msg, max_cur_length, data.safe_exp(acc_loss),
|
|
|
- 100*acc_errors, 100*acc_seq))
|
|
|
- if curriculum > acc_seq:
|
|
|
- prev_acc_perp.append(1000000)
|
|
|
+ msg1 = "step %d step-time %.2f" % (global_step, step_time)
|
|
|
+ msg2 = "lr %.8f pull %.3f" % (learning_rate, pull)
|
|
|
+ msg3 = ("%s %s grad-norm %.8f"
|
|
|
+ % (msg1, msg2, acc_grad_norm / FLAGS.steps_per_checkpoint))
|
|
|
+ data.print_out("%s len %d ppx %.8f errors %.2f sequence-errors %.2f" %
|
|
|
+ (msg3, max_cur_length, data.safe_exp(acc_loss),
|
|
|
+ 100*acc_errors, 100*acc_seq_err))
|
|
|
+
|
|
|
+ # If errors are below the curriculum threshold, move curriculum forward.
|
|
|
+ if curriculum > acc_seq_err:
|
|
|
+ # Increase current length (until the next with training data).
|
|
|
do_incr = True
|
|
|
while do_incr and max_cur_length < max_length:
|
|
|
sess.run(model.cur_length_incr_op)
|
|
|
for t in tasks:
|
|
|
if data.train_set[t]: do_incr = False
|
|
|
+ # Forget last perplexities if we're not yet at the end.
|
|
|
+ if max_cur_length < max_length:
|
|
|
+ prev_acc_perp.append(1000000)
|
|
|
+ # Either increase pull or, if it's large, average parameters.
|
|
|
if pull < 1:
|
|
|
sess.run(model.pull_incr_op)
|
|
|
else:
|
|
|
data.print_out(" Averaging parameters.")
|
|
|
sess.run([model.avg_op, model.lr_decay_op])
|
|
|
- else:
|
|
|
- acc_perp = data.safe_exp(acc_loss)
|
|
|
- if acc_perp > max(prev_acc_perp[-3:]):
|
|
|
- sess.run(model.lr_decay_op)
|
|
|
- prev_acc_perp.append(acc_perp)
|
|
|
+
|
|
|
+ # Lower learning rate if we're worse than the last 3 checkpoints.
|
|
|
+ acc_perp = data.safe_exp(acc_loss)
|
|
|
+ if acc_perp > max(prev_acc_perp[-3:]):
|
|
|
+ sess.run(model.lr_decay_op)
|
|
|
+ prev_acc_perp.append(acc_perp)
|
|
|
+
|
|
|
+ # Save checkpoint.
|
|
|
checkpoint_path = os.path.join(checkpoint_dir, "neural_gpu.ckpt")
|
|
|
model.saver.save(sess, checkpoint_path,
|
|
|
global_step=model.global_step)
|
|
|
+
|
|
|
# Run evaluation.
|
|
|
- should_exit = True
|
|
|
bound = data.bins[-1] + 1
|
|
|
for t in tasks:
|
|
|
l = min_length
|
|
|
- while l < max_length + 12 and l < bound:
|
|
|
- _, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
|
|
|
+ while l < max_length + EXTRA_EVAL and l < bound:
|
|
|
+ _, seq_err, _ = single_test(l, model, sess, t,
|
|
|
+ FLAGS.nprint, batch_size)
|
|
|
l += 1
|
|
|
while l < bound + 1 and not data.test_set[t][l]:
|
|
|
l += 1
|
|
|
- if sq < 0.5:
|
|
|
- _, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
|
|
|
- batch_size * 4)
|
|
|
- if sq > 0.001: should_exit = False
|
|
|
- if should_exit:
|
|
|
+ if seq_err < 0.5: # Run larger test if we're good enough.
|
|
|
+ _, seq_err = multi_test(data.forward_max, model, sess, t,
|
|
|
+ FLAGS.nprint, batch_size * 4)
|
|
|
+ if seq_err < 0.01: # Super-large test on 1-task large-forward models.
|
|
|
if data.forward_max > 4000 and len(tasks) == 1:
|
|
|
multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
|
|
|
batch_size * 16, 0)
|
|
@@ -277,14 +278,17 @@ def train():
|
|
|
|
|
|
def animate(l, test_data, anim_size):
|
|
|
"""Create animation for the given data (hacky matplotlib use)."""
|
|
|
- xf = 12
|
|
|
- fps = 2
|
|
|
+ xf = 12 # Extra frames to slow down at start and end.
|
|
|
+ fps = 2 # Frames per step.
|
|
|
+
|
|
|
+ # Make the figure.
|
|
|
fig = plt.figure(figsize=(16, 9), facecolor="white")
|
|
|
ax = fig.add_axes([0, 0, 1, 1], frameon=False, zorder=2)
|
|
|
ax.set_xticks([i * 24-0.5 for i in xrange(4)])
|
|
|
ax.set_xticklabels([])
|
|
|
ax.set_yticks([i - 0.5 for i in xrange(l+1)])
|
|
|
ax.grid(which="major", axis="both", linestyle="-", color="black")
|
|
|
+ # We need text fields.
|
|
|
text_fields = []
|
|
|
text_size = 24*32/l
|
|
|
for y in xrange(l):
|
|
@@ -296,11 +300,8 @@ def animate(l, test_data, anim_size):
|
|
|
vmax=1.0, cmap="gray", aspect="auto", origin="upper",
|
|
|
interpolation="none", animated=True)
|
|
|
im.set_zorder(1)
|
|
|
- def to_symbol(i):
|
|
|
- if i == 0: return ""
|
|
|
- if i == 11: return "+"
|
|
|
- if i == 12: return "*"
|
|
|
- return str(i-1)
|
|
|
+
|
|
|
+ # Main animation step.
|
|
|
def animation_update(frame_no, test_data, xf, im, text_fields):
|
|
|
"""Update an animation frame."""
|
|
|
steps, inpt, out_raw = test_data
|
|
@@ -319,15 +320,17 @@ def animate(l, test_data, anim_size):
|
|
|
if index - 2*xf < length:
|
|
|
t.set_text("")
|
|
|
else:
|
|
|
- t.set_text(to_symbol(out[i]))
|
|
|
+ t.set_text(data.to_symbol(out[i]))
|
|
|
else:
|
|
|
for i, t in enumerate(text_fields):
|
|
|
- t.set_text(to_symbol(inpt[i][batch]) if index < xf else "")
|
|
|
+ t.set_text(data.to_symbol(inpt[i][batch]) if index < xf else "")
|
|
|
if index < xf:
|
|
|
im.set_array(np.zeros_like(steps[0][0]))
|
|
|
else:
|
|
|
im.set_array(steps[0][batch])
|
|
|
return im,
|
|
|
+
|
|
|
+ # Create the animation and save to mp4.
|
|
|
animation = anim.FuncAnimation(
|
|
|
fig, animation_update, blit=True, frames=(l+4*xf)*anim_size*fps,
|
|
|
interval=500/fps, fargs=(test_data, xf, im, text_fields))
|
|
@@ -343,8 +346,8 @@ def evaluate():
|
|
|
bound = data.bins[-1] + 1
|
|
|
for t in tasks:
|
|
|
l = min_length
|
|
|
- while l < max_length + 12 and l < bound:
|
|
|
- _, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
|
|
|
+ while l < max_length + EXTRA_EVAL and l < bound:
|
|
|
+ _, seq_err, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
|
|
|
l += 1
|
|
|
while l < bound + 1 and not data.test_set[t][l]:
|
|
|
l += 1
|
|
@@ -353,9 +356,9 @@ def evaluate():
|
|
|
_, _, test_data = single_test(l, model, sess, t, 0, anim_size)
|
|
|
animate(l, test_data, anim_size)
|
|
|
# More tests.
|
|
|
- _, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
|
|
|
- batch_size * 4)
|
|
|
- if sq < 0.01: # More tests.
|
|
|
+ _, seq_err = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
|
|
|
+ batch_size * 4)
|
|
|
+ if seq_err < 0.01: # Super-test if we're very good and in large-test mode.
|
|
|
if data.forward_max > 4000 and len(tasks) == 1:
|
|
|
multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
|
|
|
batch_size * 64, 0)
|
|
@@ -365,16 +368,18 @@ def interactive():
|
|
|
"""Interactively probe an existing model."""
|
|
|
with tf.Session() as sess:
|
|
|
model, _, _, _, _ = initialize(sess)
|
|
|
+ sys.stdout.write("Input to Neural GPU, e.g., 0 1. Use -1 for PAD.\n")
|
|
|
sys.stdout.write("> ")
|
|
|
sys.stdout.flush()
|
|
|
inpt = sys.stdin.readline()
|
|
|
while inpt:
|
|
|
- ids = [int(c) for c in inpt.strip()]
|
|
|
+ ids = [data.to_id(s) for s in inpt.strip().split()]
|
|
|
inpt, target = data.get_batch(len(ids), 1, False, "",
|
|
|
preset=(ids, [0 for _ in ids]))
|
|
|
_, res, _, _ = model.step(sess, inpt, target, False)
|
|
|
res = [np.argmax(o, axis=1) for o in res]
|
|
|
- print " ".join([str(output[0]) for output in res])
|
|
|
+ res = [o for o in res[:len(ids)] if o > 0]
|
|
|
+ print " " + " ".join([data.to_symbol(output[0]) for output in res])
|
|
|
sys.stdout.write("> ")
|
|
|
sys.stdout.flush()
|
|
|
inpt = sys.stdin.readline()
|