| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302 |
- # Copyright 2016 The TensorFlow Authors All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- # pylint: disable=line-too-long
- r"""Training for Domain Separation Networks (DSNs).
- -- Compile:
- $ blaze build -c opt --copt=-mavx --config=cuda \
- third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_train
- -- Run:
- $
- ./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_train
- \
- --similarity_loss=dann \
- --basic_tower=dsn_cropped_linemod \
- --source_dataset=pose_synthetic \
- --target_dataset=pose_real \
- --learning_rate=0.012 \
- --alpha_weight=0.26 \
- --gamma_weight=0.0115 \
- --weight_decay=4e-5 \
- --layers_to_regularize=fc3 \
- --use_separation \
- --alsologtostderr
- """
- # pylint: enable=line-too-long
- from __future__ import division
- import tensorflow as tf
- from domain_adaptation.datasets import dataset_factory
- import dsn
- slim = tf.contrib.slim
- FLAGS = tf.app.flags.FLAGS
- tf.app.flags.DEFINE_integer('batch_size', 32,
- 'The number of images in each batch.')
- tf.app.flags.DEFINE_string('source_dataset', 'pose_synthetic',
- 'Source dataset to train on.')
- tf.app.flags.DEFINE_string('target_dataset', 'pose_real',
- 'Target dataset to train on.')
- tf.app.flags.DEFINE_string('target_labeled_dataset', 'none',
- 'Target dataset to train on.')
- tf.app.flags.DEFINE_string('dataset_dir', '/cns/ok-d/home/konstantinos/cad_learning/',
- 'The directory where the dataset files are stored.')
- tf.app.flags.DEFINE_string('master', '',
- 'BNS name of the TensorFlow master to use.')
- tf.app.flags.DEFINE_string('train_log_dir', '/tmp/da/',
- 'Directory where to write event logs.')
- tf.app.flags.DEFINE_string(
- 'layers_to_regularize', 'fc3',
- 'Comma-separated list of layer names to use MMD regularization on.')
- tf.app.flags.DEFINE_float('learning_rate', .01, 'The learning rate')
- tf.app.flags.DEFINE_float('alpha_weight', 1e-6,
- 'The coefficient for scaling the reconstruction '
- 'loss.')
- tf.app.flags.DEFINE_float(
- 'beta_weight', 1e-6,
- 'The coefficient for scaling the private/shared difference loss.')
- tf.app.flags.DEFINE_float(
- 'gamma_weight', 1e-6,
- 'The coefficient for scaling the shared encoding similarity loss.')
- tf.app.flags.DEFINE_float('pose_weight', 0.125,
- 'The coefficient for scaling the pose loss.')
- tf.app.flags.DEFINE_float(
- 'weight_decay', 1e-6,
- 'The coefficient for the L2 regularization applied for all weights.')
- tf.app.flags.DEFINE_integer(
- 'save_summaries_secs', 60,
- 'The frequency with which summaries are saved, in seconds.')
- tf.app.flags.DEFINE_integer(
- 'save_interval_secs', 60,
- 'The frequency with which the model is saved, in seconds.')
- tf.app.flags.DEFINE_integer(
- 'max_number_of_steps', None,
- 'The maximum number of gradient steps. Use None to train indefinitely.')
- tf.app.flags.DEFINE_integer(
- 'domain_separation_startpoint', 1,
- 'The global step to add the domain separation losses.')
- tf.app.flags.DEFINE_integer(
- 'bipartite_assignment_top_k', 3,
- 'The number of top-k matches to use in bipartite matching adaptation.')
- tf.app.flags.DEFINE_float('decay_rate', 0.95, 'Learning rate decay factor.')
- tf.app.flags.DEFINE_integer('decay_steps', 20000, 'Learning rate decay steps.')
- tf.app.flags.DEFINE_float('momentum', 0.9, 'The momentum value.')
- tf.app.flags.DEFINE_bool('use_separation', False,
- 'Use our domain separation model.')
- tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
- tf.app.flags.DEFINE_integer(
- 'ps_tasks', 0,
- 'The number of parameter servers. If the value is 0, then the parameters '
- 'are handled locally by the worker.')
- tf.app.flags.DEFINE_integer(
- 'num_readers', 4,
- 'The number of parallel readers that read data from the dataset.')
- tf.app.flags.DEFINE_integer('num_preprocessing_threads', 4,
- 'The number of threads used to create the batches.')
- tf.app.flags.DEFINE_integer(
- 'task', 0,
- 'The Task ID. This value is used when training with multiple workers to '
- 'identify each worker.')
- tf.app.flags.DEFINE_string('decoder_name', 'small_decoder',
- 'The decoder to use.')
- tf.app.flags.DEFINE_string('encoder_name', 'default_encoder',
- 'The encoder to use.')
- ################################################################################
- # Flags that control the architecture and losses
- ################################################################################
- tf.app.flags.DEFINE_string(
- 'similarity_loss', 'grl',
- 'The method to use for encouraging the common encoder codes to be '
- 'similar, one of "grl", "mmd", "corr".')
- tf.app.flags.DEFINE_string('recon_loss_name', 'sum_of_pairwise_squares',
- 'The name of the reconstruction loss.')
- tf.app.flags.DEFINE_string('basic_tower', 'pose_mini',
- 'The basic tower building block.')
- def provide_batch_fn():
- """ The provide_batch function to use. """
- return dataset_factory.provide_batch
- def main(_):
- model_params = {
- 'use_separation': FLAGS.use_separation,
- 'domain_separation_startpoint': FLAGS.domain_separation_startpoint,
- 'layers_to_regularize': FLAGS.layers_to_regularize,
- 'alpha_weight': FLAGS.alpha_weight,
- 'beta_weight': FLAGS.beta_weight,
- 'gamma_weight': FLAGS.gamma_weight,
- 'pose_weight': FLAGS.pose_weight,
- 'recon_loss_name': FLAGS.recon_loss_name,
- 'decoder_name': FLAGS.decoder_name,
- 'encoder_name': FLAGS.encoder_name,
- 'weight_decay': FLAGS.weight_decay,
- 'batch_size': FLAGS.batch_size,
- 'use_logging': FLAGS.use_logging,
- 'ps_tasks': FLAGS.ps_tasks,
- 'task': FLAGS.task,
- }
- g = tf.Graph()
- with g.as_default():
- with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
- # Load the data.
- source_images, source_labels = provide_batch_fn()(
- FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
- FLAGS.batch_size, FLAGS.num_preprocessing_threads)
- target_images, target_labels = provide_batch_fn()(
- FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
- FLAGS.batch_size, FLAGS.num_preprocessing_threads)
- # In the unsupervised case all the samples in the labeled
- # domain are from the source domain.
- domain_selection_mask = tf.fill((source_images.get_shape().as_list()[0],),
- True)
- # When using the semisupervised model we include labeled target data in
- # the source labelled data.
- if FLAGS.target_labeled_dataset != 'none':
- # 1000 is the maximum number of labelled target samples that exists in
- # the datasets.
- target_semi_images, target_semi_labels = data_provider.provide(
- FLAGS.target_labeled_dataset, 'train', FLAGS.batch_size)
- # Calculate the proportion of source domain samples in the semi-
- # supervised setting, so that the proportion is set accordingly in the
- # batches.
- proportion = float(source_labels['num_train_samples']) / (
- source_labels['num_train_samples'] +
- target_semi_labels['num_train_samples'])
- rnd_tensor = tf.random_uniform(
- (target_semi_images.get_shape().as_list()[0],))
- domain_selection_mask = rnd_tensor < proportion
- source_images = tf.where(domain_selection_mask, source_images,
- target_semi_images)
- source_class_labels = tf.where(domain_selection_mask,
- source_labels['classes'],
- target_semi_labels['classes'])
- if 'quaternions' in source_labels:
- source_pose_labels = tf.where(domain_selection_mask,
- source_labels['quaternions'],
- target_semi_labels['quaternions'])
- (source_images, source_class_labels, source_pose_labels,
- domain_selection_mask) = tf.train.shuffle_batch(
- [
- source_images, source_class_labels, source_pose_labels,
- domain_selection_mask
- ],
- FLAGS.batch_size,
- 50000,
- 5000,
- num_threads=1,
- enqueue_many=True)
- else:
- (source_images, source_class_labels,
- domain_selection_mask) = tf.train.shuffle_batch(
- [source_images, source_class_labels, domain_selection_mask],
- FLAGS.batch_size,
- 50000,
- 5000,
- num_threads=1,
- enqueue_many=True)
- source_labels = {}
- source_labels['classes'] = source_class_labels
- if 'quaternions' in source_labels:
- source_labels['quaternions'] = source_pose_labels
- slim.get_or_create_global_step()
- tf.summary.image('source_images', source_images, max_outputs=3)
- tf.summary.image('target_images', target_images, max_outputs=3)
- dsn.create_model(
- source_images,
- source_labels,
- domain_selection_mask,
- target_images,
- target_labels,
- FLAGS.similarity_loss,
- model_params,
- basic_tower_name=FLAGS.basic_tower)
- # Configure the optimization scheme:
- learning_rate = tf.train.exponential_decay(
- FLAGS.learning_rate,
- slim.get_or_create_global_step(),
- FLAGS.decay_steps,
- FLAGS.decay_rate,
- staircase=True,
- name='learning_rate')
- tf.summary.scalar('learning_rate', learning_rate)
- tf.summary.scalar('total_loss', tf.losses.get_total_loss())
- opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
- tf.logging.set_verbosity(tf.logging.INFO)
- # Run training.
- loss_tensor = slim.learning.create_train_op(
- slim.losses.get_total_loss(),
- opt,
- summarize_gradients=True,
- colocate_gradients_with_ops=True)
- slim.learning.train(
- train_op=loss_tensor,
- logdir=FLAGS.train_log_dir,
- master=FLAGS.master,
- is_chief=FLAGS.task == 0,
- number_of_steps=FLAGS.max_number_of_steps,
- save_summaries_secs=FLAGS.save_summaries_secs,
- save_interval_secs=FLAGS.save_interval_secs)
- if __name__ == '__main__':
- tf.app.run()
|