123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354 |
- # Copyright 2016 The TensorFlow Authors All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Functions to create a DSN model and add the different losses to it.
- Specifically, in this file we define the:
- - Shared Encoding Similarity Loss Module, with:
- - The MMD Similarity method
- - The Correlation Similarity method
- - The Gradient Reversal (Domain-Adversarial) method
- - Difference Loss Module
- - Reconstruction Loss Module
- - Task Loss Module
- """
- from functools import partial
- import tensorflow as tf
- import losses
- import models
- import utils
- slim = tf.contrib.slim
- ################################################################################
- # HELPER FUNCTIONS
- ################################################################################
- def dsn_loss_coefficient(params):
- """The global_step-dependent weight that specifies when to kick in DSN losses.
- Args:
- params: A dictionary of parameters. Expecting 'domain_separation_startpoint'
- Returns:
- A weight to that effectively enables or disables the DSN-related losses,
- i.e. similarity, difference, and reconstruction losses.
- """
- return tf.where(
- tf.less(slim.get_or_create_global_step(),
- params['domain_separation_startpoint']), 1e-10, 1.0)
- ################################################################################
- # MODEL CREATION
- ################################################################################
- def create_model(source_images, source_labels, domain_selection_mask,
- target_images, target_labels, similarity_loss, params,
- basic_tower_name):
- """Creates a DSN model.
- Args:
- source_images: images from the source domain, a tensor of size
- [batch_size, height, width, channels]
- source_labels: a dictionary with the name, tensor pairs. 'classes' is one-
- hot for the number of classes.
- domain_selection_mask: a boolean tensor of size [batch_size, ] which denotes
- the labeled images that belong to the source domain.
- target_images: images from the target domain, a tensor of size
- [batch_size, height width, channels].
- target_labels: a dictionary with the name, tensor pairs.
- similarity_loss: The type of method to use for encouraging
- the codes from the shared encoder to be similar.
- params: A dictionary of parameters. Expecting 'weight_decay',
- 'layers_to_regularize', 'use_separation', 'domain_separation_startpoint',
- 'alpha_weight', 'beta_weight', 'gamma_weight', 'recon_loss_name',
- 'decoder_name', 'encoder_name'
- basic_tower_name: the name of the tower to use for the shared encoder.
- Raises:
- ValueError: if the arch is not one of the available architectures.
- """
- network = getattr(models, basic_tower_name)
- num_classes = source_labels['classes'].get_shape().as_list()[1]
- # Make sure we are using the appropriate number of classes.
- network = partial(network, num_classes=num_classes)
- # Add the classification/pose estimation loss to the source domain.
- source_endpoints = add_task_loss(source_images, source_labels, network,
- params)
- if similarity_loss == 'none':
- # No domain adaptation, we can stop here.
- return
- with tf.variable_scope('towers', reuse=True):
- target_logits, target_endpoints = network(
- target_images, weight_decay=params['weight_decay'], prefix='target')
- # Plot target accuracy of the train set.
- target_accuracy = utils.accuracy(
- tf.argmax(target_logits, 1), tf.argmax(target_labels['classes'], 1))
- if 'quaternions' in target_labels:
- target_quaternion_loss = losses.log_quaternion_loss(
- target_labels['quaternions'], target_endpoints['quaternion_pred'],
- params)
- tf.summary.scalar('eval/Target quaternions', target_quaternion_loss)
- tf.summary.scalar('eval/Target accuracy', target_accuracy)
- source_shared = source_endpoints[params['layers_to_regularize']]
- target_shared = target_endpoints[params['layers_to_regularize']]
- # When using the semisupervised model we include labeled target data in the
- # source classifier. We do not want to include these target domain when
- # we use the similarity loss.
- indices = tf.range(0, source_shared.get_shape().as_list()[0])
- indices = tf.boolean_mask(indices, domain_selection_mask)
- add_similarity_loss(similarity_loss,
- tf.gather(source_shared, indices),
- tf.gather(target_shared, indices), params)
- if params['use_separation']:
- add_autoencoders(
- source_images,
- source_shared,
- target_images,
- target_shared,
- params=params,)
- def add_similarity_loss(method_name,
- source_samples,
- target_samples,
- params,
- scope=None):
- """Adds a loss encouraging the shared encoding from each domain to be similar.
- Args:
- method_name: the name of the encoding similarity method to use. Valid
- options include `dann_loss', `mmd_loss' or `correlation_loss'.
- source_samples: a tensor of shape [num_samples, num_features].
- target_samples: a tensor of shape [num_samples, num_features].
- params: a dictionary of parameters. Expecting 'gamma_weight'.
- scope: optional name scope for summary tags.
- Raises:
- ValueError: if `method_name` is not recognized.
- """
- weight = dsn_loss_coefficient(params) * params['gamma_weight']
- method = getattr(losses, method_name)
- method(source_samples, target_samples, weight, scope)
- def add_reconstruction_loss(recon_loss_name, images, recons, weight, domain):
- """Adds a reconstruction loss.
- Args:
- recon_loss_name: The name of the reconstruction loss.
- images: A `Tensor` of size [batch_size, height, width, 3].
- recons: A `Tensor` whose size matches `images`.
- weight: A scalar coefficient for the loss.
- domain: The name of the domain being reconstructed.
- Raises:
- ValueError: If `recon_loss_name` is not recognized.
- """
- if recon_loss_name == 'sum_of_pairwise_squares':
- loss_fn = tf.contrib.losses.mean_pairwise_squared_error
- elif recon_loss_name == 'sum_of_squares':
- loss_fn = tf.contrib.losses.mean_squared_error
- else:
- raise ValueError('recon_loss_name value [%s] not recognized.' %
- recon_loss_name)
- loss = loss_fn(recons, images, weight)
- assert_op = tf.Assert(tf.is_finite(loss), [loss])
- with tf.control_dependencies([assert_op]):
- tf.summary.scalar('losses/%s Recon Loss' % domain, loss)
- def add_autoencoders(source_data, source_shared, target_data, target_shared,
- params):
- """Adds the encoders/decoders for our domain separation model w/ incoherence.
- Args:
- source_data: images from the source domain, a tensor of size
- [batch_size, height, width, channels]
- source_shared: a tensor with first dimension batch_size
- target_data: images from the target domain, a tensor of size
- [batch_size, height, width, channels]
- target_shared: a tensor with first dimension batch_size
- params: A dictionary of parameters. Expecting 'layers_to_regularize',
- 'beta_weight', 'alpha_weight', 'recon_loss_name', 'decoder_name',
- 'encoder_name', 'weight_decay'
- """
- def normalize_images(images):
- images -= tf.reduce_min(images)
- return images / tf.reduce_max(images)
- def concat_operation(shared_repr, private_repr):
- return shared_repr + private_repr
- mu = dsn_loss_coefficient(params)
- # The layer to concatenate the networks at.
- concat_layer = params['layers_to_regularize']
- # The coefficient for modulating the private/shared difference loss.
- difference_loss_weight = params['beta_weight'] * mu
- # The reconstruction weight.
- recon_loss_weight = params['alpha_weight'] * mu
- # The reconstruction loss to use.
- recon_loss_name = params['recon_loss_name']
- # The decoder/encoder to use.
- decoder_name = params['decoder_name']
- encoder_name = params['encoder_name']
- _, height, width, _ = source_data.get_shape().as_list()
- code_size = source_shared.get_shape().as_list()[-1]
- weight_decay = params['weight_decay']
- encoder_fn = getattr(models, encoder_name)
- # Target Auto-encoding.
- with tf.variable_scope('source_encoder'):
- source_endpoints = encoder_fn(
- source_data, code_size, weight_decay=weight_decay)
- with tf.variable_scope('target_encoder'):
- target_endpoints = encoder_fn(
- target_data, code_size, weight_decay=weight_decay)
- decoder_fn = getattr(models, decoder_name)
- decoder = partial(
- decoder_fn,
- height=height,
- width=width,
- channels=source_data.get_shape().as_list()[-1],
- weight_decay=weight_decay)
- # Source Auto-encoding.
- source_private = source_endpoints[concat_layer]
- target_private = target_endpoints[concat_layer]
- with tf.variable_scope('decoder'):
- source_recons = decoder(concat_operation(source_shared, source_private))
- with tf.variable_scope('decoder', reuse=True):
- source_private_recons = decoder(
- concat_operation(tf.zeros_like(source_private), source_private))
- source_shared_recons = decoder(
- concat_operation(source_shared, tf.zeros_like(source_shared)))
- with tf.variable_scope('decoder', reuse=True):
- target_recons = decoder(concat_operation(target_shared, target_private))
- target_shared_recons = decoder(
- concat_operation(target_shared, tf.zeros_like(target_shared)))
- target_private_recons = decoder(
- concat_operation(tf.zeros_like(target_private), target_private))
- losses.difference_loss(
- source_private,
- source_shared,
- weight=difference_loss_weight,
- name='Source')
- losses.difference_loss(
- target_private,
- target_shared,
- weight=difference_loss_weight,
- name='Target')
- add_reconstruction_loss(recon_loss_name, source_data, source_recons,
- recon_loss_weight, 'source')
- add_reconstruction_loss(recon_loss_name, target_data, target_recons,
- recon_loss_weight, 'target')
- # Add summaries
- source_reconstructions = tf.concat(
- map(normalize_images, [
- source_data, source_recons, source_shared_recons,
- source_private_recons
- ]), 2)
- target_reconstructions = tf.concat(
- map(normalize_images, [
- target_data, target_recons, target_shared_recons,
- target_private_recons
- ]), 2)
- tf.summary.image(
- 'Source Images:Recons:RGB',
- source_reconstructions[:, :, :, :3],
- max_outputs=10)
- tf.summary.image(
- 'Target Images:Recons:RGB',
- target_reconstructions[:, :, :, :3],
- max_outputs=10)
- if source_reconstructions.get_shape().as_list()[3] == 4:
- tf.summary.image(
- 'Source Images:Recons:Depth',
- source_reconstructions[:, :, :, 3:4],
- max_outputs=10)
- tf.summary.image(
- 'Target Images:Recons:Depth',
- target_reconstructions[:, :, :, 3:4],
- max_outputs=10)
- def add_task_loss(source_images, source_labels, basic_tower, params):
- """Adds a classification and/or pose estimation loss to the model.
- Args:
- source_images: images from the source domain, a tensor of size
- [batch_size, height, width, channels]
- source_labels: labels from the source domain, a tensor of size [batch_size].
- or a tuple of (quaternions, class_labels)
- basic_tower: a function that creates the single tower of the model.
- params: A dictionary of parameters. Expecting 'weight_decay', 'pose_weight'.
- Returns:
- The source endpoints.
- Raises:
- RuntimeError: if basic tower does not support pose estimation.
- """
- with tf.variable_scope('towers'):
- source_logits, source_endpoints = basic_tower(
- source_images, weight_decay=params['weight_decay'], prefix='Source')
- if 'quaternions' in source_labels: # We have pose estimation as well
- if 'quaternion_pred' not in source_endpoints:
- raise RuntimeError('Please use a model for estimation e.g. pose_mini')
- loss = losses.log_quaternion_loss(source_labels['quaternions'],
- source_endpoints['quaternion_pred'],
- params)
- assert_op = tf.Assert(tf.is_finite(loss), [loss])
- with tf.control_dependencies([assert_op]):
- quaternion_loss = loss
- tf.summary.histogram('log_quaternion_loss_hist', quaternion_loss)
- slim.losses.add_loss(quaternion_loss * params['pose_weight'])
- tf.summary.scalar('losses/quaternion_loss', quaternion_loss)
- classification_loss = tf.losses.softmax_cross_entropy(
- source_labels['classes'], source_logits)
- tf.summary.scalar('losses/classification_loss', classification_loss)
- return source_endpoints
|