dsn_train.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. # pylint: disable=line-too-long
  16. r"""Training for Domain Separation Networks (DSNs).
  17. -- Compile:
  18. $ blaze build -c opt --copt=-mavx --config=cuda \
  19. third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_train
  20. -- Run:
  21. $
  22. ./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_train
  23. \
  24. --similarity_loss=dann \
  25. --basic_tower=dsn_cropped_linemod \
  26. --source_dataset=pose_synthetic \
  27. --target_dataset=pose_real \
  28. --learning_rate=0.012 \
  29. --alpha_weight=0.26 \
  30. --gamma_weight=0.0115 \
  31. --weight_decay=4e-5 \
  32. --layers_to_regularize=fc3 \
  33. --use_separation \
  34. --alsologtostderr
  35. """
  36. # pylint: enable=line-too-long
  37. from __future__ import division
  38. import tensorflow as tf
  39. from domain_adaptation.datasets import dataset_factory
  40. import dsn
  41. slim = tf.contrib.slim
  42. FLAGS = tf.app.flags.FLAGS
  43. tf.app.flags.DEFINE_integer('batch_size', 32,
  44. 'The number of images in each batch.')
  45. tf.app.flags.DEFINE_string('source_dataset', 'pose_synthetic',
  46. 'Source dataset to train on.')
  47. tf.app.flags.DEFINE_string('target_dataset', 'pose_real',
  48. 'Target dataset to train on.')
  49. tf.app.flags.DEFINE_string('target_labeled_dataset', 'none',
  50. 'Target dataset to train on.')
  51. tf.app.flags.DEFINE_string('dataset_dir', '/cns/ok-d/home/konstantinos/cad_learning/',
  52. 'The directory where the dataset files are stored.')
  53. tf.app.flags.DEFINE_string('master', '',
  54. 'BNS name of the TensorFlow master to use.')
  55. tf.app.flags.DEFINE_string('train_log_dir', '/tmp/da/',
  56. 'Directory where to write event logs.')
  57. tf.app.flags.DEFINE_string(
  58. 'layers_to_regularize', 'fc3',
  59. 'Comma-separated list of layer names to use MMD regularization on.')
  60. tf.app.flags.DEFINE_float('learning_rate', .01, 'The learning rate')
  61. tf.app.flags.DEFINE_float('alpha_weight', 1e-6,
  62. 'The coefficient for scaling the reconstruction '
  63. 'loss.')
  64. tf.app.flags.DEFINE_float(
  65. 'beta_weight', 1e-6,
  66. 'The coefficient for scaling the private/shared difference loss.')
  67. tf.app.flags.DEFINE_float(
  68. 'gamma_weight', 1e-6,
  69. 'The coefficient for scaling the shared encoding similarity loss.')
  70. tf.app.flags.DEFINE_float('pose_weight', 0.125,
  71. 'The coefficient for scaling the pose loss.')
  72. tf.app.flags.DEFINE_float(
  73. 'weight_decay', 1e-6,
  74. 'The coefficient for the L2 regularization applied for all weights.')
  75. tf.app.flags.DEFINE_integer(
  76. 'save_summaries_secs', 60,
  77. 'The frequency with which summaries are saved, in seconds.')
  78. tf.app.flags.DEFINE_integer(
  79. 'save_interval_secs', 60,
  80. 'The frequency with which the model is saved, in seconds.')
  81. tf.app.flags.DEFINE_integer(
  82. 'max_number_of_steps', None,
  83. 'The maximum number of gradient steps. Use None to train indefinitely.')
  84. tf.app.flags.DEFINE_integer(
  85. 'domain_separation_startpoint', 1,
  86. 'The global step to add the domain separation losses.')
  87. tf.app.flags.DEFINE_integer(
  88. 'bipartite_assignment_top_k', 3,
  89. 'The number of top-k matches to use in bipartite matching adaptation.')
  90. tf.app.flags.DEFINE_float('decay_rate', 0.95, 'Learning rate decay factor.')
  91. tf.app.flags.DEFINE_integer('decay_steps', 20000, 'Learning rate decay steps.')
  92. tf.app.flags.DEFINE_float('momentum', 0.9, 'The momentum value.')
  93. tf.app.flags.DEFINE_bool('use_separation', False,
  94. 'Use our domain separation model.')
  95. tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
  96. tf.app.flags.DEFINE_integer(
  97. 'ps_tasks', 0,
  98. 'The number of parameter servers. If the value is 0, then the parameters '
  99. 'are handled locally by the worker.')
  100. tf.app.flags.DEFINE_integer(
  101. 'num_readers', 4,
  102. 'The number of parallel readers that read data from the dataset.')
  103. tf.app.flags.DEFINE_integer('num_preprocessing_threads', 4,
  104. 'The number of threads used to create the batches.')
  105. tf.app.flags.DEFINE_integer(
  106. 'task', 0,
  107. 'The Task ID. This value is used when training with multiple workers to '
  108. 'identify each worker.')
  109. tf.app.flags.DEFINE_string('decoder_name', 'small_decoder',
  110. 'The decoder to use.')
  111. tf.app.flags.DEFINE_string('encoder_name', 'default_encoder',
  112. 'The encoder to use.')
  113. ################################################################################
  114. # Flags that control the architecture and losses
  115. ################################################################################
  116. tf.app.flags.DEFINE_string(
  117. 'similarity_loss', 'grl',
  118. 'The method to use for encouraging the common encoder codes to be '
  119. 'similar, one of "grl", "mmd", "corr".')
  120. tf.app.flags.DEFINE_string('recon_loss_name', 'sum_of_pairwise_squares',
  121. 'The name of the reconstruction loss.')
  122. tf.app.flags.DEFINE_string('basic_tower', 'pose_mini',
  123. 'The basic tower building block.')
  124. def provide_batch_fn():
  125. """ The provide_batch function to use. """
  126. return dataset_factory.provide_batch
  127. def main(_):
  128. model_params = {
  129. 'use_separation': FLAGS.use_separation,
  130. 'domain_separation_startpoint': FLAGS.domain_separation_startpoint,
  131. 'layers_to_regularize': FLAGS.layers_to_regularize,
  132. 'alpha_weight': FLAGS.alpha_weight,
  133. 'beta_weight': FLAGS.beta_weight,
  134. 'gamma_weight': FLAGS.gamma_weight,
  135. 'pose_weight': FLAGS.pose_weight,
  136. 'recon_loss_name': FLAGS.recon_loss_name,
  137. 'decoder_name': FLAGS.decoder_name,
  138. 'encoder_name': FLAGS.encoder_name,
  139. 'weight_decay': FLAGS.weight_decay,
  140. 'batch_size': FLAGS.batch_size,
  141. 'use_logging': FLAGS.use_logging,
  142. 'ps_tasks': FLAGS.ps_tasks,
  143. 'task': FLAGS.task,
  144. }
  145. g = tf.Graph()
  146. with g.as_default():
  147. with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
  148. # Load the data.
  149. source_images, source_labels = provide_batch_fn()(
  150. FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
  151. FLAGS.batch_size, FLAGS.num_preprocessing_threads)
  152. target_images, target_labels = provide_batch_fn()(
  153. FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
  154. FLAGS.batch_size, FLAGS.num_preprocessing_threads)
  155. # In the unsupervised case all the samples in the labeled
  156. # domain are from the source domain.
  157. domain_selection_mask = tf.fill((source_images.get_shape().as_list()[0],),
  158. True)
  159. # When using the semisupervised model we include labeled target data in
  160. # the source labelled data.
  161. if FLAGS.target_labeled_dataset != 'none':
  162. # 1000 is the maximum number of labelled target samples that exists in
  163. # the datasets.
  164. target_semi_images, target_semi_labels = data_provider.provide(
  165. FLAGS.target_labeled_dataset, 'train', FLAGS.batch_size)
  166. # Calculate the proportion of source domain samples in the semi-
  167. # supervised setting, so that the proportion is set accordingly in the
  168. # batches.
  169. proportion = float(source_labels['num_train_samples']) / (
  170. source_labels['num_train_samples'] +
  171. target_semi_labels['num_train_samples'])
  172. rnd_tensor = tf.random_uniform(
  173. (target_semi_images.get_shape().as_list()[0],))
  174. domain_selection_mask = rnd_tensor < proportion
  175. source_images = tf.where(domain_selection_mask, source_images,
  176. target_semi_images)
  177. source_class_labels = tf.where(domain_selection_mask,
  178. source_labels['classes'],
  179. target_semi_labels['classes'])
  180. if 'quaternions' in source_labels:
  181. source_pose_labels = tf.where(domain_selection_mask,
  182. source_labels['quaternions'],
  183. target_semi_labels['quaternions'])
  184. (source_images, source_class_labels, source_pose_labels,
  185. domain_selection_mask) = tf.train.shuffle_batch(
  186. [
  187. source_images, source_class_labels, source_pose_labels,
  188. domain_selection_mask
  189. ],
  190. FLAGS.batch_size,
  191. 50000,
  192. 5000,
  193. num_threads=1,
  194. enqueue_many=True)
  195. else:
  196. (source_images, source_class_labels,
  197. domain_selection_mask) = tf.train.shuffle_batch(
  198. [source_images, source_class_labels, domain_selection_mask],
  199. FLAGS.batch_size,
  200. 50000,
  201. 5000,
  202. num_threads=1,
  203. enqueue_many=True)
  204. source_labels = {}
  205. source_labels['classes'] = source_class_labels
  206. if 'quaternions' in source_labels:
  207. source_labels['quaternions'] = source_pose_labels
  208. slim.get_or_create_global_step()
  209. tf.summary.image('source_images', source_images, max_outputs=3)
  210. tf.summary.image('target_images', target_images, max_outputs=3)
  211. dsn.create_model(
  212. source_images,
  213. source_labels,
  214. domain_selection_mask,
  215. target_images,
  216. target_labels,
  217. FLAGS.similarity_loss,
  218. model_params,
  219. basic_tower_name=FLAGS.basic_tower)
  220. # Configure the optimization scheme:
  221. learning_rate = tf.train.exponential_decay(
  222. FLAGS.learning_rate,
  223. slim.get_or_create_global_step(),
  224. FLAGS.decay_steps,
  225. FLAGS.decay_rate,
  226. staircase=True,
  227. name='learning_rate')
  228. tf.summary.scalar('learning_rate', learning_rate)
  229. tf.summary.scalar('total_loss', tf.losses.get_total_loss())
  230. opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
  231. tf.logging.set_verbosity(tf.logging.INFO)
  232. # Run training.
  233. loss_tensor = slim.learning.create_train_op(
  234. slim.losses.get_total_loss(),
  235. opt,
  236. summarize_gradients=True,
  237. colocate_gradients_with_ops=True)
  238. slim.learning.train(
  239. train_op=loss_tensor,
  240. logdir=FLAGS.train_log_dir,
  241. master=FLAGS.master,
  242. is_chief=FLAGS.task == 0,
  243. number_of_steps=FLAGS.max_number_of_steps,
  244. save_summaries_secs=FLAGS.save_summaries_secs,
  245. save_interval_secs=FLAGS.save_interval_secs)
  246. if __name__ == '__main__':
  247. tf.app.run()