|
@@ -55,10 +55,14 @@ tf.app.flags.DEFINE_integer("batch_size", 64,
|
|
|
"Batch size to use during training.")
|
|
"Batch size to use during training.")
|
|
|
tf.app.flags.DEFINE_integer("size", 1024, "Size of each model layer.")
|
|
tf.app.flags.DEFINE_integer("size", 1024, "Size of each model layer.")
|
|
|
tf.app.flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.")
|
|
tf.app.flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.")
|
|
|
-tf.app.flags.DEFINE_integer("en_vocab_size", 40000, "English vocabulary size.")
|
|
|
|
|
-tf.app.flags.DEFINE_integer("fr_vocab_size", 40000, "French vocabulary size.")
|
|
|
|
|
|
|
+tf.app.flags.DEFINE_integer("from_vocab_size", 40000, "English vocabulary size.")
|
|
|
|
|
+tf.app.flags.DEFINE_integer("to_vocab_size", 40000, "French vocabulary size.")
|
|
|
tf.app.flags.DEFINE_string("data_dir", "/tmp", "Data directory")
|
|
tf.app.flags.DEFINE_string("data_dir", "/tmp", "Data directory")
|
|
|
tf.app.flags.DEFINE_string("train_dir", "/tmp", "Training directory.")
|
|
tf.app.flags.DEFINE_string("train_dir", "/tmp", "Training directory.")
|
|
|
|
|
+tf.app.flags.DEFINE_string("from_train_data", None, "Training data.")
|
|
|
|
|
+tf.app.flags.DEFINE_string("to_train_data", None, "Training data.")
|
|
|
|
|
+tf.app.flags.DEFINE_string("from_dev_data", None, "Training data.")
|
|
|
|
|
+tf.app.flags.DEFINE_string("to_dev_data", None, "Training data.")
|
|
|
tf.app.flags.DEFINE_integer("max_train_data_size", 0,
|
|
tf.app.flags.DEFINE_integer("max_train_data_size", 0,
|
|
|
"Limit on the size of training data (0: no limit).")
|
|
"Limit on the size of training data (0: no limit).")
|
|
|
tf.app.flags.DEFINE_integer("steps_per_checkpoint", 200,
|
|
tf.app.flags.DEFINE_integer("steps_per_checkpoint", 200,
|
|
@@ -119,8 +123,8 @@ def create_model(session, forward_only):
|
|
|
"""Create translation model and initialize or load parameters in session."""
|
|
"""Create translation model and initialize or load parameters in session."""
|
|
|
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
|
|
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
|
|
|
model = seq2seq_model.Seq2SeqModel(
|
|
model = seq2seq_model.Seq2SeqModel(
|
|
|
- FLAGS.en_vocab_size,
|
|
|
|
|
- FLAGS.fr_vocab_size,
|
|
|
|
|
|
|
+ FLAGS.from_vocab_size,
|
|
|
|
|
+ FLAGS.to_vocab_size,
|
|
|
_buckets,
|
|
_buckets,
|
|
|
FLAGS.size,
|
|
FLAGS.size,
|
|
|
FLAGS.num_layers,
|
|
FLAGS.num_layers,
|
|
@@ -142,10 +146,31 @@ def create_model(session, forward_only):
|
|
|
|
|
|
|
|
def train():
|
|
def train():
|
|
|
"""Train a en->fr translation model using WMT data."""
|
|
"""Train a en->fr translation model using WMT data."""
|
|
|
- # Prepare WMT data.
|
|
|
|
|
- print("Preparing WMT data in %s" % FLAGS.data_dir)
|
|
|
|
|
- en_train, fr_train, en_dev, fr_dev, _, _ = data_utils.prepare_wmt_data(
|
|
|
|
|
- FLAGS.data_dir, FLAGS.en_vocab_size, FLAGS.fr_vocab_size)
|
|
|
|
|
|
|
+ from_train = None
|
|
|
|
|
+ to_train = None
|
|
|
|
|
+ from_dev = None
|
|
|
|
|
+ to_dev = None
|
|
|
|
|
+ if FLAGS.from_train_data and FLAGS.to_train_data:
|
|
|
|
|
+ from_train_data = FLAGS.from_train_data
|
|
|
|
|
+ to_train_data = FLAGS.to_train_data
|
|
|
|
|
+ from_dev_data = from_train_data
|
|
|
|
|
+ to_dev_data = to_train_data
|
|
|
|
|
+ if FLAGS.from_dev_data and FLAGS.to_dev_data:
|
|
|
|
|
+ from_dev_data = FLAGS.from_dev_data
|
|
|
|
|
+ to_dev_data = FLAGS.to_dev_data
|
|
|
|
|
+ from_train, to_train, from_dev, to_dev, _, _ = data_utils.prepare_data(
|
|
|
|
|
+ FLAGS.data_dir,
|
|
|
|
|
+ from_train_data,
|
|
|
|
|
+ to_train_data,
|
|
|
|
|
+ from_dev_data,
|
|
|
|
|
+ to_dev_data,
|
|
|
|
|
+ FLAGS.from_vocab_size,
|
|
|
|
|
+ FLAGS.to_vocab_size)
|
|
|
|
|
+ else:
|
|
|
|
|
+ # Prepare WMT data.
|
|
|
|
|
+ print("Preparing WMT data in %s" % FLAGS.data_dir)
|
|
|
|
|
+ from_train, to_train, from_dev, to_dev, _, _ = data_utils.prepare_wmt_data(
|
|
|
|
|
+ FLAGS.data_dir, FLAGS.from_vocab_size, FLAGS.to_vocab_size)
|
|
|
|
|
|
|
|
with tf.Session() as sess:
|
|
with tf.Session() as sess:
|
|
|
# Create model.
|
|
# Create model.
|
|
@@ -155,8 +180,8 @@ def train():
|
|
|
# Read data into buckets and compute their sizes.
|
|
# Read data into buckets and compute their sizes.
|
|
|
print ("Reading development and training data (limit: %d)."
|
|
print ("Reading development and training data (limit: %d)."
|
|
|
% FLAGS.max_train_data_size)
|
|
% FLAGS.max_train_data_size)
|
|
|
- dev_set = read_data(en_dev, fr_dev)
|
|
|
|
|
- train_set = read_data(en_train, fr_train, FLAGS.max_train_data_size)
|
|
|
|
|
|
|
+ dev_set = read_data(from_dev, to_dev)
|
|
|
|
|
+ train_set = read_data(from_train, to_train, FLAGS.max_train_data_size)
|
|
|
train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
|
|
train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
|
|
|
train_total_size = float(sum(train_bucket_sizes))
|
|
train_total_size = float(sum(train_bucket_sizes))
|
|
|
|
|
|
|
@@ -225,9 +250,9 @@ def decode():
|
|
|
|
|
|
|
|
# Load vocabularies.
|
|
# Load vocabularies.
|
|
|
en_vocab_path = os.path.join(FLAGS.data_dir,
|
|
en_vocab_path = os.path.join(FLAGS.data_dir,
|
|
|
- "vocab%d.en" % FLAGS.en_vocab_size)
|
|
|
|
|
|
|
+ "vocab%d.from" % FLAGS.from_vocab_size)
|
|
|
fr_vocab_path = os.path.join(FLAGS.data_dir,
|
|
fr_vocab_path = os.path.join(FLAGS.data_dir,
|
|
|
- "vocab%d.fr" % FLAGS.fr_vocab_size)
|
|
|
|
|
|
|
+ "vocab%d.to" % FLAGS.to_vocab_size)
|
|
|
en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
|
|
en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
|
|
|
_, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)
|
|
_, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)
|
|
|
|
|
|
|
@@ -245,7 +270,7 @@ def decode():
|
|
|
bucket_id = i
|
|
bucket_id = i
|
|
|
break
|
|
break
|
|
|
else:
|
|
else:
|
|
|
- logging.warning("Sentence truncated: %s", sentence)
|
|
|
|
|
|
|
+ logging.warning("Sentence truncated: %s", sentence)
|
|
|
|
|
|
|
|
# Get a 1-element batch to feed the sentence to the model.
|
|
# Get a 1-element batch to feed the sentence to the model.
|
|
|
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
|
|
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
|