dsn.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Functions to create a DSN model and add the different losses to it.
  16. Specifically, in this file we define the:
  17. - Shared Encoding Similarity Loss Module, with:
  18. - The MMD Similarity method
  19. - The Correlation Similarity method
  20. - The Gradient Reversal (Domain-Adversarial) method
  21. - Difference Loss Module
  22. - Reconstruction Loss Module
  23. - Task Loss Module
  24. """
  25. from functools import partial
  26. import tensorflow as tf
  27. import losses
  28. import models
  29. import utils
  30. slim = tf.contrib.slim
  31. ################################################################################
  32. # HELPER FUNCTIONS
  33. ################################################################################
  34. def dsn_loss_coefficient(params):
  35. """The global_step-dependent weight that specifies when to kick in DSN losses.
  36. Args:
  37. params: A dictionary of parameters. Expecting 'domain_separation_startpoint'
  38. Returns:
  39. A weight to that effectively enables or disables the DSN-related losses,
  40. i.e. similarity, difference, and reconstruction losses.
  41. """
  42. return tf.where(
  43. tf.less(slim.get_or_create_global_step(),
  44. params['domain_separation_startpoint']), 1e-10, 1.0)
  45. ################################################################################
  46. # MODEL CREATION
  47. ################################################################################
  48. def create_model(source_images, source_labels, domain_selection_mask,
  49. target_images, target_labels, similarity_loss, params,
  50. basic_tower_name):
  51. """Creates a DSN model.
  52. Args:
  53. source_images: images from the source domain, a tensor of size
  54. [batch_size, height, width, channels]
  55. source_labels: a dictionary with the name, tensor pairs. 'classes' is one-
  56. hot for the number of classes.
  57. domain_selection_mask: a boolean tensor of size [batch_size, ] which denotes
  58. the labeled images that belong to the source domain.
  59. target_images: images from the target domain, a tensor of size
  60. [batch_size, height width, channels].
  61. target_labels: a dictionary with the name, tensor pairs.
  62. similarity_loss: The type of method to use for encouraging
  63. the codes from the shared encoder to be similar.
  64. params: A dictionary of parameters. Expecting 'weight_decay',
  65. 'layers_to_regularize', 'use_separation', 'domain_separation_startpoint',
  66. 'alpha_weight', 'beta_weight', 'gamma_weight', 'recon_loss_name',
  67. 'decoder_name', 'encoder_name'
  68. basic_tower_name: the name of the tower to use for the shared encoder.
  69. Raises:
  70. ValueError: if the arch is not one of the available architectures.
  71. """
  72. network = getattr(models, basic_tower_name)
  73. num_classes = source_labels['classes'].get_shape().as_list()[1]
  74. # Make sure we are using the appropriate number of classes.
  75. network = partial(network, num_classes=num_classes)
  76. # Add the classification/pose estimation loss to the source domain.
  77. source_endpoints = add_task_loss(source_images, source_labels, network,
  78. params)
  79. if similarity_loss == 'none':
  80. # No domain adaptation, we can stop here.
  81. return
  82. with tf.variable_scope('towers', reuse=True):
  83. target_logits, target_endpoints = network(
  84. target_images, weight_decay=params['weight_decay'], prefix='target')
  85. # Plot target accuracy of the train set.
  86. target_accuracy = utils.accuracy(
  87. tf.argmax(target_logits, 1), tf.argmax(target_labels['classes'], 1))
  88. if 'quaternions' in target_labels:
  89. target_quaternion_loss = losses.log_quaternion_loss(
  90. target_labels['quaternions'], target_endpoints['quaternion_pred'],
  91. params)
  92. tf.summary.scalar('eval/Target quaternions', target_quaternion_loss)
  93. tf.summary.scalar('eval/Target accuracy', target_accuracy)
  94. source_shared = source_endpoints[params['layers_to_regularize']]
  95. target_shared = target_endpoints[params['layers_to_regularize']]
  96. # When using the semisupervised model we include labeled target data in the
  97. # source classifier. We do not want to include these target domain when
  98. # we use the similarity loss.
  99. indices = tf.range(0, source_shared.get_shape().as_list()[0])
  100. indices = tf.boolean_mask(indices, domain_selection_mask)
  101. add_similarity_loss(similarity_loss,
  102. tf.gather(source_shared, indices),
  103. tf.gather(target_shared, indices), params)
  104. if params['use_separation']:
  105. add_autoencoders(
  106. source_images,
  107. source_shared,
  108. target_images,
  109. target_shared,
  110. params=params,)
  111. def add_similarity_loss(method_name,
  112. source_samples,
  113. target_samples,
  114. params,
  115. scope=None):
  116. """Adds a loss encouraging the shared encoding from each domain to be similar.
  117. Args:
  118. method_name: the name of the encoding similarity method to use. Valid
  119. options include `dann_loss', `mmd_loss' or `correlation_loss'.
  120. source_samples: a tensor of shape [num_samples, num_features].
  121. target_samples: a tensor of shape [num_samples, num_features].
  122. params: a dictionary of parameters. Expecting 'gamma_weight'.
  123. scope: optional name scope for summary tags.
  124. Raises:
  125. ValueError: if `method_name` is not recognized.
  126. """
  127. weight = dsn_loss_coefficient(params) * params['gamma_weight']
  128. method = getattr(losses, method_name)
  129. method(source_samples, target_samples, weight, scope)
  130. def add_reconstruction_loss(recon_loss_name, images, recons, weight, domain):
  131. """Adds a reconstruction loss.
  132. Args:
  133. recon_loss_name: The name of the reconstruction loss.
  134. images: A `Tensor` of size [batch_size, height, width, 3].
  135. recons: A `Tensor` whose size matches `images`.
  136. weight: A scalar coefficient for the loss.
  137. domain: The name of the domain being reconstructed.
  138. Raises:
  139. ValueError: If `recon_loss_name` is not recognized.
  140. """
  141. if recon_loss_name == 'sum_of_pairwise_squares':
  142. loss_fn = tf.contrib.losses.mean_pairwise_squared_error
  143. elif recon_loss_name == 'sum_of_squares':
  144. loss_fn = tf.contrib.losses.mean_squared_error
  145. else:
  146. raise ValueError('recon_loss_name value [%s] not recognized.' %
  147. recon_loss_name)
  148. loss = loss_fn(recons, images, weight)
  149. assert_op = tf.Assert(tf.is_finite(loss), [loss])
  150. with tf.control_dependencies([assert_op]):
  151. tf.summary.scalar('losses/%s Recon Loss' % domain, loss)
  152. def add_autoencoders(source_data, source_shared, target_data, target_shared,
  153. params):
  154. """Adds the encoders/decoders for our domain separation model w/ incoherence.
  155. Args:
  156. source_data: images from the source domain, a tensor of size
  157. [batch_size, height, width, channels]
  158. source_shared: a tensor with first dimension batch_size
  159. target_data: images from the target domain, a tensor of size
  160. [batch_size, height, width, channels]
  161. target_shared: a tensor with first dimension batch_size
  162. params: A dictionary of parameters. Expecting 'layers_to_regularize',
  163. 'beta_weight', 'alpha_weight', 'recon_loss_name', 'decoder_name',
  164. 'encoder_name', 'weight_decay'
  165. """
  166. def normalize_images(images):
  167. images -= tf.reduce_min(images)
  168. return images / tf.reduce_max(images)
  169. def concat_operation(shared_repr, private_repr):
  170. return shared_repr + private_repr
  171. mu = dsn_loss_coefficient(params)
  172. # The layer to concatenate the networks at.
  173. concat_layer = params['layers_to_regularize']
  174. # The coefficient for modulating the private/shared difference loss.
  175. difference_loss_weight = params['beta_weight'] * mu
  176. # The reconstruction weight.
  177. recon_loss_weight = params['alpha_weight'] * mu
  178. # The reconstruction loss to use.
  179. recon_loss_name = params['recon_loss_name']
  180. # The decoder/encoder to use.
  181. decoder_name = params['decoder_name']
  182. encoder_name = params['encoder_name']
  183. _, height, width, _ = source_data.get_shape().as_list()
  184. code_size = source_shared.get_shape().as_list()[-1]
  185. weight_decay = params['weight_decay']
  186. encoder_fn = getattr(models, encoder_name)
  187. # Target Auto-encoding.
  188. with tf.variable_scope('source_encoder'):
  189. source_endpoints = encoder_fn(
  190. source_data, code_size, weight_decay=weight_decay)
  191. with tf.variable_scope('target_encoder'):
  192. target_endpoints = encoder_fn(
  193. target_data, code_size, weight_decay=weight_decay)
  194. decoder_fn = getattr(models, decoder_name)
  195. decoder = partial(
  196. decoder_fn,
  197. height=height,
  198. width=width,
  199. channels=source_data.get_shape().as_list()[-1],
  200. weight_decay=weight_decay)
  201. # Source Auto-encoding.
  202. source_private = source_endpoints[concat_layer]
  203. target_private = target_endpoints[concat_layer]
  204. with tf.variable_scope('decoder'):
  205. source_recons = decoder(concat_operation(source_shared, source_private))
  206. with tf.variable_scope('decoder', reuse=True):
  207. source_private_recons = decoder(
  208. concat_operation(tf.zeros_like(source_private), source_private))
  209. source_shared_recons = decoder(
  210. concat_operation(source_shared, tf.zeros_like(source_shared)))
  211. with tf.variable_scope('decoder', reuse=True):
  212. target_recons = decoder(concat_operation(target_shared, target_private))
  213. target_shared_recons = decoder(
  214. concat_operation(target_shared, tf.zeros_like(target_shared)))
  215. target_private_recons = decoder(
  216. concat_operation(tf.zeros_like(target_private), target_private))
  217. losses.difference_loss(
  218. source_private,
  219. source_shared,
  220. weight=difference_loss_weight,
  221. name='Source')
  222. losses.difference_loss(
  223. target_private,
  224. target_shared,
  225. weight=difference_loss_weight,
  226. name='Target')
  227. add_reconstruction_loss(recon_loss_name, source_data, source_recons,
  228. recon_loss_weight, 'source')
  229. add_reconstruction_loss(recon_loss_name, target_data, target_recons,
  230. recon_loss_weight, 'target')
  231. # Add summaries
  232. source_reconstructions = tf.concat(
  233. map(normalize_images, [
  234. source_data, source_recons, source_shared_recons,
  235. source_private_recons
  236. ]), 2)
  237. target_reconstructions = tf.concat(
  238. map(normalize_images, [
  239. target_data, target_recons, target_shared_recons,
  240. target_private_recons
  241. ]), 2)
  242. tf.summary.image(
  243. 'Source Images:Recons:RGB',
  244. source_reconstructions[:, :, :, :3],
  245. max_outputs=10)
  246. tf.summary.image(
  247. 'Target Images:Recons:RGB',
  248. target_reconstructions[:, :, :, :3],
  249. max_outputs=10)
  250. if source_reconstructions.get_shape().as_list()[3] == 4:
  251. tf.summary.image(
  252. 'Source Images:Recons:Depth',
  253. source_reconstructions[:, :, :, 3:4],
  254. max_outputs=10)
  255. tf.summary.image(
  256. 'Target Images:Recons:Depth',
  257. target_reconstructions[:, :, :, 3:4],
  258. max_outputs=10)
  259. def add_task_loss(source_images, source_labels, basic_tower, params):
  260. """Adds a classification and/or pose estimation loss to the model.
  261. Args:
  262. source_images: images from the source domain, a tensor of size
  263. [batch_size, height, width, channels]
  264. source_labels: labels from the source domain, a tensor of size [batch_size].
  265. or a tuple of (quaternions, class_labels)
  266. basic_tower: a function that creates the single tower of the model.
  267. params: A dictionary of parameters. Expecting 'weight_decay', 'pose_weight'.
  268. Returns:
  269. The source endpoints.
  270. Raises:
  271. RuntimeError: if basic tower does not support pose estimation.
  272. """
  273. with tf.variable_scope('towers'):
  274. source_logits, source_endpoints = basic_tower(
  275. source_images, weight_decay=params['weight_decay'], prefix='Source')
  276. if 'quaternions' in source_labels: # We have pose estimation as well
  277. if 'quaternion_pred' not in source_endpoints:
  278. raise RuntimeError('Please use a model for estimation e.g. pose_mini')
  279. loss = losses.log_quaternion_loss(source_labels['quaternions'],
  280. source_endpoints['quaternion_pred'],
  281. params)
  282. assert_op = tf.Assert(tf.is_finite(loss), [loss])
  283. with tf.control_dependencies([assert_op]):
  284. quaternion_loss = loss
  285. tf.summary.histogram('log_quaternion_loss_hist', quaternion_loss)
  286. slim.losses.add_loss(quaternion_loss * params['pose_weight'])
  287. tf.summary.scalar('losses/quaternion_loss', quaternion_loss)
  288. classification_loss = tf.losses.softmax_cross_entropy(
  289. source_labels['classes'], source_logits)
  290. tf.summary.scalar('losses/classification_loss', classification_loss)
  291. return source_endpoints