Kaynağa Gözat

Ability to train the translation model on arbitrary input sources.

Viacheslav Kovalevskyi 8 yıl önce
ebeveyn
işleme
c902a86785

+ 50 - 18
tutorials/rnn/translate/data_utils.py

@@ -239,8 +239,8 @@ def data_to_token_ids(data_path, target_path, vocabulary_path,
           counter += 1
           if counter % 100000 == 0:
             print("  tokenizing line %d" % counter)
-          token_ids = sentence_to_token_ids(tf.compat.as_bytes(line), vocab,
-                                            tokenizer, normalize_digits)
+          token_ids = sentence_to_token_ids(line, vocab, tokenizer,
+                                            normalize_digits)
           tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")
 
 
@@ -267,24 +267,56 @@ def prepare_wmt_data(data_dir, en_vocabulary_size, fr_vocabulary_size, tokenizer
   train_path = get_wmt_enfr_train_set(data_dir)
   dev_path = get_wmt_enfr_dev_set(data_dir)
 
+  from_train_path = train_path + ".en"
+  to_train_path = train_path + ".fr"
+  from_dev_path = dev_path + ".en"
+  to_dev_path = dev_path + ".fr"
+  return prepare_data(data_dir, from_train_path, to_train_path, from_dev_path, to_dev_path, en_vocabulary_size,
+                      fr_vocabulary_size, tokenizer)
+
+
+def prepare_data(data_dir, from_train_path, to_train_path, from_dev_path, to_dev_path, from_vocabulary_size,
+                 to_vocabulary_size, tokenizer=None):
+  """Preapre all necessary files that are required for the training.
+
+    Args:
+      data_dir: directory in which the data sets will be stored.
+      from_train_path: path to the file that includes "from" training samples.
+      to_train_path: path to the file that includes "to" training samples.
+      from_dev_path: path to the file that includes "from" dev samples.
+      to_dev_path: path to the file that includes "to" dev samples.
+      from_vocabulary_size: size of the "from language" vocabulary to create and use.
+      to_vocabulary_size: size of the "to language" vocabulary to create and use.
+      tokenizer: a function to use to tokenize each data sentence;
+        if None, basic_tokenizer will be used.
+
+    Returns:
+      A tuple of 6 elements:
+        (1) path to the token-ids for "from language" training data-set,
+        (2) path to the token-ids for "to language" training data-set,
+        (3) path to the token-ids for "from language" development data-set,
+        (4) path to the token-ids for "to language" development data-set,
+        (5) path to the "from language" vocabulary file,
+        (6) path to the "to language" vocabulary file.
+    """
   # Create vocabularies of the appropriate sizes.
-  fr_vocab_path = os.path.join(data_dir, "vocab%d.fr" % fr_vocabulary_size)
-  en_vocab_path = os.path.join(data_dir, "vocab%d.en" % en_vocabulary_size)
-  create_vocabulary(fr_vocab_path, train_path + ".fr", fr_vocabulary_size, tokenizer)
-  create_vocabulary(en_vocab_path, train_path + ".en", en_vocabulary_size, tokenizer)
+  to_vocab_path = os.path.join(data_dir, "vocab%d" % to_vocabulary_size)
+  from_vocab_path = os.path.join(data_dir, "vocab%d" % from_vocabulary_size)
+  create_vocabulary(to_vocab_path, to_train_path , to_vocabulary_size, tokenizer)
+  create_vocabulary(from_vocab_path, from_train_path , from_vocabulary_size, tokenizer)
 
   # Create token ids for the training data.
-  fr_train_ids_path = train_path + (".ids%d.fr" % fr_vocabulary_size)
-  en_train_ids_path = train_path + (".ids%d.en" % en_vocabulary_size)
-  data_to_token_ids(train_path + ".fr", fr_train_ids_path, fr_vocab_path, tokenizer)
-  data_to_token_ids(train_path + ".en", en_train_ids_path, en_vocab_path, tokenizer)
+  to_train_ids_path = to_train_path + (".ids%d" % to_vocabulary_size)
+  from_train_ids_path = from_train_path + (".ids%d" % from_vocabulary_size)
+  data_to_token_ids(to_train_path, to_train_ids_path, to_vocab_path, tokenizer)
+  data_to_token_ids(from_train_path, from_train_ids_path, from_vocab_path, tokenizer)
 
   # Create token ids for the development data.
-  fr_dev_ids_path = dev_path + (".ids%d.fr" % fr_vocabulary_size)
-  en_dev_ids_path = dev_path + (".ids%d.en" % en_vocabulary_size)
-  data_to_token_ids(dev_path + ".fr", fr_dev_ids_path, fr_vocab_path, tokenizer)
-  data_to_token_ids(dev_path + ".en", en_dev_ids_path, en_vocab_path, tokenizer)
-
-  return (en_train_ids_path, fr_train_ids_path,
-          en_dev_ids_path, fr_dev_ids_path,
-          en_vocab_path, fr_vocab_path)
+  to_dev_ids_path = to_dev_path + (".ids%d" % to_vocabulary_size)
+  from_dev_ids_path = from_dev_path + (".ids%d" % from_vocabulary_size)
+  data_to_token_ids(to_dev_path, to_dev_ids_path, to_vocab_path, tokenizer)
+  data_to_token_ids(from_dev_path, from_dev_ids_path, from_vocab_path, tokenizer)
+
+  return (from_train_ids_path, to_train_ids_path,
+          from_dev_ids_path, to_dev_ids_path,
+          from_vocab_path, to_vocab_path)

+ 8 - 13
tutorials/rnn/translate/seq2seq_model.py

@@ -108,27 +108,22 @@ class Seq2SeqModel(object):
         local_b = tf.cast(b, tf.float32)
         local_inputs = tf.cast(inputs, tf.float32)
         return tf.cast(
-            tf.nn.sampled_softmax_loss(
-                weights=local_w_t,
-                biases=local_b,
-                labels=labels,
-                inputs=local_inputs,
-                num_sampled=num_samples,
-                num_classes=self.target_vocab_size),
+            tf.nn.sampled_softmax_loss(local_w_t, local_b, local_inputs, labels,
+                                       num_samples, self.target_vocab_size),
             dtype)
       softmax_loss_function = sampled_loss
 
     # Create the internal multi-layer cell for our RNN.
-    single_cell = tf.contrib.rnn.GRUCell(size)
+    single_cell = tf.nn.rnn_cell.GRUCell(size)
     if use_lstm:
-      single_cell = tf.contrib.rnn.BasicLSTMCell(size)
+      single_cell = tf.nn.rnn_cell.BasicLSTMCell(size)
     cell = single_cell
     if num_layers > 1:
-      cell = tf.contrib.rnn.MultiRNNCell([single_cell] * num_layers)
+      cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers)
 
     # The seq2seq function: we use embedding for the input and attention.
     def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
-      return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
+      return tf.nn.seq2seq.embedding_attention_seq2seq(
           encoder_inputs,
           decoder_inputs,
           cell,
@@ -158,7 +153,7 @@ class Seq2SeqModel(object):
 
     # Training outputs and losses.
     if forward_only:
-      self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
+      self.outputs, self.losses = tf.nn.seq2seq.model_with_buckets(
           self.encoder_inputs, self.decoder_inputs, targets,
           self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, True),
           softmax_loss_function=softmax_loss_function)
@@ -170,7 +165,7 @@ class Seq2SeqModel(object):
               for output in self.outputs[b]
           ]
     else:
-      self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
+      self.outputs, self.losses = tf.nn.seq2seq.model_with_buckets(
           self.encoder_inputs, self.decoder_inputs, targets,
           self.target_weights, buckets,
           lambda x, y: seq2seq_f(x, y, False),

+ 38 - 13
tutorials/rnn/translate/translate.py

@@ -55,10 +55,14 @@ tf.app.flags.DEFINE_integer("batch_size", 64,
                             "Batch size to use during training.")
 tf.app.flags.DEFINE_integer("size", 1024, "Size of each model layer.")
 tf.app.flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.")
-tf.app.flags.DEFINE_integer("en_vocab_size", 40000, "English vocabulary size.")
-tf.app.flags.DEFINE_integer("fr_vocab_size", 40000, "French vocabulary size.")
+tf.app.flags.DEFINE_integer("from_vocab_size", 40000, "English vocabulary size.")
+tf.app.flags.DEFINE_integer("to_vocab_size", 40000, "French vocabulary size.")
 tf.app.flags.DEFINE_string("data_dir", "/tmp", "Data directory")
 tf.app.flags.DEFINE_string("train_dir", "/tmp", "Training directory.")
+tf.app.flags.DEFINE_string("from_train_data", None, "Training data.")
+tf.app.flags.DEFINE_string("to_train_data", None, "Training data.")
+tf.app.flags.DEFINE_string("from_dev_data", None, "Training data.")
+tf.app.flags.DEFINE_string("to_dev_data", None, "Training data.")
 tf.app.flags.DEFINE_integer("max_train_data_size", 0,
                             "Limit on the size of training data (0: no limit).")
 tf.app.flags.DEFINE_integer("steps_per_checkpoint", 200,
@@ -119,8 +123,8 @@ def create_model(session, forward_only):
   """Create translation model and initialize or load parameters in session."""
   dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
   model = seq2seq_model.Seq2SeqModel(
-      FLAGS.en_vocab_size,
-      FLAGS.fr_vocab_size,
+      FLAGS.from_vocab_size,
+      FLAGS.to_vocab_size,
       _buckets,
       FLAGS.size,
       FLAGS.num_layers,
@@ -142,10 +146,31 @@ def create_model(session, forward_only):
 
 def train():
   """Train a en->fr translation model using WMT data."""
-  # Prepare WMT data.
-  print("Preparing WMT data in %s" % FLAGS.data_dir)
-  en_train, fr_train, en_dev, fr_dev, _, _ = data_utils.prepare_wmt_data(
-      FLAGS.data_dir, FLAGS.en_vocab_size, FLAGS.fr_vocab_size)
+  from_train = None
+  to_train = None
+  from_dev = None
+  to_dev = None
+  if FLAGS.from_train_data and FLAGS.to_train_data:
+    from_train_data = FLAGS.from_train_data
+    to_train_data = FLAGS.to_train_data
+    from_dev_data = from_train_data
+    to_dev_data = to_train_data
+    if FLAGS.from_dev_data and FLAGS.to_dev_data:
+      from_dev_data = FLAGS.from_dev_data
+      to_dev_data = FLAGS.to_dev_data
+    from_train, to_train, from_dev, to_dev, _, _ = data_utils.prepare_data(
+        FLAGS.data_dir,
+        from_train_data,
+        to_train_data,
+        from_dev_data,
+        to_dev_data,
+        FLAGS.from_vocab_size,
+        FLAGS.to_vocab_size)
+  else:
+      # Prepare WMT data.
+      print("Preparing WMT data in %s" % FLAGS.data_dir)
+      from_train, to_train, from_dev, to_dev, _, _ = data_utils.prepare_wmt_data(
+          FLAGS.data_dir, FLAGS.from_vocab_size, FLAGS.to_vocab_size)
 
   with tf.Session() as sess:
     # Create model.
@@ -155,8 +180,8 @@ def train():
     # Read data into buckets and compute their sizes.
     print ("Reading development and training data (limit: %d)."
            % FLAGS.max_train_data_size)
-    dev_set = read_data(en_dev, fr_dev)
-    train_set = read_data(en_train, fr_train, FLAGS.max_train_data_size)
+    dev_set = read_data(from_dev, to_dev)
+    train_set = read_data(from_train, to_train, FLAGS.max_train_data_size)
     train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
     train_total_size = float(sum(train_bucket_sizes))
 
@@ -225,9 +250,9 @@ def decode():
 
     # Load vocabularies.
     en_vocab_path = os.path.join(FLAGS.data_dir,
-                                 "vocab%d.en" % FLAGS.en_vocab_size)
+                                 "vocab%d.from" % FLAGS.from_vocab_size)
     fr_vocab_path = os.path.join(FLAGS.data_dir,
-                                 "vocab%d.fr" % FLAGS.fr_vocab_size)
+                                 "vocab%d.to" % FLAGS.to_vocab_size)
     en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
     _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)
 
@@ -245,7 +270,7 @@ def decode():
           bucket_id = i
           break
       else:
-        logging.warning("Sentence truncated: %s", sentence) 
+        logging.warning("Sentence truncated: %s", sentence)
 
       # Get a 1-element batch to feed the sentence to the model.
       encoder_inputs, decoder_inputs, target_weights = model.get_batch(