瀏覽代碼

DSN infrastructure staging

Konstantinos Bousmalis 8 年之前
父節點
當前提交
89c7c98711
共有 35 個文件被更改,包括 4361 次插入0 次删除
  1. 4 0
      domain_adaptation/OWNERS
  2. 0 0
      domain_adaptation/WORKSPACE
  3. 0 0
      domain_adaptation/__init__.py
  4. 43 0
      domain_adaptation/datasets/BUILD
  5. 0 0
      domain_adaptation/datasets/__init__.py
  6. 106 0
      domain_adaptation/datasets/dataset_factory.py
  7. 243 0
      domain_adaptation/datasets/download_and_convert_mnist_m.py
  8. 97 0
      domain_adaptation/datasets/mnist_m.py
  9. 165 0
      domain_adaptation/domain_separation/#models_test.py#
  10. 1 0
      domain_adaptation/domain_separation/.#models_test.py
  11. 157 0
      domain_adaptation/domain_separation/.pipertmp-2H2v0i-dsn_eval.py
  12. 152 0
      domain_adaptation/domain_separation/.pipertmp-9mVtwS-dsn_eval.py
  13. 157 0
      domain_adaptation/domain_separation/.pipertmp-Ckvhfy-dsn_eval.py
  14. 214 0
      domain_adaptation/domain_separation/.pipertmp-OiMpXz-dsn_eval.py
  15. 152 0
      domain_adaptation/domain_separation/.pipertmp-WMYPqp-dsn_eval.py
  16. 229 0
      domain_adaptation/domain_separation/.pipertmp-son4h0-dsn_eval.py
  17. 185 0
      domain_adaptation/domain_separation/BUILD
  18. 41 0
      domain_adaptation/domain_separation/README.md
  19. 0 0
      domain_adaptation/domain_separation/__init__.py
  20. 二進制
      domain_adaptation/domain_separation/_grl_ops.so
  21. 353 0
      domain_adaptation/domain_separation/dsn.py
  22. 175 0
      domain_adaptation/domain_separation/dsn_eval.py
  23. 157 0
      domain_adaptation/domain_separation/dsn_test.py
  24. 301 0
      domain_adaptation/domain_separation/dsn_train.py
  25. 34 0
      domain_adaptation/domain_separation/grl_op_grads.py
  26. 47 0
      domain_adaptation/domain_separation/grl_op_kernels.cc
  27. 16 0
      domain_adaptation/domain_separation/grl_op_shapes.py
  28. 36 0
      domain_adaptation/domain_separation/grl_ops.cc
  29. 28 0
      domain_adaptation/domain_separation/grl_ops.py
  30. 73 0
      domain_adaptation/domain_separation/grl_ops_test.py
  31. 292 0
      domain_adaptation/domain_separation/losses.py
  32. 110 0
      domain_adaptation/domain_separation/losses_test.py
  33. 443 0
      domain_adaptation/domain_separation/models.py
  34. 167 0
      domain_adaptation/domain_separation/models_test.py
  35. 183 0
      domain_adaptation/domain_separation/utils.py

+ 4 - 0
domain_adaptation/OWNERS

@@ -0,0 +1,4 @@
+konstantinos
+nsilberman
+dilipkay
+dumitru

+ 0 - 0
domain_adaptation/WORKSPACE


+ 0 - 0
domain_adaptation/__init__.py


+ 43 - 0
domain_adaptation/datasets/BUILD

@@ -0,0 +1,43 @@
+# Domain Adaptation Scenarios Datasets
+
+package(
+    default_visibility = [
+        ":internal",
+    ],
+)
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package_group(
+    name = "internal",
+    packages = [
+        "//domain_adaptation/...",
+    ],
+)
+
+py_library(
+    name = "dataset_factory",
+    srcs = ["dataset_factory.py"],
+    deps = [
+        ":mnist_m",
+        "//slim:mnist",
+    ],
+)
+
+py_binary(
+    name = "download_and_convert_mnist_m",
+    srcs = ["download_and_convert_mnist_m.py"],
+    deps = [
+        "//slim:dataset_utils",
+    ],
+)
+
+py_binary(
+    name = "mnist_m",
+    srcs = ["mnist_m.py"],
+    deps = [
+        "//slim:dataset_utils",
+    ],
+)

+ 0 - 0
domain_adaptation/datasets/__init__.py


+ 106 - 0
domain_adaptation/datasets/dataset_factory.py

@@ -0,0 +1,106 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A factory-pattern class which returns image/label pairs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from slim.datasets import mnist
+from domain_adaptation.datasets import mnist_m
+
+slim = tf.contrib.slim
+
+
+def get_dataset(dataset_name,
+                split_name,
+                dataset_dir,
+                file_pattern=None,
+                reader=None):
+  """Given a dataset name and a split_name returns a Dataset.
+
+  Args:
+    dataset_name: String, the name of the dataset.
+    split_name: A train/test split name.
+    dataset_dir: The directory where the dataset files are stored.
+    file_pattern: The file pattern to use for matching the dataset source files.
+    reader: The subclass of tf.ReaderBase. If left as `None`, then the default
+      reader defined by each dataset is used.
+
+  Returns:
+    A tf-slim `Dataset` class.
+
+  Raises:
+    ValueError: if `dataset_name` isn't recognized.
+  """
+  dataset_name_to_module = {'mnist': mnist, 'mnist_m': mnist_m}
+  if dataset_name not in dataset_name_to_module:
+    raise ValueError('Name of dataset unknown %s.' % dataset_name)
+
+  return dataset_name_to_module[dataset_name].get_split(split_name, dataset_dir,
+                                                        file_pattern, reader)
+
+
+def provide_batch(dataset_name, split_name, dataset_dir, num_readers,
+                  batch_size, num_preprocessing_threads):
+  """Provides a batch of images and corresponding labels.
+
+    Args:
+    dataset_name: String, the name of the dataset.
+    split_name: A train/test split name.
+    dataset_dir: The directory where the dataset files are stored.
+    num_readers: The number of readers used by DatasetDataProvider.
+    batch_size: The size of the batch requested.
+    num_preprocessing_threads: The number of preprocessing threads for
+      tf.train.batch.
+    file_pattern: The file pattern to use for matching the dataset source files.
+    reader: The subclass of tf.ReaderBase. If left as `None`, then the default
+      reader defined by each dataset is used.
+
+  Returns:
+    A batch of
+      images: tensor of [batch_size, height, width, channels].
+      labels: dictionary of labels.
+  """
+  dataset = get_dataset(dataset_name, split_name, dataset_dir)
+  provider = slim.dataset_data_provider.DatasetDataProvider(
+      dataset,
+      num_readers=num_readers,
+      common_queue_capacity=20 * batch_size,
+      common_queue_min=10 * batch_size)
+  [image, label] = provider.get(['image', 'label'])
+
+  # Convert images to float32
+  image = tf.image.convert_image_dtype(image, tf.float32)
+  image -= 0.5
+  image *= 2
+
+  # Load the data.
+  labels = {}
+  images, labels['classes'] = tf.train.batch(
+      [image, label],
+      batch_size=batch_size,
+      num_threads=num_preprocessing_threads,
+      capacity=5 * batch_size)
+  labels['classes'] = slim.one_hot_encoding(labels['classes'],
+                                            dataset.num_classes)
+
+  # Convert mnist to RGB and 32x32 so that it can match mnist_m.
+  if dataset_name == 'mnist':
+    images = tf.image.grayscale_to_rgb(images)
+    images = tf.image.resize_images(images, [32, 32])
+  return images, labels

+ 243 - 0
domain_adaptation/datasets/download_and_convert_mnist_m.py

@@ -0,0 +1,243 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Downloads and converts MNIST-M data to TFRecords of TF-Example protos.
+
+This module downloads the MNIST-M data, uncompresses it, reads the files
+that make up the MNIST-M data and creates two TFRecord datasets: one for train
+and one for test. Each TFRecord dataset is comprised of a set of TF-Example
+protocol buffers, each of which contain a single image and label.
+
+The script should take about a minute to run.
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import random
+import sys
+
+import google3
+import numpy as np
+from six.moves import urllib
+import tensorflow as tf
+
+from google3.third_party.tensorflow_models.slim.datasets import dataset_utils
+
+tf.app.flags.DEFINE_string(
+    'dataset_dir', None,
+    'The directory where the output TFRecords and temporary files are saved.')
+
+FLAGS = tf.app.flags.FLAGS
+
+# The URLs where the MNIST-M data can be downloaded.
+_DATA_URL = 'http://yann.lecun.com/exdb/mnist/'
+_TRAIN_DATA_DIR = 'mnist_m_train'
+_TRAIN_LABELS_FILENAME = 'mnist_m_train_labels'
+_TEST_DATA_DIR = 'mnist_m_test'
+_TEST_LABELS_FILENAME = 'mnist_m_test_labels'
+
+_IMAGE_SIZE = 32
+_NUM_CHANNELS = 3
+
+# The number of images in the training set.
+_NUM_TRAIN_SAMPLES = 59001
+
+# The number of images to be kept from the training set for the validation set.
+_NUM_VALIDATION = 1000
+
+# The number of images in the test set.
+_NUM_TEST_SAMPLES = 9001
+
+# Seed for repeatability.
+_RANDOM_SEED = 0
+
+# The names of the classes.
+_CLASS_NAMES = [
+    'zero',
+    'one',
+    'two',
+    'three',
+    'four',
+    'five',
+    'size',
+    'seven',
+    'eight',
+    'nine',
+]
+
+
+class ImageReader(object):
+  """Helper class that provides TensorFlow image coding utilities."""
+
+  def __init__(self):
+    # Initializes function that decodes RGB PNG data.
+    self._decode_png_data = tf.placeholder(dtype=tf.string)
+    self._decode_png = tf.image.decode_png(self._decode_png_data, channels=3)
+
+  def read_image_dims(self, sess, image_data):
+    image = self.decode_png(sess, image_data)
+    return image.shape[0], image.shape[1]
+
+  def decode_png(self, sess, image_data):
+    image = sess.run(
+        self._decode_png, feed_dict={self._decode_png_data: image_data})
+    assert len(image.shape) == 3
+    assert image.shape[2] == 3
+    return image
+
+
+def _convert_dataset(split_name, filenames, filename_to_class_id, dataset_dir):
+  """Converts the given filenames to a TFRecord dataset.
+
+  Args:
+    split_name: The name of the dataset, either 'train' or 'valid'.
+    filenames: A list of absolute paths to png images.
+    filename_to_class_id: A dictionary from filenames (strings) to class ids
+      (integers).
+    dataset_dir: The directory where the converted datasets are stored.
+  """
+  print('Converting the {} split.'.format(split_name))
+  # Train and validation splits are both in the train directory.
+  if split_name in ['train', 'valid']:
+    png_directory = os.path.join(dataset_dir, 'mnist-m', 'mnist_m_train')
+  elif split_name == 'test':
+    png_directory = os.path.join(dataset_dir, 'mnist-m', 'mnist_m_test')
+
+  with tf.Graph().as_default():
+    image_reader = ImageReader()
+
+    with tf.Session('') as sess:
+      output_filename = _get_output_filename(dataset_dir, split_name)
+
+      with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
+        for filename in filenames:
+          # Read the filename:
+          image_data = tf.gfile.FastGFile(
+              os.path.join(png_directory, filename), 'r').read()
+          height, width = image_reader.read_image_dims(sess, image_data)
+
+          class_id = filename_to_class_id[filename]
+          example = dataset_utils.image_to_tfexample(image_data, 'png', height,
+                                                     width, class_id)
+          tfrecord_writer.write(example.SerializeToString())
+
+  sys.stdout.write('\n')
+  sys.stdout.flush()
+
+
+def _extract_labels(label_filename):
+  """Extract the labels into a dict of filenames to int labels.
+
+  Args:
+    labels_filename: The filename of the MNIST-M labels.
+
+  Returns:
+    A dictionary of filenames to int labels.
+  """
+  print('Extracting labels from: ', label_filename)
+  label_file = tf.gfile.FastGFile(label_filename, 'r').readlines()
+  label_lines = [line.rstrip('\n').split() for line in label_file]
+  labels = {}
+  for line in label_lines:
+    assert len(line) == 2
+    labels[line[0]] = int(line[1])
+  return labels
+
+
+def _get_output_filename(dataset_dir, split_name):
+  """Creates the output filename.
+
+  Args:
+    dataset_dir: The directory where the temporary files are stored.
+    split_name: The name of the train/test split.
+
+  Returns:
+    An absolute file path.
+  """
+  return '%s/mnist_m_%s.tfrecord' % (dataset_dir, split_name)
+
+
+def _get_filenames(dataset_dir):
+  """Returns a list of filenames and inferred class names.
+
+  Args:
+    dataset_dir: A directory containing a set PNG encoded MNIST-M images.
+
+  Returns:
+    A list of image file paths, relative to `dataset_dir`.
+  """
+  photo_filenames = []
+  for filename in os.listdir(dataset_dir):
+    photo_filenames.append(filename)
+  return photo_filenames
+
+
+def run(dataset_dir):
+  """Runs the download and conversion operation.
+
+  Args:
+    dataset_dir: The dataset directory where the dataset is stored.
+  """
+  if not tf.gfile.Exists(dataset_dir):
+    tf.gfile.MakeDirs(dataset_dir)
+
+  train_filename = _get_output_filename(dataset_dir, 'train')
+  testing_filename = _get_output_filename(dataset_dir, 'test')
+
+  if tf.gfile.Exists(train_filename) and tf.gfile.Exists(testing_filename):
+    print('Dataset files already exist. Exiting without re-creating them.')
+    return
+
+  #TODO(konstantinos): Add download and cleanup functionality
+
+  train_validation_filenames = _get_filenames(
+      os.path.join(dataset_dir, 'mnist-m', 'mnist_m_train'))
+  test_filenames = _get_filenames(
+      os.path.join(dataset_dir, 'mnist-m', 'mnist_m_test'))
+
+  # Divide into train and validation:
+  random.seed(_RANDOM_SEED)
+  random.shuffle(train_validation_filenames)
+  train_filenames = train_validation_filenames[_NUM_VALIDATION:]
+  validation_filenames = train_validation_filenames[:_NUM_VALIDATION]
+
+  train_validation_filenames_to_class_ids = _extract_labels(
+      os.path.join(dataset_dir, 'mnist-m', 'mnist_m_train_labels.txt'))
+  test_filenames_to_class_ids = _extract_labels(
+      os.path.join(dataset_dir, 'mnist-m', 'mnist_m_test_labels.txt'))
+
+  # Convert the train, validation, and test sets.
+  _convert_dataset('train', train_filenames,
+                   train_validation_filenames_to_class_ids, dataset_dir)
+  _convert_dataset('valid', validation_filenames,
+                   train_validation_filenames_to_class_ids, dataset_dir)
+  _convert_dataset('test', test_filenames, test_filenames_to_class_ids,
+                   dataset_dir)
+
+  # Finally, write the labels file:
+  labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
+  dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
+
+  print('\nFinished converting the MNIST-M dataset!')
+
+
+def main(_):
+  run(FLAGS.dataset_dir)
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 97 - 0
domain_adaptation/datasets/mnist_m.py

@@ -0,0 +1,97 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Provides data for the MNIST-M dataset.
+
+The dataset scripts used to create the dataset can be found at:
+tensorflow_models/domain_adaptation_/datasets/download_and_convert_mnist_m_dataset.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tensorflow as tf
+
+from slim.datasets import dataset_utils
+
+slim = tf.contrib.slim
+
+_FILE_PATTERN = 'mnist_m_%s.tfrecord'
+
+_SPLITS_TO_SIZES = {'train': 58001, 'valid': 1000, 'test': 9001}
+
+_NUM_CLASSES = 10
+
+_ITEMS_TO_DESCRIPTIONS = {
+    'image': 'A [32 x 32 x 1] RGB image.',
+    'label': 'A single integer between 0 and 9',
+}
+
+
+def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
+  """Gets a dataset tuple with instructions for reading MNIST.
+
+  Args:
+    split_name: A train/test split name.
+    dataset_dir: The base directory of the dataset sources.
+
+  Returns:
+    A `Dataset` namedtuple.
+
+  Raises:
+    ValueError: if `split_name` is not a valid train/test split.
+  """
+  if split_name not in _SPLITS_TO_SIZES:
+    raise ValueError('split name %s was not recognized.' % split_name)
+
+  if not file_pattern:
+    file_pattern = _FILE_PATTERN
+  file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
+
+  # Allowing None in the signature so that dataset_factory can use the default.
+  if reader is None:
+    reader = tf.TFRecordReader
+
+  keys_to_features = {
+      'image/encoded':
+          tf.FixedLenFeature((), tf.string, default_value=''),
+      'image/format':
+          tf.FixedLenFeature((), tf.string, default_value='png'),
+      'image/class/label':
+          tf.FixedLenFeature(
+              [1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)),
+  }
+
+  items_to_handlers = {
+      'image': slim.tfexample_decoder.Image(shape=[32, 32, 3], channels=3),
+      'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]),
+  }
+
+  decoder = slim.tfexample_decoder.TFExampleDecoder(
+      keys_to_features, items_to_handlers)
+
+  labels_to_names = None
+  if dataset_utils.has_labels(dataset_dir):
+    labels_to_names = dataset_utils.read_label_file(dataset_dir)
+
+  return slim.dataset.Dataset(
+      data_sources=file_pattern,
+      reader=reader,
+      decoder=decoder,
+      num_samples=_SPLITS_TO_SIZES[split_name],
+      num_classes=_NUM_CLASSES,
+      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
+      labels_to_names=labels_to_names)

+ 165 - 0
domain_adaptation/domain_separation/#models_test.py#

@@ -0,0 +1,165 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for DSN components."""
+
+import numpy as np
+import tensorflow as tf
+
+import models
+
+class SharedEncodersTest(tf.test.TestCase):
+
+  def _testSharedEncoder(self,
+                         input_shape=[5, 28, 28, 1],
+                         model=models.dann_mnist,
+                         is_training=True):
+    images = tf.to_float(np.random.rand(*input_shape))
+
+    with self.test_session() as sess:
+      logits, _ = model(images)
+      sess.run(tf.global_variables_initializer())
+      logits_np = sess.run(logits)
+    return logits_np
+
+  def testBuildGRLMnistModel(self):
+    logits = self._testSharedEncoder(model=getattr(models,
+                                                   'dann_mnist'))
+    self.assertEqual(logits.shape, (5, 10))
+    self.assertTrue(np.any(logits))
+
+  def testBuildGRLSvhnModel(self):
+    logits = self._testSharedEncoder(model=getattr(models,
+                                                   'dann_svhn'))
+    self.assertEqual(logits.shape, (5, 10))
+    self.assertTrue(np.any(logits))
+
+  def testBuildGRLGtsrbModel(self):
+    logits = self._testSharedEncoder([5, 40, 40, 3],
+                                     getattr(models, 'dann_gtsrb'))
+    self.assertEqual(logits.shape, (5, 43))
+    self.assertTrue(np.any(logits))
+
+  def testBuildPoseModel(self):
+    logits = self._testSharedEncoder([5, 64, 64, 4],
+                                     getattr(models, 'dsn_cropped_linemod'))
+    self.assertEqual(logits.shape, (5, 11))
+    self.assertTrue(np.any(logits))
+
+  def testBuildPoseModelWithBatchNorm(self):
+    images = tf.to_float(np.random.rand(10, 64, 64, 4))
+
+    with self.test_session() as sess:
+      logits, _ = getattr(models, 'dsn_cropped_linemod')(
+          images, batch_norm_params=models.default_batch_norm_params(True))
+      sess.run(tf.global_variables_initializer())
+      logits_np = sess.run(logits)
+    self.assertEqual(logits_np.shape, (10, 11))
+    self.assertTrue(np.any(logits_np))
+
+
+class EncoderTest(tf.test.TestCase):
+
+  def _testEncoder(self, batch_norm_params=None, channels=1):
+    images = tf.to_float(np.random.rand(10, 28, 28, channels))
+
+    with self.test_session() as sess:
+      end_points = models.default_encoder(
+          images, 128, batch_norm_params=batch_norm_params)
+      sess.run(tf.global_variables_initializer())
+      private_code = sess.run(end_points['fc3'])
+    self.assertEqual(private_code.shape, (10, 128))
+    self.assertTrue(np.any(private_code))
+    self.assertTrue(np.all(np.isfinite(private_code)))
+
+  def testEncoder(self):
+    self._testEncoder()
+
+  def testEncoderMultiChannel(self):
+    self._testEncoder(None, 4)
+
+  def testEncoderIsTrainingBatchNorm(self):
+    self._testEncoder(models.default_batch_norm_params(True))
+
+  def testEncoderBatchNorm(self):
+    self._testEncoder(models.default_batch_norm_params(False))
+
+
+class DecoderTest(tf.test.TestCase):
+
+  def _testDecoder(self,
+                   height=64,
+                   width=64,
+                   channels=4,
+                   batch_norm_params=None,
+                   decoder=models.small_decoder):
+    codes = tf.to_float(np.random.rand(32, 100))
+
+    with self.test_session() as sess:
+      output = decoder(
+          codes,
+          height=height,
+          width=width,
+          channels=channels,
+          batch_norm_params=batch_norm_params)
+      sess.run(tf.initialize_all_variables())
+      output_np = sess.run(output)
+    self.assertEqual(output_np.shape, (32, height, width, channels))
+    self.assertTrue(np.any(output_np))
+    self.assertTrue(np.all(np.isfinite(output_np)))
+
+  def testSmallDecoder(self):
+    self._testDecoder(28, 28, 4, None, getattr(models, 'small_decoder'))
+
+  def testSmallDecoderThreeChannels(self):
+    self._testDecoder(28, 28, 3)
+
+  def testSmallDecoderBatchNorm(self):
+    self._testDecoder(28, 28, 4, models.default_batch_norm_params(False))
+
+  def testSmallDecoderIsTrainingBatchNorm(self):
+    self._testDecoder(28, 28, 4, models.default_batch_norm_params(True))
+
+  def testLargeDecoder(self):
+    self._testDecoder(32, 32, 4, None, getattr(models, 'large_decoder'))
+
+  def testLargeDecoderThreeChannels(self):
+    self._testDecoder(32, 32, 3, None, getattr(models, 'large_decoder'))
+
+  def testLargeDecoderBatchNorm(self):
+    self._testDecoder(32, 32, 4,
+                      models.default_batch_norm_params(False),
+                      getattr(models, 'large_decoder'))
+
+  def testLargeDecoderIsTrainingBatchNorm(self):
+    self._testDecoder(32, 32, 4,
+                      models.default_batch_norm_params(True),
+                      getattr(models, 'large_decoder'))
+
+  def testGtsrbDecoder(self):
+    self._testDecoder(40, 40, 3, None, getattr(models, 'large_decoder'))
+
+  def testGtsrbDecoderBatchNorm(self):
+    self._testDecoder(40, 40, 4,
+                      models.default_batch_norm_params(False),
+                      getattr(models, 'gtsrb_decoder'))
+
+  def testGtsrbDecoderIsTrainingBatchNorm(self):
+    self._testDecoder(40, 40, 4,
+                      models.default_batch_norm_params(True),
+                      getattr(models, 'gtsrb_decoder'))
+
+
+if __name__ == '__main__':
+  tf.test.main()

+ 1 - 0
domain_adaptation/domain_separation/.#models_test.py

@@ -0,0 +1 @@
+konstantinos@kalivaki.lon.corp.google.com.139121:1490035651

+ 157 - 0
domain_adaptation/domain_separation/.pipertmp-2H2v0i-dsn_eval.py

@@ -0,0 +1,157 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# pylint: disable=line-too-long
+r"""Evaluation for Domain Separation Networks (DSNs).
+
+To build locally for CPU:
+  blaze build -c opt --copt=-mavx \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
+
+To build locally for GPU:
+  blaze build -c opt --copt=-mavx --config=cuda_clang \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
+
+To run locally:
+$
+./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval
+\
+    --alsologtostderr
+"""
+# pylint: enable=line-too-long
+import math
+
+import google3
+
+import numpy as np
+import tensorflow as tf
+from google3.robotics.cad_learning.domain_adaptation.fnist import data_provider
+from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import models
+
+slim = tf.contrib.slim
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_integer('batch_size', 50,
+                            'The number of images in each batch.')
+
+tf.app.flags.DEFINE_string('master', 'local',
+                           'BNS name of the TensorFlow master to use.')
+
+tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/da/',
+                           'Directory where the model was written to.')
+
+tf.app.flags.DEFINE_string(
+    'eval_dir', '/tmp/da/',
+    'Directory where we should write the tf summaries to.')
+
+tf.app.flags.DEFINE_string(
+    'dataset', 'pose_real',
+    'Which dataset to test on: "pose_real", "pose_synthetic".')
+tf.app.flags.DEFINE_string('portion', 'valid',
+                           'Which portion to test on: "valid", "test".')
+tf.app.flags.DEFINE_integer('num_examples', 1000, 'Number of test examples.')
+
+tf.app.flags.DEFINE_string('basic_tower', 'pose_mini',
+                           'The basic tower building block.')
+tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
+
+
+def quaternion_metric(predictions, labels):
+  product = tf.multiply(predictions, labels)
+  internal_dot_products = tf.reduce_sum(product, [1])
+  logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
+  return tf.contrib.metrics.streaming_mean(logcost)
+
+
+def to_degrees(predictions, labels):
+  """Converts a log quaternion distance to an angle.
+
+  Args:
+    log_quaternion_loss: The log quaternion distance between two
+      unit quaternions (or a batch of pairs of quaternions).
+
+  Returns:
+    The angle in degrees of the implied angle-axis representation.
+  """
+  product = tf.multiply(predictions, labels)
+  internal_dot_products = tf.reduce_sum(product, [1])
+  log_quaternion_loss = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
+  angle_loss = tf.acos(-(tf.exp(log_quaternion_loss) - 1)) * 2 * 180 / math.pi
+  return tf.contrib.metrics.streaming_mean(angle_loss)
+
+
+def main(_):
+  g = tf.Graph()
+  with g.as_default():
+    images, labels = data_provider.provide(FLAGS.dataset, FLAGS.portion,
+                                           FLAGS.batch_size)
+
+    num_classes = labels['classes'].shape[1]
+
+    # Define the model:
+    with tf.variable_scope('towers'):
+      basic_tower = models.provide(FLAGS.basic_tower)
+      predictions, endpoints = basic_tower(
+          images, is_training=False, num_classes=num_classes)
+    names_to_values = {}
+    names_to_updates = {}
+    # Define the metrics:
+    if 'quaternions' in labels:  # Also have to evaluate pose estimation!
+      quaternion_loss = quaternion_metric(labels['quaternions'],
+                                          endpoints['quaternion_pred'])
+
+      metric_name = 'Angle Mean Error'
+      names_to_values[metric_name], names_to_updates[metric_name] = to_degrees(
+          labels['quaternions'], endpoints['quaternion_pred'])
+
+      metric_name = 'Log Quaternion Error'
+      names_to_values[metric_name], names_to_updates[
+          metric_name] = quaternion_metric(labels['quaternions'],
+                                           endpoints['quaternion_pred'])
+      metric_name = 'Accuracy'
+      names_to_values[metric_name], names_to_updates[
+          metric_name] = tf.contrib.metrics.streaming_accuracy(
+              tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
+
+    metric_name = 'Accuracy'
+    names_to_values[metric_name], names_to_updates[
+        metric_name] = tf.contrib.metrics.streaming_accuracy(
+            tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
+
+    # Create the summary ops such that they also print out to std output:
+    summary_ops = []
+    for metric_name, metric_value in names_to_values.iteritems():
+      op = tf.contrib.deprecated.scalar_summary(metric_name, metric_value)
+      op = tf.Print(op, [metric_value], metric_name)
+      summary_ops.append(op)
+
+    # This ensures that we make a single pass over all of the data.
+    num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
+
+    # Setup the global step.
+    slim.get_or_create_global_step()
+
+    slim.evaluation.evaluation_loop(
+        FLAGS.master,
+        checkpoint_dir=FLAGS.checkpoint_dir,
+        logdir=FLAGS.eval_dir,
+        num_evals=num_batches,
+        eval_op=names_to_updates.values(),
+        summary_op=tf.contrib.deprecated.merge_summary(summary_ops))
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 152 - 0
domain_adaptation/domain_separation/.pipertmp-9mVtwS-dsn_eval.py

@@ -0,0 +1,152 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# pylint: disable=line-too-long
+r"""Evaluation for Domain Separation Networks (DSNs).
+
+To build locally for CPU:
+  blaze build -c opt --copt=-mavx \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
+
+To build locally for GPU:
+  blaze build -c opt --copt=-mavx --config=cuda_clang \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
+
+To run locally:
+$
+./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval
+\
+    --alsologtostderr
+"""
+# pylint: enable=line-too-long
+import math
+
+import google3
+
+import numpy as np
+import tensorflow as tf
+from google3.robotics.cad_learning.domain_adaptation.fnist import data_provider
+from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import losses
+from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import models
+
+slim = tf.contrib.slim
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_integer('batch_size', 32,
+                            'The number of images in each batch.')
+
+tf.app.flags.DEFINE_string('master', 'local',
+                           'BNS name of the TensorFlow master to use.')
+
+tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/da/',
+                           'Directory where the model was written to.')
+
+tf.app.flags.DEFINE_string(
+    'eval_dir', '/tmp/da/',
+    'Directory where we should write the tf summaries to.')
+
+tf.app.flags.DEFINE_string(
+    'dataset', 'pose_real',
+    'Which dataset to test on: "pose_real", "pose_synthetic".')
+tf.app.flags.DEFINE_string('portion', 'valid',
+                           'Which portion to test on: "valid", "test".')
+tf.app.flags.DEFINE_integer('num_examples', 1000, 'Number of test examples.')
+
+tf.app.flags.DEFINE_string('basic_tower', 'dsn_cropped_linemod',
+                           'The basic tower building block.')
+tf.app.flags.DEFINE_bool('enable_precision_recall', False,
+                         'If True, precision and recall for each class will '
+                         'be added to the metrics.')
+tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
+
+
+def quaternion_metric(predictions, labels):
+  params = {'batch_size': FLAGS.batch_size, 'use_logging': False}
+  logcost = losses.log_quaternion_loss_batch(predictions, labels, params)
+  return slim.metrics.streaming_mean(logcost)
+
+
+def angle_diff(true_q, pred_q):
+  angles = 2 * (
+      180.0 /
+      np.pi) * np.arccos(np.abs(np.sum(np.multiply(pred_q, true_q), axis=1)))
+  return angles
+
+
+def main(_):
+  g = tf.Graph()
+  with g.as_default():
+    images, labels = data_provider.provide(FLAGS.dataset, FLAGS.portion,
+                                           FLAGS.batch_size)
+
+    num_classes = labels['classes'].get_shape().as_list()[1]
+
+    # Define the model:
+    with tf.variable_scope('towers'):
+      basic_tower = getattr(models, FLAGS.basic_tower)
+      predictions, endpoints = basic_tower(
+          images,
+          num_classes=num_classes,
+          is_training=False,
+          batch_norm_params=None)
+    metric_names_to_values = {}
+
+    # Define the metrics:
+    if 'quaternions' in labels:  # Also have to evaluate pose estimation!
+      quaternion_loss = quaternion_metric(labels['quaternions'],
+                                          endpoints['quaternion_pred'])
+
+      angle_errors, = tf.py_func(
+          angle_diff, [labels['quaternions'], endpoints['quaternion_pred']],
+          [tf.float32])
+
+      metric_names_to_values[
+          'Angular mean error'] = slim.metrics.streaming_mean(angle_errors)
+      metric_names_to_values['Quaternion Loss'] = quaternion_loss
+
+    accuracy = tf.contrib.metrics.streaming_accuracy(
+        tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
+
+    predictions = tf.argmax(predictions, 1)
+    labels = tf.argmax(labels['classes'], 1)
+    metric_names_to_values['Accuracy'] = accuracy
+
+    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
+        metric_names_to_values)
+
+    # Create the summary ops such that they also print out to std output:
+    summary_ops = []
+    for metric_name, metric_value in names_to_values.iteritems():
+      op = tf.summary.scalar(metric_name, metric_value)
+      op = tf.Print(op, [metric_value], metric_name)
+      summary_ops.append(op)
+
+    # This ensures that we make a single pass over all of the data.
+    num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
+
+    # Setup the global step.
+    slim.get_or_create_global_step()
+    slim.evaluation.evaluation_loop(
+        FLAGS.master,
+        checkpoint_dir=FLAGS.checkpoint_dir,
+        logdir=FLAGS.eval_dir,
+        num_evals=num_batches,
+        eval_op=names_to_updates.values(),
+        summary_op=tf.summary.merge(summary_ops))
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 157 - 0
domain_adaptation/domain_separation/.pipertmp-Ckvhfy-dsn_eval.py

@@ -0,0 +1,157 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# pylint: disable=line-too-long
+r"""Evaluation for Domain Separation Networks (DSNs).
+
+To build locally for CPU:
+  blaze build -c opt --copt=-mavx \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
+
+To build locally for GPU:
+  blaze build -c opt --copt=-mavx --config=cuda_clang \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
+
+To run locally:
+$
+./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval
+\
+    --alsologtostderr
+"""
+# pylint: enable=line-too-long
+import math
+
+import google3
+
+import numpy as np
+import tensorflow as tf
+from google3.robotics.cad_learning.domain_adaptation.fnist import data_provider
+from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import models
+
+slim = tf.contrib.slim
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_integer('batch_size', 50,
+                            'The number of images in each batch.')
+
+tf.app.flags.DEFINE_string('master', 'local',
+                           'BNS name of the TensorFlow master to use.')
+
+tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/da/',
+                           'Directory where the model was written to.')
+
+tf.app.flags.DEFINE_string(
+    'eval_dir', '/tmp/da/',
+    'Directory where we should write the tf summaries to.')
+
+tf.app.flags.DEFINE_string(
+    'dataset', 'pose_real',
+    'Which dataset to test on: "pose_real", "pose_synthetic".')
+tf.app.flags.DEFINE_string('portion', 'valid',
+                           'Which portion to test on: "valid", "test".')
+tf.app.flags.DEFINE_integer('num_examples', 1000, 'Number of test examples.')
+
+tf.app.flags.DEFINE_string('basic_tower', 'pose_mini',
+                           'The basic tower building block.')
+tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
+
+
+def quaternion_metric(predictions, labels):
+  product = tf.multiply(predictions, labels)
+  internal_dot_products = tf.reduce_sum(product, [1])
+  logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
+  return tf.contrib.metrics.streaming_mean(logcost)
+
+
+def to_degrees(predictions, labels):
+  """Converts a log quaternion distance to an angle.
+
+  Args:
+    log_quaternion_loss: The log quaternion distance between two
+      unit quaternions (or a batch of pairs of quaternions).
+
+  Returns:
+    The angle in degrees of the implied angle-axis representation.
+  """
+  product = tf.multiply(predictions, labels)
+  internal_dot_products = tf.reduce_sum(product, [1])
+  log_quaternion_loss = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
+  angle_loss = tf.acos(-(tf.exp(log_quaternion_loss) - 1)) * 2 * 180 / math.pi
+  return tf.contrib.metrics.streaming_mean(angle_loss)
+
+
+def main(_):
+  g = tf.Graph()
+  with g.as_default():
+    images, labels = data_provider.provide(FLAGS.dataset, FLAGS.portion,
+                                           FLAGS.batch_size)
+
+    num_classes = labels['classes'].shape[1]
+
+    # Define the model:
+    with tf.variable_scope('towers'):
+      basic_tower = models.provide(FLAGS.basic_tower)
+      predictions, endpoints = basic_tower(
+          images, is_training=False, num_classes=num_classes)
+    names_to_values = {}
+    names_to_updates = {}
+    # Define the metrics:
+    if 'quaternions' in labels:  # Also have to evaluate pose estimation!
+      quaternion_loss = quaternion_metric(labels['quaternions'],
+                                          endpoints['quaternion_pred'])
+
+      metric_name = 'Angle Mean Error'
+      names_to_values[metric_name], names_to_updates[metric_name] = to_degrees(
+          labels['quaternions'], endpoints['quaternion_pred'])
+
+      metric_name = 'Log Quaternion Error'
+      names_to_values[metric_name], names_to_updates[
+          metric_name] = quaternion_metric(labels['quaternions'],
+                                           endpoints['quaternion_pred'])
+      metric_name = 'Accuracy'
+      names_to_values[metric_name], names_to_updates[
+          metric_name] = tf.contrib.metrics.streaming_accuracy(
+              tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
+
+    metric_name = 'Accuracy'
+    names_to_values[metric_name], names_to_updates[
+        metric_name] = tf.contrib.metrics.streaming_accuracy(
+            tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
+
+    # Create the summary ops such that they also print out to std output:
+    summary_ops = []
+    for metric_name, metric_value in names_to_values.iteritems():
+      op = tf.contrib.deprecated.scalar_summary(metric_name, metric_value)
+      op = tf.Print(op, [metric_value], metric_name)
+      summary_ops.append(op)
+
+    # This ensures that we make a single pass over all of the data.
+    num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
+
+    # Setup the global step.
+    slim.get_or_create_global_step()
+
+    slim.evaluation.evaluation_loop(
+        FLAGS.master,
+        checkpoint_dir=FLAGS.checkpoint_dir,
+        logdir=FLAGS.eval_dir,
+        num_evals=num_batches,
+        eval_op=names_to_updates.values(),
+        summary_op=tf.contrib.deprecated.merge_summary(summary_ops))
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 214 - 0
domain_adaptation/domain_separation/.pipertmp-OiMpXz-dsn_eval.py

@@ -0,0 +1,214 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# pylint: disable=line-too-long
+r"""Evaluation for Domain Separation Networks (DSNs).
+
+To build locally for CPU:
+  blaze build -c opt --copt=-mavx \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
+
+To build locally for GPU:
+  blaze build -c opt --copt=-mavx --config=cuda_clang \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
+
+To run locally:
+$
+./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval
+\
+    --alsologtostderr
+"""
+# pylint: enable=line-too-long
+import math
+
+import google3
+
+import numpy as np
+import tensorflow as tf
+
+from google3.third_party.tensorflow_models.domain_adaptation.datasets import dataset_factory
+from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import losses
+from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import losses
+from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import models
+
+slim = tf.contrib.slim
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_integer('batch_size', 32,
+                            'The number of images in each batch.')
+
+tf.app.flags.DEFINE_string('master', '',
+                           'BNS name of the TensorFlow master to use.')
+
+tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/da/',
+                           'Directory where the model was written to.')
+
+tf.app.flags.DEFINE_string(
+    'eval_dir', '/tmp/da/',
+    'Directory where we should write the tf summaries to.')
+
+tf.app.flags.DEFINE_string('dataset_dir', None,
+                           'The directory where the dataset files are stored.')
+
+tf.app.flags.DEFINE_string('dataset', 'mnist_m',
+                           'Which dataset to test on: "mnist", "mnist_m".')
+
+tf.app.flags.DEFINE_string('split', 'valid',
+                           'Which portion to test on: "valid", "test".')
+
+tf.app.flags.DEFINE_integer('num_examples', 1000, 'Number of test examples.')
+
+>>>> ORIGINAL //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#5
+tf.app.flags.DEFINE_string('basic_tower', 'pose_mini',
+==== THEIRS //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#6
+tf.app.flags.DEFINE_string('basic_tower', 'dsn_cropped_linemod',
+==== YOURS //konstantinos:opensource:883:citc/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py
+tf.app.flags.DEFINE_string('basic_tower', 'dann_mnist',
+<<<<
+                           'The basic tower building block.')
+>>>> ORIGINAL //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#5
+==== THEIRS //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#6
+tf.app.flags.DEFINE_bool('enable_precision_recall', False,
+                         'If True, precision and recall for each class will '
+                         'be added to the metrics.')
+==== YOURS //konstantinos:opensource:883:citc/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py
+
+tf.app.flags.DEFINE_bool('enable_precision_recall', False,
+                         'If True, precision and recall for each class will '
+                         'be added to the metrics.')
+
+<<<<
+tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
+
+
+def quaternion_metric(predictions, labels):
+  params = {'batch_size': FLAGS.batch_size, 'use_logging': False}
+  logcost = losses.log_quaternion_loss_batch(predictions, labels, params)
+  return slim.metrics.streaming_mean(logcost)
+
+
+def angle_diff(true_q, pred_q):
+  angles = 2 * (
+      180.0 /
+      np.pi) * np.arccos(np.abs(np.sum(np.multiply(pred_q, true_q), axis=1)))
+  return angles
+
+
+>>>> ORIGINAL //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#5
+  Returns:
+    The angle in degrees of the implied angle-axis representation.
+  """
+  product = tf.multiply(predictions, labels)
+  internal_dot_products = tf.reduce_sum(product, [1])
+  log_quaternion_loss = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
+  angle_loss = tf.acos(-(tf.exp(log_quaternion_loss) - 1)) * 2 * 180 / math.pi
+  return tf.contrib.metrics.streaming_mean(angle_loss)
+
+
+==== THEIRS //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#6
+==== YOURS //konstantinos:opensource:883:citc/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py
+def provide_batch_fn():
+  """ The provide_batch function to use. """
+  return dataset_factory.provide_batch
+
+
+<<<<
+def main(_):
+  g = tf.Graph()
+  with g.as_default():
+    # Load the data.
+    images, labels = provide_batch_fn()(
+        FLAGS.dataset, FLAGS.split, FLAGS.dataset_dir, 4, FLAGS.batch_size, 4)
+
+    num_classes = labels['classes'].get_shape().as_list()[1]
+
+    tf.summary.image('eval_images', images, max_outputs=3)
+
+    # Define the model:
+    with tf.variable_scope('towers'):
+      basic_tower = getattr(models, FLAGS.basic_tower)
+      predictions, endpoints = basic_tower(
+          images,
+          num_classes=num_classes,
+          is_training=False,
+          batch_norm_params=None)
+    metric_names_to_values = {}
+
+    # Define the metrics:
+    if 'quaternions' in labels:  # Also have to evaluate pose estimation!
+      quaternion_loss = quaternion_metric(labels['quaternions'],
+                                          endpoints['quaternion_pred'])
+
+      angle_errors, = tf.py_func(
+          angle_diff, [labels['quaternions'], endpoints['quaternion_pred']],
+          [tf.float32])
+
+      metric_names_to_values[
+          'Angular mean error'] = slim.metrics.streaming_mean(angle_errors)
+      metric_names_to_values['Quaternion Loss'] = quaternion_loss
+
+    accuracy = tf.contrib.metrics.streaming_accuracy(
+        tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
+
+>>>> ORIGINAL //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#5
+==== THEIRS //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#6
+    predictions = tf.argmax(predictions, 1)
+    labels = tf.argmax(labels['classes'], 1)
+    metric_names_to_values['Accuracy'] = accuracy
+
+    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
+        metric_names_to_values)
+
+==== YOURS //konstantinos:opensource:883:citc/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py
+    predictions = tf.argmax(predictions, 1)
+    labels = tf.argmax(labels['classes'], 1)
+    metric_names_to_values['Accuracy'] = accuracy
+    for i in xrange(num_classes):
+      index_map = tf.one_hot(i, depth=num_classes)
+      name = 'PR/Precision_{}'.format(i)
+      metric_names_to_values[name] = slim.metrics.streaming_precision(
+          tf.gather(index_map, predictions), tf.gather(index_map, labels))
+      name = 'PR/Recall_{}'.format(i)
+      metric_names_to_values[name] = slim.metrics.streaming_recall(
+          tf.gather(index_map, predictions), tf.gather(index_map, labels))
+
+    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
+        metric_names_to_values)
+
+<<<<
+    # Create the summary ops such that they also print out to std output:
+    summary_ops = []
+    for metric_name, metric_value in names_to_values.iteritems():
+      op = tf.summary.scalar(metric_name, metric_value)
+      op = tf.Print(op, [metric_value], metric_name)
+      summary_ops.append(op)
+
+    # This ensures that we make a single pass over all of the data.
+    num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
+
+    # Setup the global step.
+    slim.get_or_create_global_step()
+    slim.evaluation.evaluation_loop(
+        FLAGS.master,
+        checkpoint_dir=FLAGS.checkpoint_dir,
+        logdir=FLAGS.eval_dir,
+        num_evals=num_batches,
+        eval_op=names_to_updates.values(),
+        summary_op=tf.summary.merge(summary_ops))
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 152 - 0
domain_adaptation/domain_separation/.pipertmp-WMYPqp-dsn_eval.py

@@ -0,0 +1,152 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# pylint: disable=line-too-long
+r"""Evaluation for Domain Separation Networks (DSNs).
+
+To build locally for CPU:
+  blaze build -c opt --copt=-mavx \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
+
+To build locally for GPU:
+  blaze build -c opt --copt=-mavx --config=cuda_clang \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
+
+To run locally:
+$
+./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval
+\
+    --alsologtostderr
+"""
+# pylint: enable=line-too-long
+import math
+
+import google3
+
+import numpy as np
+import tensorflow as tf
+from google3.robotics.cad_learning.domain_adaptation.fnist import data_provider
+from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import losses
+from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import models
+
+slim = tf.contrib.slim
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_integer('batch_size', 32,
+                            'The number of images in each batch.')
+
+tf.app.flags.DEFINE_string('master', 'local',
+                           'BNS name of the TensorFlow master to use.')
+
+tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/da/',
+                           'Directory where the model was written to.')
+
+tf.app.flags.DEFINE_string(
+    'eval_dir', '/tmp/da/',
+    'Directory where we should write the tf summaries to.')
+
+tf.app.flags.DEFINE_string(
+    'dataset', 'pose_real',
+    'Which dataset to test on: "pose_real", "pose_synthetic".')
+tf.app.flags.DEFINE_string('portion', 'valid',
+                           'Which portion to test on: "valid", "test".')
+tf.app.flags.DEFINE_integer('num_examples', 1000, 'Number of test examples.')
+
+tf.app.flags.DEFINE_string('basic_tower', 'dsn_cropped_linemod',
+                           'The basic tower building block.')
+tf.app.flags.DEFINE_bool('enable_precision_recall', False,
+                         'If True, precision and recall for each class will '
+                         'be added to the metrics.')
+tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
+
+
+def quaternion_metric(predictions, labels):
+  params = {'batch_size': FLAGS.batch_size, 'use_logging': False}
+  logcost = losses.log_quaternion_loss_batch(predictions, labels, params)
+  return slim.metrics.streaming_mean(logcost)
+
+
+def angle_diff(true_q, pred_q):
+  angles = 2 * (
+      180.0 /
+      np.pi) * np.arccos(np.abs(np.sum(np.multiply(pred_q, true_q), axis=1)))
+  return angles
+
+
+def main(_):
+  g = tf.Graph()
+  with g.as_default():
+    images, labels = data_provider.provide(FLAGS.dataset, FLAGS.portion,
+                                           FLAGS.batch_size)
+
+    num_classes = labels['classes'].get_shape().as_list()[1]
+
+    # Define the model:
+    with tf.variable_scope('towers'):
+      basic_tower = getattr(models, FLAGS.basic_tower)
+      predictions, endpoints = basic_tower(
+          images,
+          num_classes=num_classes,
+          is_training=False,
+          batch_norm_params=None)
+    metric_names_to_values = {}
+
+    # Define the metrics:
+    if 'quaternions' in labels:  # Also have to evaluate pose estimation!
+      quaternion_loss = quaternion_metric(labels['quaternions'],
+                                          endpoints['quaternion_pred'])
+
+      angle_errors, = tf.py_func(
+          angle_diff, [labels['quaternions'], endpoints['quaternion_pred']],
+          [tf.float32])
+
+      metric_names_to_values[
+          'Angular mean error'] = slim.metrics.streaming_mean(angle_errors)
+      metric_names_to_values['Quaternion Loss'] = quaternion_loss
+
+    accuracy = tf.contrib.metrics.streaming_accuracy(
+        tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
+
+    predictions = tf.argmax(predictions, 1)
+    labels = tf.argmax(labels['classes'], 1)
+    metric_names_to_values['Accuracy'] = accuracy
+
+    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
+        metric_names_to_values)
+
+    # Create the summary ops such that they also print out to std output:
+    summary_ops = []
+    for metric_name, metric_value in names_to_values.iteritems():
+      op = tf.summary.scalar(metric_name, metric_value)
+      op = tf.Print(op, [metric_value], metric_name)
+      summary_ops.append(op)
+
+    # This ensures that we make a single pass over all of the data.
+    num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
+
+    # Setup the global step.
+    slim.get_or_create_global_step()
+    slim.evaluation.evaluation_loop(
+        FLAGS.master,
+        checkpoint_dir=FLAGS.checkpoint_dir,
+        logdir=FLAGS.eval_dir,
+        num_evals=num_batches,
+        eval_op=names_to_updates.values(),
+        summary_op=tf.summary.merge(summary_ops))
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 229 - 0
domain_adaptation/domain_separation/.pipertmp-son4h0-dsn_eval.py

@@ -0,0 +1,229 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# pylint: disable=line-too-long
+r"""Evaluation for Domain Separation Networks (DSNs).
+
+To build locally for CPU:
+  blaze build -c opt --copt=-mavx \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
+
+To build locally for GPU:
+  blaze build -c opt --copt=-mavx --config=cuda_clang \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
+
+To run locally:
+$
+./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval
+\
+    --alsologtostderr
+"""
+# pylint: enable=line-too-long
+import math
+
+import google3
+
+import numpy as np
+import tensorflow as tf
+
+from google3.third_party.tensorflow_models.domain_adaptation.datasets import dataset_factory
+from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import losses
+from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import losses
+from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import models
+
+slim = tf.contrib.slim
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_integer('batch_size', 32,
+                            'The number of images in each batch.')
+
+tf.app.flags.DEFINE_string('master', '',
+                           'BNS name of the TensorFlow master to use.')
+
+tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/da/',
+                           'Directory where the model was written to.')
+
+tf.app.flags.DEFINE_string(
+    'eval_dir', '/tmp/da/',
+    'Directory where we should write the tf summaries to.')
+
+tf.app.flags.DEFINE_string('dataset_dir', None,
+                           'The directory where the dataset files are stored.')
+
+tf.app.flags.DEFINE_string('dataset', 'mnist_m',
+                           'Which dataset to test on: "mnist", "mnist_m".')
+
+tf.app.flags.DEFINE_string('split', 'valid',
+                           'Which portion to test on: "valid", "test".')
+
+tf.app.flags.DEFINE_integer('num_examples', 1000, 'Number of test examples.')
+
+>>>> ORIGINAL //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#5
+tf.app.flags.DEFINE_string('basic_tower', 'pose_mini',
+==== THEIRS //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#6
+tf.app.flags.DEFINE_string('basic_tower', 'dsn_cropped_linemod',
+==== YOURS //konstantinos:opensource:883:citc/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py
+tf.app.flags.DEFINE_string('basic_tower', 'dann_mnist',
+<<<<
+                           'The basic tower building block.')
+>>>> ORIGINAL //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#5
+==== THEIRS //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#6
+tf.app.flags.DEFINE_bool('enable_precision_recall', False,
+                         'If True, precision and recall for each class will '
+                         'be added to the metrics.')
+==== YOURS //konstantinos:opensource:883:citc/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py
+
+tf.app.flags.DEFINE_bool('enable_precision_recall', False,
+                         'If True, precision and recall for each class will '
+                         'be added to the metrics.')
+
+<<<<
+tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
+
+
+def quaternion_metric(predictions, labels):
+  params = {'batch_size': FLAGS.batch_size, 'use_logging': False}
+  logcost = losses.log_quaternion_loss_batch(predictions, labels, params)
+  return slim.metrics.streaming_mean(logcost)
+
+
+def angle_diff(true_q, pred_q):
+  angles = 2 * (
+      180.0 /
+      np.pi) * np.arccos(np.abs(np.sum(np.multiply(pred_q, true_q), axis=1)))
+  return angles
+
+
+>>>> ORIGINAL //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#5
+  Returns:
+    The angle in degrees of the implied angle-axis representation.
+  """
+  product = tf.multiply(predictions, labels)
+  internal_dot_products = tf.reduce_sum(product, [1])
+  log_quaternion_loss = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
+  angle_loss = tf.acos(-(tf.exp(log_quaternion_loss) - 1)) * 2 * 180 / math.pi
+  return tf.contrib.metrics.streaming_mean(angle_loss)
+
+
+==== THEIRS //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#6
+==== YOURS //konstantinos:opensource:883:citc/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py
+def provide_batch_fn():
+  """ The provide_batch function to use. """
+  return dataset_factory.provide_batch
+
+
+<<<<
+def main(_):
+  g = tf.Graph()
+  with g.as_default():
+    # Load the data.
+    images, labels = provide_batch_fn()(
+        FLAGS.dataset, FLAGS.split, FLAGS.dataset_dir, 4, FLAGS.batch_size, 4)
+
+    num_classes = labels['classes'].get_shape().as_list()[1]
+
+    tf.summary.image('eval_images', images, max_outputs=3)
+
+    # Define the model:
+    with tf.variable_scope('towers'):
+      basic_tower = getattr(models, FLAGS.basic_tower)
+      predictions, endpoints = basic_tower(
+          images,
+          num_classes=num_classes,
+          is_training=False,
+          batch_norm_params=None)
+    metric_names_to_values = {}
+
+    # Define the metrics:
+    if 'quaternions' in labels:  # Also have to evaluate pose estimation!
+      quaternion_loss = quaternion_metric(labels['quaternions'],
+                                          endpoints['quaternion_pred'])
+
+      angle_errors, = tf.py_func(
+          angle_diff, [labels['quaternions'], endpoints['quaternion_pred']],
+          [tf.float32])
+
+>>>> ORIGINAL //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#5
+      metric_name = 'Log Quaternion Error'
+      names_to_values[metric_name], names_to_updates[
+          metric_name] = quaternion_metric(labels['quaternions'],
+                                           endpoints['quaternion_pred'])
+      metric_name = 'Accuracy'
+      names_to_values[metric_name], names_to_updates[
+          metric_name] = tf.contrib.metrics.streaming_accuracy(
+              tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
+==== THEIRS //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#6
+      metric_names_to_values[
+          'Angular mean error'] = slim.metrics.streaming_mean(angle_errors)
+      metric_names_to_values['Quaternion Loss'] = quaternion_loss
+==== YOURS //konstantinos:opensource:883:citc/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py
+      metric_names_to_values['Angular mean error'] = slim.metrics.mean(
+          angle_errors)
+      metric_names_to_values['Quaternion Loss'] = quaternion_loss
+<<<<
+
+    accuracy = tf.contrib.metrics.streaming_accuracy(
+        tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
+
+>>>> ORIGINAL //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#5
+==== THEIRS //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#6
+    predictions = tf.argmax(predictions, 1)
+    labels = tf.argmax(labels['classes'], 1)
+    metric_names_to_values['Accuracy'] = accuracy
+
+    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
+        metric_names_to_values)
+
+==== YOURS //konstantinos:opensource:883:citc/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py
+    predictions = tf.argmax(predictions, 1)
+    labels = tf.argmax(labels['classes'], 1)
+    metric_names_to_values['Accuracy'] = accuracy
+    for i in xrange(num_classes):
+      index_map = tf.one_hot(i, depth=num_classes)
+      name = 'PR/Precision_{}'.format(i)
+      metric_names_to_values[name] = slim.metrics.streaming_precision(
+          tf.gather(index_map, predictions), tf.gather(index_map, labels))
+      name = 'PR/Recall_{}'.format(i)
+      metric_names_to_values[name] = slim.metrics.streaming_recall(
+          tf.gather(index_map, predictions), tf.gather(index_map, labels))
+
+    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
+        metric_names_to_values)
+
+<<<<
+    # Create the summary ops such that they also print out to std output:
+    summary_ops = []
+    for metric_name, metric_value in names_to_values.iteritems():
+      op = tf.summary.scalar(metric_name, metric_value)
+      op = tf.Print(op, [metric_value], metric_name)
+      summary_ops.append(op)
+
+    # This ensures that we make a single pass over all of the data.
+    num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
+
+    # Setup the global step.
+    slim.get_or_create_global_step()
+    slim.evaluation.evaluation_loop(
+        FLAGS.master,
+        checkpoint_dir=FLAGS.checkpoint_dir,
+        logdir=FLAGS.eval_dir,
+        num_evals=num_batches,
+        eval_op=names_to_updates.values(),
+        summary_op=tf.summary.merge(summary_ops))
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 185 - 0
domain_adaptation/domain_separation/BUILD

@@ -0,0 +1,185 @@
+# Domain Separation Networks
+
+package(
+    default_visibility = [
+        ":internal",
+    ],
+)
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package_group(
+    name = "internal",
+    packages = [
+        "//domain_adaptation/...",
+    ],
+)
+
+py_library(
+    name = "models",
+    srcs = [
+        "models.py",
+    ],
+    deps = [
+        ":utils",
+    ],
+)
+
+py_library(
+    name = "losses",
+    srcs = [
+        "losses.py",
+    ],
+    deps = [
+        ":grl_op_grads_py",
+        # ":grl_op_kernels",
+        ":grl_op_shapes_py",
+        ":grl_ops",
+        # ":grl_ops_py",
+        ":utils",
+    ],
+)
+
+py_test(
+    name = "losses_test",
+    srcs = [
+        "losses_test.py",
+    ],
+    deps = [
+        ":losses",
+        ":utils",
+    ],
+)
+
+py_library(
+    name = "dsn",
+    srcs = [
+        "dsn.py",
+    ],
+    deps = [
+        ":grl_op_grads_py",
+        #":grl_op_kernels",
+        ":grl_op_shapes_py",
+        ":grl_ops",
+        #":grl_ops_py",
+        ":losses",
+        ":models",
+        ":utils",
+    ],
+)
+
+py_test(
+    name = "dsn_test",
+    srcs = [
+        "dsn_test.py",
+    ],
+    deps = [
+        ":dsn",
+    ],
+)
+
+py_binary(
+    name = "dsn_train",
+    srcs = [
+        "dsn_train.py",
+    ],
+    deps = [
+        ":dsn",
+        ":models",
+        "//domain_adaptation/datasets:dataset_factory",
+    ],
+)
+
+py_binary(
+    name = "dsn_eval",
+    srcs = [
+        "dsn_eval.py",
+    ],
+    deps = [
+        ":dsn",
+        ":models",
+        "//domain_adaptation/datasets:dataset_factory",
+    ],
+)
+
+py_test(
+    name = "models_test",
+    srcs = [
+        "models_test.py",
+    ],
+    deps = [
+        ":models",
+        "//domain_adaptation/datasets:dataset_factory",
+    ],
+)
+
+py_library(
+    name = "utils",
+    srcs = [
+        "utils.py",
+    ],
+    deps = [
+    ],
+)
+
+py_library(
+    name = "grl_op_grads_py",
+    srcs = [
+        "grl_op_grads.py",
+    ],
+    deps = [
+        ":grl_ops",
+    ],
+)
+
+py_library(
+    name = "grl_op_shapes_py",
+    srcs = [
+        "grl_op_shapes.py",
+    ],
+    deps = [
+    ],
+)
+
+py_library(
+    name = "grl_ops",
+    srcs = ["grl_ops.py"],
+    data = ["_grl_ops.so"],
+)
+#cc_library(
+#    name = "grl_ops",
+#    srcs = ["grl_ops.cc"],
+#    deps = ["//tensorflow/core:framework"],
+#   alwayslink = 1,
+#)
+
+#tf_gen_op_wrapper_py(
+#    name = "grl_ops_py",
+#    out = "grl_ops.py",
+#    deps = [":grl_ops"],
+#)
+
+#cc_library(
+#    name = "grl_op_kernels",
+#    srcs = ["grl_op_kernels.cc"],
+#    deps = [
+#        "//tensorflow/core:framework",
+#        "//tensorflow/core:protos_all",
+#    ],
+#    alwayslink = 1,
+#)
+
+py_test(
+    name = "grl_ops_test",
+    size = "small",
+    srcs = ["grl_ops_test.py"],
+    deps = [
+        ":grl_op_grads_py",
+        #   ":grl_op_kernels",
+        ":grl_op_shapes_py",
+        ":grl_ops",
+        #":grl_ops_py",
+    ],
+)

+ 41 - 0
domain_adaptation/domain_separation/README.md

@@ -0,0 +1,41 @@
+# Domain Seperation Networks
+
+## Introduction
+This code is the code used for the "Domain Separation Networks" paper
+by Bousmalis K., Trigeorgis G., et al. which was presented at NIPS 2016. The
+paper can be found here: https://arxiv.org/abs/1608.06019
+
+## Contact
+This code was open-sourced by Konstantinos Bousmalis (konstantinos@google.com, github:bousmalis)
+
+## Installation
+You will need to have the following installed on your machine before trying out the DSN code.
+
+*  Tensorflow: https://www.tensorflow.org/install/
+*  Bazel: https://bazel.build/
+
+## Running the code for adapting MNIST to MNIST-M
+In order to run the MNIST to MNIST-M experiments with DANNs and/or DANNs with
+domain separation (DSNs) you will need to set the directory you used to download
+MNIST and MNIST-M:
+$ export DSN_DATA_DIR=/your/dir
+
+Then you need to build the binaries with Bazel:
+
+$ bazel build -c opt domain_adaptation/domain_separation/...
+
+You can then train with the following command:
+
+$ ./bazel-bin/domain_adaptation/domain_separation/dsn_train \
+      --similarity_loss=dann_loss  \
+      --basic_tower=dann_mnist  \
+      --source_dataset=mnist  \
+      --target_dataset=mnist_m  \
+      --learning_rate=0.0117249  \
+      --gamma_weight=0.251175  \
+      --weight_decay=1e-6  \
+      --layers_to_regularize=fc3  \
+      --nouse_separation  \
+      --master=""  \
+      --dataset_dir=${DSN_DATA_DIR}  \
+      -v --use_logging

+ 0 - 0
domain_adaptation/domain_separation/__init__.py


二進制
domain_adaptation/domain_separation/_grl_ops.so


+ 353 - 0
domain_adaptation/domain_separation/dsn.py

@@ -0,0 +1,353 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions to create a DSN model and add the different losses to it.
+
+Specifically, in this file we define the:
+  - Shared Encoding Similarity Loss Module, with:
+    - The MMD Similarity method
+    - The Correlation Similarity method
+    - The Gradient Reversal (Domain-Adversarial) method
+  - Difference Loss Module
+  - Reconstruction Loss Module
+  - Task Loss Module
+"""
+from functools import partial
+
+import tensorflow as tf
+
+import losses
+import models
+import utils
+
+slim = tf.contrib.slim
+
+
+################################################################################
+# HELPER FUNCTIONS
+################################################################################
+def dsn_loss_coefficient(params):
+  """The global_step-dependent weight that specifies when to kick in DSN losses.
+
+  Args:
+    params: A dictionary of parameters. Expecting 'domain_separation_startpoint'
+
+  Returns:
+    A weight to that effectively enables or disables the DSN-related losses,
+    i.e. similarity, difference, and reconstruction losses.
+  """
+  return tf.where(
+      tf.less(slim.get_or_create_global_step(),
+              params['domain_separation_startpoint']), 1e-10, 1.0)
+
+
+################################################################################
+# MODEL CREATION
+################################################################################
+def create_model(source_images, source_labels, domain_selection_mask,
+                 target_images, target_labels, similarity_loss, params,
+                 basic_tower_name):
+  """Creates a DSN model.
+
+  Args:
+    source_images: images from the source domain, a tensor of size
+      [batch_size, height, width, channels]
+    source_labels: a dictionary with the name, tensor pairs. 'classes' is one-
+      hot for the number of classes.
+    domain_selection_mask: a boolean tensor of size [batch_size, ] which denotes
+      the labeled images that belong to the source domain.
+    target_images: images from the target domain, a tensor of size
+      [batch_size, height width, channels].
+    target_labels: a dictionary with the name, tensor pairs.
+    similarity_loss: The type of method to use for encouraging
+      the codes from the shared encoder to be similar.
+    params: A dictionary of parameters. Expecting 'weight_decay',
+      'layers_to_regularize', 'use_separation', 'domain_separation_startpoint',
+      'alpha_weight', 'beta_weight', 'gamma_weight', 'recon_loss_name',
+      'decoder_name', 'encoder_name'
+    basic_tower_name: the name of the tower to use for the shared encoder.
+
+  Raises:
+    ValueError: if the arch is not one of the available architectures.
+  """
+  network = getattr(models, basic_tower_name)
+  num_classes = source_labels['classes'].get_shape().as_list()[1]
+
+  # Make sure we are using the appropriate number of classes.
+  network = partial(network, num_classes=num_classes)
+
+  # Add the classification/pose estimation loss to the source domain.
+  source_endpoints = add_task_loss(source_images, source_labels, network,
+                                   params)
+
+  if similarity_loss == 'none':
+    # No domain adaptation, we can stop here.
+    return
+
+  with tf.variable_scope('towers', reuse=True):
+    target_logits, target_endpoints = network(
+        target_images, weight_decay=params['weight_decay'], prefix='target')
+
+  # Plot target accuracy of the train set.
+  target_accuracy = utils.accuracy(
+      tf.argmax(target_logits, 1), tf.argmax(target_labels['classes'], 1))
+
+  if 'quaternions' in target_labels:
+    target_quaternion_loss = losses.log_quaternion_loss(
+        target_labels['quaternions'], target_endpoints['quaternion_pred'],
+        params)
+    tf.summary.scalar('eval/Target quaternions', target_quaternion_loss)
+
+  tf.summary.scalar('eval/Target accuracy', target_accuracy)
+
+  source_shared = source_endpoints[params['layers_to_regularize']]
+  target_shared = target_endpoints[params['layers_to_regularize']]
+
+  # When using the semisupervised model we include labeled target data in the
+  # source classifier. We do not want to include these target domain when
+  # we use the similarity loss.
+  indices = tf.range(0, source_shared.get_shape().as_list()[0])
+  indices = tf.boolean_mask(indices, domain_selection_mask)
+  add_similarity_loss(similarity_loss,
+                      tf.gather(source_shared, indices),
+                      tf.gather(target_shared, indices), params)
+
+  if params['use_separation']:
+    add_autoencoders(
+        source_images,
+        source_shared,
+        target_images,
+        target_shared,
+        params=params,)
+
+
+def add_similarity_loss(method_name,
+                        source_samples,
+                        target_samples,
+                        params,
+                        scope=None):
+  """Adds a loss encouraging the shared encoding from each domain to be similar.
+
+  Args:
+    method_name: the name of the encoding similarity method to use. Valid
+      options include `dann_loss', `mmd_loss' or `correlation_loss'.
+    source_samples: a tensor of shape [num_samples, num_features].
+    target_samples: a tensor of shape [num_samples, num_features].
+    params: a dictionary of parameters. Expecting 'gamma_weight'.
+    scope: optional name scope for summary tags.
+  Raises:
+    ValueError: if `method_name` is not recognized.
+  """
+  weight = dsn_loss_coefficient(params) * params['gamma_weight']
+  method = getattr(losses, method_name)
+  method(source_samples, target_samples, weight, scope)
+
+
+def add_reconstruction_loss(recon_loss_name, images, recons, weight, domain):
+  """Adds a reconstruction loss.
+
+  Args:
+    recon_loss_name: The name of the reconstruction loss.
+    images: A `Tensor` of size [batch_size, height, width, 3].
+    recons: A `Tensor` whose size matches `images`.
+    weight: A scalar coefficient for the loss.
+    domain: The name of the domain being reconstructed.
+
+  Raises:
+    ValueError: If `recon_loss_name` is not recognized.
+  """
+  if recon_loss_name == 'sum_of_pairwise_squares':
+    loss_fn = tf.contrib.losses.mean_pairwise_squared_error
+  elif recon_loss_name == 'sum_of_squares':
+    loss_fn = tf.contrib.losses.mean_squared_error
+  else:
+    raise ValueError('recon_loss_name value [%s] not recognized.' %
+                     recon_loss_name)
+
+  loss = loss_fn(recons, images, weight)
+  assert_op = tf.Assert(tf.is_finite(loss), [loss])
+  with tf.control_dependencies([assert_op]):
+    tf.summary.scalar('losses/%s Recon Loss' % domain, loss)
+
+
+def add_autoencoders(source_data, source_shared, target_data, target_shared,
+                     params):
+  """Adds the encoders/decoders for our domain separation model w/ incoherence.
+
+  Args:
+    source_data: images from the source domain, a tensor of size
+      [batch_size, height, width, channels]
+    source_shared: a tensor with first dimension batch_size
+    target_data: images from the target domain, a tensor of size
+      [batch_size, height, width, channels]
+    target_shared: a tensor with first dimension batch_size
+    params: A dictionary of parameters. Expecting 'layers_to_regularize',
+      'beta_weight', 'alpha_weight', 'recon_loss_name', 'decoder_name',
+      'encoder_name', 'weight_decay'
+  """
+
+  def normalize_images(images):
+    images -= tf.reduce_min(images)
+    return images / tf.reduce_max(images)
+
+  def concat_operation(shared_repr, private_repr):
+    return shared_repr + private_repr
+
+  mu = dsn_loss_coefficient(params)
+
+  # The layer to concatenate the networks at.
+  concat_layer = params['layers_to_regularize']
+
+  # The coefficient for modulating the private/shared difference loss.
+  difference_loss_weight = params['beta_weight'] * mu
+
+  # The reconstruction weight.
+  recon_loss_weight = params['alpha_weight'] * mu
+
+  # The reconstruction loss to use.
+  recon_loss_name = params['recon_loss_name']
+
+  # The decoder/encoder to use.
+  decoder_name = params['decoder_name']
+  encoder_name = params['encoder_name']
+
+  _, height, width, _ = source_data.get_shape().as_list()
+  code_size = source_shared.get_shape().as_list()[-1]
+  weight_decay = params['weight_decay']
+
+  encoder_fn = getattr(models, encoder_name)
+  # Target Auto-encoding.
+  with tf.variable_scope('source_encoder'):
+    source_endpoints = encoder_fn(
+        source_data, code_size, weight_decay=weight_decay)
+
+  with tf.variable_scope('target_encoder'):
+    target_endpoints = encoder_fn(
+        target_data, code_size, weight_decay=weight_decay)
+
+  decoder_fn = getattr(models, decoder_name)
+
+  decoder = partial(
+      decoder_fn,
+      height=height,
+      width=width,
+      channels=source_data.get_shape().as_list()[-1],
+      weight_decay=weight_decay)
+
+  # Source Auto-encoding.
+  source_private = source_endpoints[concat_layer]
+  target_private = target_endpoints[concat_layer]
+  with tf.variable_scope('decoder'):
+    source_recons = decoder(concat_operation(source_shared, source_private))
+
+  with tf.variable_scope('decoder', reuse=True):
+    source_private_recons = decoder(
+        concat_operation(tf.zeros_like(source_private), source_private))
+    source_shared_recons = decoder(
+        concat_operation(source_shared, tf.zeros_like(source_shared)))
+
+  with tf.variable_scope('decoder', reuse=True):
+    target_recons = decoder(concat_operation(target_shared, target_private))
+    target_shared_recons = decoder(
+        concat_operation(target_shared, tf.zeros_like(target_shared)))
+    target_private_recons = decoder(
+        concat_operation(tf.zeros_like(target_private), target_private))
+
+  losses.difference_loss(
+      source_private,
+      source_shared,
+      weight=difference_loss_weight,
+      name='Source')
+  losses.difference_loss(
+      target_private,
+      target_shared,
+      weight=difference_loss_weight,
+      name='Target')
+
+  add_reconstruction_loss(recon_loss_name, source_data, source_recons,
+                          recon_loss_weight, 'source')
+  add_reconstruction_loss(recon_loss_name, target_data, target_recons,
+                          recon_loss_weight, 'target')
+
+  # Add summaries
+  source_reconstructions = tf.concat(
+      map(normalize_images, [
+          source_data, source_recons, source_shared_recons,
+          source_private_recons
+      ]), 2)
+  target_reconstructions = tf.concat(
+      map(normalize_images, [
+          target_data, target_recons, target_shared_recons,
+          target_private_recons
+      ]), 2)
+  tf.summary.image(
+      'Source Images:Recons:RGB',
+      source_reconstructions[:, :, :, :3],
+      max_outputs=10)
+  tf.summary.image(
+      'Target Images:Recons:RGB',
+      target_reconstructions[:, :, :, :3],
+      max_outputs=10)
+
+  if source_reconstructions.get_shape().as_list()[3] == 4:
+    tf.summary.image(
+        'Source Images:Recons:Depth',
+        source_reconstructions[:, :, :, 3:4],
+        max_outputs=10)
+    tf.summary.image(
+        'Target Images:Recons:Depth',
+        target_reconstructions[:, :, :, 3:4],
+        max_outputs=10)
+
+
+def add_task_loss(source_images, source_labels, basic_tower, params):
+  """Adds a classification and/or pose estimation loss to the model.
+
+  Args:
+    source_images: images from the source domain, a tensor of size
+      [batch_size, height, width, channels]
+    source_labels: labels from the source domain, a tensor of size [batch_size].
+      or a tuple of (quaternions, class_labels)
+    basic_tower: a function that creates the single tower of the model.
+    params: A dictionary of parameters. Expecting 'weight_decay', 'pose_weight'.
+  Returns:
+    The source endpoints.
+
+  Raises:
+    RuntimeError: if basic tower does not support pose estimation.
+  """
+  with tf.variable_scope('towers'):
+    source_logits, source_endpoints = basic_tower(
+        source_images, weight_decay=params['weight_decay'], prefix='Source')
+
+  if 'quaternions' in source_labels:  # We have pose estimation as well
+    if 'quaternion_pred' not in source_endpoints:
+      raise RuntimeError('Please use a model for estimation e.g. pose_mini')
+
+    loss = losses.log_quaternion_loss(source_labels['quaternions'],
+                                      source_endpoints['quaternion_pred'],
+                                      params)
+
+    assert_op = tf.Assert(tf.is_finite(loss), [loss])
+    with tf.control_dependencies([assert_op]):
+      quaternion_loss = loss
+      tf.summary.histogram('log_quaternion_loss_hist', quaternion_loss)
+    slim.losses.add_loss(quaternion_loss * params['pose_weight'])
+    tf.summary.scalar('losses/quaternion_loss', quaternion_loss)
+
+  classification_loss = tf.losses.softmax_cross_entropy(
+      source_labels['classes'], source_logits)
+
+  tf.summary.scalar('losses/classification_loss', classification_loss)
+  return source_endpoints

+ 175 - 0
domain_adaptation/domain_separation/dsn_eval.py

@@ -0,0 +1,175 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# pylint: disable=line-too-long
+r"""Evaluation for Domain Separation Networks (DSNs).
+
+To build locally for CPU:
+  blaze build -c opt --copt=-mavx \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
+
+To build locally for GPU:
+  blaze build -c opt --copt=-mavx --config=cuda_clang \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
+
+To run locally:
+$
+./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval
+\
+    --alsologtostderr
+"""
+# pylint: enable=line-too-long
+import math
+
+import numpy as np
+import tensorflow as tf
+
+from domain_adaptation.datasets import dataset_factory
+from domain_adaptation.domain_separation import losses
+from domain_adaptation.domain_separation import models
+
+slim = tf.contrib.slim
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_integer('batch_size', 32,
+                            'The number of images in each batch.')
+
+tf.app.flags.DEFINE_string('master', '',
+                           'BNS name of the TensorFlow master to use.')
+
+tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/da/',
+                           'Directory where the model was written to.')
+
+tf.app.flags.DEFINE_string(
+    'eval_dir', '/tmp/da/',
+    'Directory where we should write the tf summaries to.')
+
+tf.app.flags.DEFINE_string('dataset_dir', '/cns/ok-d/home/konstantinos/cad_learning/',
+                           'The directory where the dataset files are stored.')
+
+tf.app.flags.DEFINE_string('dataset', 'mnist_m',
+                           'Which dataset to test on: "mnist", "mnist_m".')
+
+tf.app.flags.DEFINE_string('split', 'valid',
+                           'Which portion to test on: "valid", "test".')
+
+tf.app.flags.DEFINE_integer('num_examples', 1000, 'Number of test examples.')
+
+tf.app.flags.DEFINE_string('basic_tower', 'dann_mnist',
+                           'The basic tower building block.')
+
+tf.app.flags.DEFINE_bool('enable_precision_recall', False,
+                         'If True, precision and recall for each class will '
+                         'be added to the metrics.')
+
+tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
+
+
+def quaternion_metric(predictions, labels):
+  params = {'batch_size': FLAGS.batch_size, 'use_logging': False}
+  logcost = losses.log_quaternion_loss_batch(predictions, labels, params)
+  return slim.metrics.streaming_mean(logcost)
+
+
+def angle_diff(true_q, pred_q):
+  angles = 2 * (
+      180.0 /
+      np.pi) * np.arccos(np.abs(np.sum(np.multiply(pred_q, true_q), axis=1)))
+  return angles
+
+
+def provide_batch_fn():
+  """ The provide_batch function to use. """
+  return dataset_factory.provide_batch
+
+
+def main(_):
+  g = tf.Graph()
+  with g.as_default():
+    # Load the data.
+    images, labels = provide_batch_fn()(
+        FLAGS.dataset, FLAGS.split, FLAGS.dataset_dir, 4, FLAGS.batch_size, 4)
+
+    num_classes = labels['classes'].get_shape().as_list()[1]
+
+    tf.summary.image('eval_images', images, max_outputs=3)
+
+    # Define the model:
+    with tf.variable_scope('towers'):
+      basic_tower = getattr(models, FLAGS.basic_tower)
+      predictions, endpoints = basic_tower(
+          images,
+          num_classes=num_classes,
+          is_training=False,
+          batch_norm_params=None)
+    metric_names_to_values = {}
+
+    # Define the metrics:
+    if 'quaternions' in labels:  # Also have to evaluate pose estimation!
+      quaternion_loss = quaternion_metric(labels['quaternions'],
+                                          endpoints['quaternion_pred'])
+
+      angle_errors, = tf.py_func(
+          angle_diff, [labels['quaternions'], endpoints['quaternion_pred']],
+          [tf.float32])
+
+      metric_names_to_values[
+          'Angular mean error'] = slim.metrics.streaming_mean(angle_errors)
+      metric_names_to_values['Quaternion Loss'] = quaternion_loss
+
+    accuracy = tf.contrib.metrics.streaming_accuracy(
+        tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
+
+    predictions = tf.argmax(predictions, 1)
+    labels = tf.argmax(labels['classes'], 1)
+    metric_names_to_values['Accuracy'] = accuracy
+
+    if FLAGS.enable_precision_recall:
+      for i in xrange(num_classes):
+        index_map = tf.one_hot(i, depth=num_classes)
+        name = 'PR/Precision_{}'.format(i)
+        metric_names_to_values[name] = slim.metrics.streaming_precision(
+            tf.gather(index_map, predictions), tf.gather(index_map, labels))
+        name = 'PR/Recall_{}'.format(i)
+        metric_names_to_values[name] = slim.metrics.streaming_recall(
+            tf.gather(index_map, predictions), tf.gather(index_map, labels))
+
+    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
+        metric_names_to_values)
+
+    # Create the summary ops such that they also print out to std output:
+    summary_ops = []
+    for metric_name, metric_value in names_to_values.iteritems():
+      op = tf.summary.scalar(metric_name, metric_value)
+      op = tf.Print(op, [metric_value], metric_name)
+      summary_ops.append(op)
+
+    # This ensures that we make a single pass over all of the data.
+    num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
+
+    # Setup the global step.
+    slim.get_or_create_global_step()
+    slim.evaluation.evaluation_loop(
+        FLAGS.master,
+        checkpoint_dir=FLAGS.checkpoint_dir,
+        logdir=FLAGS.eval_dir,
+        num_evals=num_batches,
+        eval_op=names_to_updates.values(),
+        summary_op=tf.summary.merge(summary_ops))
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 157 - 0
domain_adaptation/domain_separation/dsn_test.py

@@ -0,0 +1,157 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for DSN model assembly functions."""
+
+import numpy as np
+import tensorflow as tf
+
+import dsn
+
+
+class HelperFunctionsTest(tf.test.TestCase):
+
+  def testBasicDomainSeparationStartPoint(self):
+    with self.test_session() as sess:
+      # Test for when global_step < domain_separation_startpoint
+      step = tf.contrib.slim.get_or_create_global_step()
+      sess.run(tf.initialize_all_variables())  # global_step = 0
+      params = {'domain_separation_startpoint': 2}
+      weight = dsn.dsn_loss_coefficient(params)
+      weight_np = sess.run(weight)
+      self.assertAlmostEqual(weight_np, 1e-10)
+
+      step_op = tf.assign_add(step, 1)
+      step_np = sess.run(step_op)  # global_step = 1
+      weight = dsn.dsn_loss_coefficient(params)
+      weight_np = sess.run(weight)
+      self.assertAlmostEqual(weight_np, 1e-10)
+
+      # Test for when global_step >= domain_separation_startpoint
+      step_np = sess.run(step_op)  # global_step = 2
+      tf.logging.info(step_np)
+      weight = dsn.dsn_loss_coefficient(params)
+      weight_np = sess.run(weight)
+      self.assertAlmostEqual(weight_np, 1.0)
+
+
+class DsnModelAssemblyTest(tf.test.TestCase):
+
+  def _testBuildDefaultModel(self):
+    images = tf.to_float(np.random.rand(32, 28, 28, 1))
+    labels = {}
+    labels['classes'] = tf.one_hot(
+        tf.to_int32(np.random.randint(0, 9, (32))), 10)
+
+    params = {
+        'use_separation': True,
+        'layers_to_regularize': 'fc3',
+        'weight_decay': 0.0,
+        'ps_tasks': 1,
+        'domain_separation_startpoint': 1,
+        'alpha_weight': 1,
+        'beta_weight': 1,
+        'gamma_weight': 1,
+        'recon_loss_name': 'sum_of_squares',
+        'decoder_name': 'small_decoder',
+        'encoder_name': 'default_encoder',
+    }
+    return images, labels, params
+
+  def testBuildModelDann(self):
+    images, labels, params = self._testBuildDefaultModel()
+
+    with self.test_session():
+      dsn.create_model(images, labels,
+                       tf.cast(tf.ones([32,]), tf.bool), images, labels,
+                       'dann_loss', params, 'dann_mnist')
+      loss_tensors = tf.contrib.losses.get_losses()
+    self.assertEqual(len(loss_tensors), 6)
+
+  def testBuildModelDannSumOfPairwiseSquares(self):
+    images, labels, params = self._testBuildDefaultModel()
+
+    with self.test_session():
+      dsn.create_model(images, labels,
+                       tf.cast(tf.ones([32,]), tf.bool), images, labels,
+                       'dann_loss', params, 'dann_mnist')
+      loss_tensors = tf.contrib.losses.get_losses()
+    self.assertEqual(len(loss_tensors), 6)
+
+  def testBuildModelDannMultiPSTasks(self):
+    images, labels, params = self._testBuildDefaultModel()
+    params['ps_tasks'] = 10
+    with self.test_session():
+      dsn.create_model(images, labels,
+                       tf.cast(tf.ones([32,]), tf.bool), images, labels,
+                       'dann_loss', params, 'dann_mnist')
+      loss_tensors = tf.contrib.losses.get_losses()
+    self.assertEqual(len(loss_tensors), 6)
+
+  def testBuildModelMmd(self):
+    images, labels, params = self._testBuildDefaultModel()
+
+    with self.test_session():
+      dsn.create_model(images, labels,
+                       tf.cast(tf.ones([32,]), tf.bool), images, labels,
+                       'mmd_loss', params, 'dann_mnist')
+      loss_tensors = tf.contrib.losses.get_losses()
+    self.assertEqual(len(loss_tensors), 6)
+
+  def testBuildModelCorr(self):
+    images, labels, params = self._testBuildDefaultModel()
+
+    with self.test_session():
+      dsn.create_model(images, labels,
+                       tf.cast(tf.ones([32,]), tf.bool), images, labels,
+                       'correlation_loss', params, 'dann_mnist')
+      loss_tensors = tf.contrib.losses.get_losses()
+    self.assertEqual(len(loss_tensors), 6)
+
+  def testBuildModelNoDomainAdaptation(self):
+    images, labels, params = self._testBuildDefaultModel()
+    params['use_separation'] = False
+    with self.test_session():
+      dsn.create_model(images, labels,
+                       tf.cast(tf.ones([32,]), tf.bool), images, labels, 'none',
+                       params, 'dann_mnist')
+      loss_tensors = tf.contrib.losses.get_losses()
+      self.assertEqual(len(loss_tensors), 1)
+      self.assertEqual(len(tf.contrib.losses.get_regularization_losses()), 0)
+
+  def testBuildModelNoAdaptationWeightDecay(self):
+    images, labels, params = self._testBuildDefaultModel()
+    params['use_separation'] = False
+    params['weight_decay'] = 1e-5
+    with self.test_session():
+      dsn.create_model(images, labels,
+                       tf.cast(tf.ones([32,]), tf.bool), images, labels, 'none',
+                       params, 'dann_mnist')
+      loss_tensors = tf.contrib.losses.get_losses()
+      self.assertEqual(len(loss_tensors), 1)
+      self.assertTrue(len(tf.contrib.losses.get_regularization_losses()) >= 1)
+
+  def testBuildModelNoSeparation(self):
+    images, labels, params = self._testBuildDefaultModel()
+    params['use_separation'] = False
+    with self.test_session():
+      dsn.create_model(images, labels,
+                       tf.cast(tf.ones([32,]), tf.bool), images, labels,
+                       'dann_loss', params, 'dann_mnist')
+      loss_tensors = tf.contrib.losses.get_losses()
+    self.assertEqual(len(loss_tensors), 2)
+
+
+if __name__ == '__main__':
+  tf.test.main()

+ 301 - 0
domain_adaptation/domain_separation/dsn_train.py

@@ -0,0 +1,301 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# pylint: disable=line-too-long
+r"""Training for Domain Separation Networks (DSNs).
+
+-- Compile:
+$ blaze build -c opt --copt=-mavx --config=cuda \
+    third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_train
+
+-- Run:
+$
+./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_train
+\
+    --similarity_loss=dann \
+    --basic_tower=dsn_cropped_linemod \
+    --source_dataset=pose_synthetic \
+    --target_dataset=pose_real \
+    --learning_rate=0.012 \
+    --alpha_weight=0.26 \
+    --gamma_weight=0.0115 \
+    --weight_decay=4e-5 \
+    --layers_to_regularize=fc3 \
+    --use_separation \
+    --alsologtostderr
+"""
+# pylint: enable=line-too-long
+from __future__ import division
+
+import tensorflow as tf
+
+from domain_adaptation.datasets import dataset_factory
+import dsn
+
+slim = tf.contrib.slim
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_integer('batch_size', 32,
+                            'The number of images in each batch.')
+
+tf.app.flags.DEFINE_string('source_dataset', 'pose_synthetic',
+                           'Source dataset to train on.')
+
+tf.app.flags.DEFINE_string('target_dataset', 'pose_real',
+                           'Target dataset to train on.')
+
+tf.app.flags.DEFINE_string('target_labeled_dataset', 'none',
+                           'Target dataset to train on.')
+
+tf.app.flags.DEFINE_string('dataset_dir', '/cns/ok-d/home/konstantinos/cad_learning/',
+                           'The directory where the dataset files are stored.')
+
+tf.app.flags.DEFINE_string('master', '',
+                           'BNS name of the TensorFlow master to use.')
+
+tf.app.flags.DEFINE_string('train_log_dir', '/tmp/da/',
+                           'Directory where to write event logs.')
+
+tf.app.flags.DEFINE_string(
+    'layers_to_regularize', 'fc3',
+    'Comma-seperated list of layer names to use MMD regularization on.')
+
+tf.app.flags.DEFINE_float('learning_rate', .01, 'The learning rate')
+
+tf.app.flags.DEFINE_float('alpha_weight', 1e-6,
+                          'The coefficient for scaling the reconstruction '
+                          'loss.')
+
+tf.app.flags.DEFINE_float(
+    'beta_weight', 1e-6,
+    'The coefficient for scaling the private/shared difference loss.')
+
+tf.app.flags.DEFINE_float(
+    'gamma_weight', 1e-6,
+    'The coefficient for scaling the shared encoding similarity loss.')
+
+tf.app.flags.DEFINE_float('pose_weight', 0.125,
+                          'The coefficient for scaling the pose loss.')
+
+tf.app.flags.DEFINE_float(
+    'weight_decay', 1e-6,
+    'The coefficient for the L2 regularization applied for all weights.')
+
+tf.app.flags.DEFINE_integer(
+    'save_summaries_secs', 60,
+    'The frequency with which summaries are saved, in seconds.')
+
+tf.app.flags.DEFINE_integer(
+    'save_interval_secs', 60,
+    'The frequency with which the model is saved, in seconds.')
+
+tf.app.flags.DEFINE_integer(
+    'max_number_of_steps', None,
+    'The maximum number of gradient steps. Use None to train indefinitely.')
+
+tf.app.flags.DEFINE_integer(
+    'domain_separation_startpoint', 1,
+    'The global step to add the domain separation losses.')
+
+tf.app.flags.DEFINE_integer(
+    'bipartite_assignment_top_k', 3,
+    'The number of top-k matches to use in bipartite matching adaptation.')
+
+tf.app.flags.DEFINE_float('decay_rate', 0.95, 'Learning rate decay factor.')
+
+tf.app.flags.DEFINE_integer('decay_steps', 20000, 'Learning rate decay steps.')
+
+tf.app.flags.DEFINE_float('momentum', 0.9, 'The momentum value.')
+
+tf.app.flags.DEFINE_bool('use_separation', False,
+                         'Use our domain separation model.')
+
+tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
+
+tf.app.flags.DEFINE_integer(
+    'ps_tasks', 0,
+    'The number of parameter servers. If the value is 0, then the parameters '
+    'are handled locally by the worker.')
+
+tf.app.flags.DEFINE_integer(
+    'num_readers', 4,
+    'The number of parallel readers that read data from the dataset.')
+
+tf.app.flags.DEFINE_integer('num_preprocessing_threads', 4,
+                            'The number of threads used to create the batches.')
+
+tf.app.flags.DEFINE_integer(
+    'task', 0,
+    'The Task ID. This value is used when training with multiple workers to '
+    'identify each worker.')
+
+tf.app.flags.DEFINE_string('decoder_name', 'small_decoder',
+                           'The decoder to use.')
+tf.app.flags.DEFINE_string('encoder_name', 'default_encoder',
+                           'The encoder to use.')
+
+################################################################################
+# Flags that control the architecture and losses
+################################################################################
+tf.app.flags.DEFINE_string(
+    'similarity_loss', 'grl',
+    'The method to use for encouraging the common encoder codes to be '
+    'similar, one of "grl", "mmd", "corr".')
+
+tf.app.flags.DEFINE_string('recon_loss_name', 'sum_of_pairwise_squares',
+                           'The name of the reconstruction loss.')
+
+tf.app.flags.DEFINE_string('basic_tower', 'pose_mini',
+                           'The basic tower building block.')
+
+def provide_batch_fn():
+  """ The provide_batch function to use. """
+  return dataset_factory.provide_batch
+
+def main(_):
+  model_params = {
+      'use_separation': FLAGS.use_separation,
+      'domain_separation_startpoint': FLAGS.domain_separation_startpoint,
+      'layers_to_regularize': FLAGS.layers_to_regularize,
+      'alpha_weight': FLAGS.alpha_weight,
+      'beta_weight': FLAGS.beta_weight,
+      'gamma_weight': FLAGS.gamma_weight,
+      'pose_weight': FLAGS.pose_weight,
+      'recon_loss_name': FLAGS.recon_loss_name,
+      'decoder_name': FLAGS.decoder_name,
+      'encoder_name': FLAGS.encoder_name,
+      'weight_decay': FLAGS.weight_decay,
+      'batch_size': FLAGS.batch_size,
+      'use_logging': FLAGS.use_logging,
+      'ps_tasks': FLAGS.ps_tasks,
+      'task': FLAGS.task,
+  }
+  g = tf.Graph()
+  with g.as_default():
+    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
+      # Load the data.
+      source_images, source_labels = provide_batch_fn()(
+          FLAGS.source_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
+          FLAGS.batch_size, FLAGS.num_preprocessing_threads)
+      target_images, target_labels = provide_batch_fn()(
+          FLAGS.target_dataset, 'train', FLAGS.dataset_dir, FLAGS.num_readers,
+          FLAGS.batch_size, FLAGS.num_preprocessing_threads)
+
+      # In the unsupervised case all the samples in the labeled
+      # domain are from the source domain.
+      domain_selection_mask = tf.fill((source_images.get_shape().as_list()[0],),
+                                      True)
+
+      # When using the semisupervised model we include labeled target data in
+      # the source labelled data.
+      if FLAGS.target_labeled_dataset != 'none':
+        # 1000 is the maximum number of labelled target samples that exists in
+        # the datasets.
+        target_semi_images, target_semi_labels = data_provider.provide(
+            FLAGS.target_labeled_dataset, 'train', FLAGS.batch_size)
+
+        # Calculate the proportion of source domain samples in the semi-
+        # supervised setting, so that the proportion is set accordingly in the
+        # batches.
+        proportion = float(source_labels['num_train_samples']) / (
+            source_labels['num_train_samples'] +
+            target_semi_labels['num_train_samples'])
+
+        rnd_tensor = tf.random_uniform(
+            (target_semi_images.get_shape().as_list()[0],))
+
+        domain_selection_mask = rnd_tensor < proportion
+        source_images = tf.where(domain_selection_mask, source_images,
+                                 target_semi_images)
+        source_class_labels = tf.where(domain_selection_mask,
+                                       source_labels['classes'],
+                                       target_semi_labels['classes'])
+
+        if 'quaternions' in source_labels:
+          source_pose_labels = tf.where(domain_selection_mask,
+                                        source_labels['quaternions'],
+                                        target_semi_labels['quaternions'])
+          (source_images, source_class_labels, source_pose_labels,
+           domain_selection_mask) = tf.train.shuffle_batch(
+               [
+                   source_images, source_class_labels, source_pose_labels,
+                   domain_selection_mask
+               ],
+               FLAGS.batch_size,
+               50000,
+               5000,
+               num_threads=1,
+               enqueue_many=True)
+
+        else:
+          (source_images, source_class_labels,
+           domain_selection_mask) = tf.train.shuffle_batch(
+               [source_images, source_class_labels, domain_selection_mask],
+               FLAGS.batch_size,
+               50000,
+               5000,
+               num_threads=1,
+               enqueue_many=True)
+        source_labels = {}
+        source_labels['classes'] = source_class_labels
+        if 'quaternions' in source_labels:
+          source_labels['quaternions'] = source_pose_labels
+
+      slim.get_or_create_global_step()
+      tf.summary.image('source_images', source_images, max_outputs=3)
+      tf.summary.image('target_images', target_images, max_outputs=3)
+
+      dsn.create_model(
+          source_images,
+          source_labels,
+          domain_selection_mask,
+          target_images,
+          target_labels,
+          FLAGS.similarity_loss,
+          model_params,
+          basic_tower_name=FLAGS.basic_tower)
+
+      # Configure the optimization scheme:
+      learning_rate = tf.train.exponential_decay(
+          FLAGS.learning_rate,
+          slim.get_or_create_global_step(),
+          FLAGS.decay_steps,
+          FLAGS.decay_rate,
+          staircase=True,
+          name='learning_rate')
+
+      tf.summary.scalar('learning_rate', learning_rate)
+      tf.summary.scalar('total_loss', tf.losses.get_total_loss())
+
+      opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
+      tf.logging.set_verbosity(tf.logging.INFO)
+      # Run training.
+      loss_tensor = slim.learning.create_train_op(
+          slim.losses.get_total_loss(),
+          opt,
+          summarize_gradients=True,
+          colocate_gradients_with_ops=True)
+      slim.learning.train(
+          train_op=loss_tensor,
+          logdir=FLAGS.train_log_dir,
+          master=FLAGS.master,
+          is_chief=FLAGS.task == 0,
+          number_of_steps=FLAGS.max_number_of_steps,
+          save_summaries_secs=FLAGS.save_summaries_secs,
+          save_interval_secs=FLAGS.save_interval_secs)
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 34 - 0
domain_adaptation/domain_separation/grl_op_grads.py

@@ -0,0 +1,34 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Gradients for operators defined in grl_ops.py."""
+import tensorflow as tf
+
+
+@tf.RegisterGradient("GradientReversal")
+def _GradientReversalGrad(_, grad):
+  """The gradients for `gradient_reversal`.
+
+  Args:
+    _: The `gradient_reversal` `Operation` that we are differentiating,
+      which we can use to find the inputs and outputs of the original op.
+    grad: Gradient with respect to the output of the `gradient_reversal` op.
+
+  Returns:
+    Gradient with respect to the input of `gradient_reversal`, which is simply
+    the negative of the input gradient.
+
+  """
+  return tf.negative(grad)

+ 47 - 0
domain_adaptation/domain_separation/grl_op_kernels.cc

@@ -0,0 +1,47 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file contains the implementations of the ops registered in
+// grl_ops.cc.
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.pb.h"
+
+namespace tensorflow {
+
+// The gradient reversal op is used in domain adversarial training.  It behaves
+// as the identity op during forward propagation, and multiplies its input by -1
+// during backward propagation.
+class GradientReversalOp : public OpKernel {
+ public:
+  explicit GradientReversalOp(OpKernelConstruction* context)
+      : OpKernel(context) {}
+
+  // Gradient reversal op behaves as the identity op during forward
+  // propagation. Compute() function copied from the IdentityOp::Compute()
+  // function here: third_party/tensorflow/core/kernels/identity_op.h.
+  void Compute(OpKernelContext* context) override {
+    if (IsRefType(context->input_dtype(0))) {
+      context->forward_ref_input_to_ref_output(0, 0);
+    } else {
+      context->set_output(0, context->input(0));
+    }
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("GradientReversal").Device(DEVICE_CPU),
+                        GradientReversalOp);
+
+}  // namespace tensorflow

+ 16 - 0
domain_adaptation/domain_separation/grl_op_shapes.py

@@ -0,0 +1,16 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Shape inference for operators defined in grl_ops.cc."""

+ 36 - 0
domain_adaptation/domain_separation/grl_ops.cc

@@ -0,0 +1,36 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Contains custom ops.
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+// This custom op is used by adversarial training.
+REGISTER_OP("GradientReversal")
+    .Input("input: float")
+    .Output("output: float")
+    .SetShapeFn(shape_inference::UnchangedShape)
+    .Doc(R"doc(
+This op copies the input to the output during forward propagation, and
+negates the input during backward propagation.
+
+input: Tensor.
+output: Tensor, copied from input.
+)doc");
+
+}  // namespace tensorflow

+ 28 - 0
domain_adaptation/domain_separation/grl_ops.py

@@ -0,0 +1,28 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""GradientReversal op Python library."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+import tensorflow as tf
+
+tf.logging.info(tf.resource_loader.get_data_files_path())
+_grl_ops_module = tf.load_op_library(
+    os.path.join(tf.resource_loader.get_data_files_path(),
+                 '_grl_ops.so'))
+gradient_reversal = _grl_ops_module.gradient_reversal

+ 73 - 0
domain_adaptation/domain_separation/grl_ops_test.py

@@ -0,0 +1,73 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for grl_ops."""
+
+#from models.domain_adaptation.domain_separation import grl_op_grads  # pylint: disable=unused-import
+#from models.domain_adaptation.domain_separation import grl_op_shapes  # pylint: disable=unused-import
+import tensorflow as tf
+
+import grl_op_grads
+import grl_ops
+
+FLAGS = tf.app.flags.FLAGS
+
+
+class GRLOpsTest(tf.test.TestCase):
+
+  def testGradientReversalOp(self):
+    with tf.Graph().as_default():
+      with self.test_session():
+        # Test that in forward prop, gradient reversal op acts as the
+        # identity operation.
+        examples = tf.constant([5.0, 4.0, 3.0, 2.0, 1.0])
+        output = grl_ops.gradient_reversal(examples)
+        expected_output = examples
+        self.assertAllEqual(output.eval(), expected_output.eval())
+
+        # Test that shape inference works as expected.
+        self.assertAllEqual(output.get_shape(), expected_output.get_shape())
+
+        # Test that in backward prop, gradient reversal op multiplies
+        # gradients by -1.
+        examples = tf.constant([[1.0]])
+        w = tf.get_variable(name='w', shape=[1, 1])
+        b = tf.get_variable(name='b', shape=[1])
+        init_op = tf.global_variables_initializer()
+        init_op.run()
+        features = tf.nn.xw_plus_b(examples, w, b)
+        # Construct two outputs: features layer passes directly to output1, but
+        # features layer passes through a gradient reversal layer before
+        # reaching output2.
+        output1 = features
+        output2 = grl_ops.gradient_reversal(features)
+        gold = tf.constant([1.0])
+        loss1 = gold - output1
+        loss2 = gold - output2
+        opt = tf.train.GradientDescentOptimizer(learning_rate=0.01)
+        grads_and_vars_1 = opt.compute_gradients(loss1,
+                                                 tf.trainable_variables())
+        grads_and_vars_2 = opt.compute_gradients(loss2,
+                                                 tf.trainable_variables())
+        self.assertAllEqual(len(grads_and_vars_1), len(grads_and_vars_2))
+        for i in range(len(grads_and_vars_1)):
+          g1 = grads_and_vars_1[i][0]
+          g2 = grads_and_vars_2[i][0]
+          # Verify that gradients of loss1 are the negative of gradients of
+          # loss2.
+          self.assertAllEqual(tf.negative(g1).eval(), g2.eval())
+
+if __name__ == '__main__':
+  tf.test.main()

+ 292 - 0
domain_adaptation/domain_separation/losses.py

@@ -0,0 +1,292 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Domain Adaptation Loss Functions.
+
+The following domain adaptation loss functions are defined:
+
+- Maximum Mean Discrepancy (MMD).
+  Relevant paper:
+    Gretton, Arthur, et al.,
+    "A kernel two-sample test."
+    The Journal of Machine Learning Research, 2012
+
+- Correlation Loss on a batch.
+"""
+from functools import partial
+import tensorflow as tf
+
+import grl_op_grads  # pylint: disable=unused-import
+import grl_op_shapes  # pylint: disable=unused-import
+import grl_ops
+import utils
+slim = tf.contrib.slim
+
+
+################################################################################
+# SIMILARITY LOSS
+################################################################################
+def maximum_mean_discrepancy(x, y, kernel=utils.gaussian_kernel_matrix):
+  r"""Computes the Maximum Mean Discrepancy (MMD) of two samples: x and y.
+
+  Maximum Mean Discrepancy (MMD) is a distance-measure between the samples of
+  the distributions of x and y. Here we use the kernel two sample estimate
+  using the empirical mean of the two distributions.
+
+  MMD^2(P, Q) = || \E{\phi(x)} - \E{\phi(y)} ||^2
+              = \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) },
+
+  where K = <\phi(x), \phi(y)>,
+    is the desired kernel function, in this case a radial basis kernel.
+
+  Args:
+      x: a tensor of shape [num_samples, num_features]
+      y: a tensor of shape [num_samples, num_features]
+      kernel: a function which computes the kernel in MMD. Defaults to the
+              GaussianKernelMatrix.
+
+  Returns:
+      a scalar denoting the squared maximum mean discrepancy loss.
+  """
+  with tf.name_scope('MaximumMeanDiscrepancy'):
+    # \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) }
+    cost = tf.reduce_mean(kernel(x, x))
+    cost += tf.reduce_mean(kernel(y, y))
+    cost -= 2 * tf.reduce_mean(kernel(x, y))
+
+    # We do not allow the loss to become negative.
+    cost = tf.where(cost > 0, cost, 0, name='value')
+  return cost
+
+
+def mmd_loss(source_samples, target_samples, weight, scope=None):
+  """Adds a similarity loss term, the MMD between two representations.
+
+  This Maximum Mean Discrepancy (MMD) loss is calculated with a number of
+  different Gaussian kernels.
+
+  Args:
+    source_samples: a tensor of shape [num_samples, num_features].
+    target_samples: a tensor of shape [num_samples, num_features].
+    weight: the weight of the MMD loss.
+    scope: optional name scope for summary tags.
+
+  Returns:
+    a scalar tensor representing the MMD loss value.
+  """
+  sigmas = [
+      1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
+      1e3, 1e4, 1e5, 1e6
+  ]
+  gaussian_kernel = partial(
+      utils.gaussian_kernel_matrix, sigmas=tf.constant(sigmas))
+
+  loss_value = maximum_mean_discrepancy(
+      source_samples, target_samples, kernel=gaussian_kernel)
+  loss_value = tf.maximum(1e-4, loss_value) * weight
+  assert_op = tf.Assert(tf.is_finite(loss_value), [loss_value])
+  with tf.control_dependencies([assert_op]):
+    tag = 'MMD Loss'
+    if scope:
+      tag = scope + tag
+    tf.contrib.deprecated.scalar_summary(tag, loss_value)
+    tf.losses.add_loss(loss_value)
+
+  return loss_value
+
+
+def correlation_loss(source_samples, target_samples, weight, scope=None):
+  """Adds a similarity loss term, the correlation between two representations.
+
+  Args:
+    source_samples: a tensor of shape [num_samples, num_features]
+    target_samples: a tensor of shape [num_samples, num_features]
+    weight: a scalar weight for the loss.
+    scope: optional name scope for summary tags.
+
+  Returns:
+    a scalar tensor representing the correlation loss value.
+  """
+  with tf.name_scope('corr_loss'):
+    source_samples -= tf.reduce_mean(source_samples, 0)
+    target_samples -= tf.reduce_mean(target_samples, 0)
+
+    source_samples = tf.nn.l2_normalize(source_samples, 1)
+    target_samples = tf.nn.l2_normalize(target_samples, 1)
+
+    source_cov = tf.matmul(tf.transpose(source_samples), source_samples)
+    target_cov = tf.matmul(tf.transpose(target_samples), target_samples)
+
+    corr_loss = tf.reduce_mean(tf.square(source_cov - target_cov)) * weight
+
+  assert_op = tf.Assert(tf.is_finite(corr_loss), [corr_loss])
+  with tf.control_dependencies([assert_op]):
+    tag = 'Correlation Loss'
+    if scope:
+      tag = scope + tag
+    tf.contrib.deprecated.scalar_summary(tag, corr_loss)
+    tf.losses.add_loss(corr_loss)
+
+  return corr_loss
+
+
+def dann_loss(source_samples, target_samples, weight, scope=None):
+  """Adds the domain adversarial (DANN) loss.
+
+  Args:
+    source_samples: a tensor of shape [num_samples, num_features].
+    target_samples: a tensor of shape [num_samples, num_features].
+    weight: the weight of the loss.
+    scope: optional name scope for summary tags.
+
+  Returns:
+    a scalar tensor representing the correlation loss value.
+  """
+  with tf.variable_scope('dann'):
+    batch_size = tf.shape(source_samples)[0]
+    samples = tf.concat([source_samples, target_samples], 0)
+    samples = slim.flatten(samples)
+
+    domain_selection_mask = tf.concat(
+        [tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], 0)
+
+    # Perform the gradient reversal and be careful with the shape.
+    grl = grl_ops.gradient_reversal(samples)
+    grl = tf.reshape(grl, (-1, samples.get_shape().as_list()[1]))
+
+    grl = slim.fully_connected(grl, 100, scope='fc1')
+    logits = slim.fully_connected(grl, 1, activation_fn=None, scope='fc2')
+
+    domain_predictions = tf.sigmoid(logits)
+
+  domain_loss = tf.losses.log_loss(
+      domain_selection_mask, domain_predictions, weights=weight)
+
+  domain_accuracy = utils.accuracy(
+      tf.round(domain_predictions), domain_selection_mask)
+
+  assert_op = tf.Assert(tf.is_finite(domain_loss), [domain_loss])
+  with tf.control_dependencies([assert_op]):
+    tag_loss = 'losses/Domain Loss'
+    tag_accuracy = 'losses/Domain Accuracy'
+    if scope:
+      tag_loss = scope + tag_loss
+      tag_accuracy = scope + tag_accuracy
+
+    tf.contrib.deprecated.scalar_summary(
+        tag_loss, domain_loss, name='domain_loss_summary')
+    tf.contrib.deprecated.scalar_summary(
+        tag_accuracy, domain_accuracy, name='domain_accuracy_summary')
+
+  return domain_loss
+
+
+################################################################################
+# DIFFERENCE LOSS
+################################################################################
+def difference_loss(private_samples, shared_samples, weight=1.0, name=''):
+  """Adds the difference loss between the private and shared representations.
+
+  Args:
+    private_samples: a tensor of shape [num_samples, num_features].
+    shared_samples: a tensor of shape [num_samples, num_features].
+    weight: the weight of the incoherence loss.
+    name: the name of the tf summary.
+  """
+  private_samples -= tf.reduce_mean(private_samples, 0)
+  shared_samples -= tf.reduce_mean(shared_samples, 0)
+
+  private_samples = tf.nn.l2_normalize(private_samples, 1)
+  shared_samples = tf.nn.l2_normalize(shared_samples, 1)
+
+  correlation_matrix = tf.matmul(
+      private_samples, shared_samples, transpose_a=True)
+
+  cost = tf.reduce_mean(tf.square(correlation_matrix)) * weight
+  cost = tf.where(cost > 0, cost, 0, name='value')
+
+  tf.contrib.deprecated.scalar_summary('losses/Difference Loss {}'.format(name),
+                                       cost)
+  assert_op = tf.Assert(tf.is_finite(cost), [cost])
+  with tf.control_dependencies([assert_op]):
+    tf.losses.add_loss(cost)
+
+
+################################################################################
+# TASK LOSS
+################################################################################
+def log_quaternion_loss_batch(predictions, labels, params):
+  """A helper function to compute the error between quaternions.
+
+  Args:
+    predictions: A Tensor of size [batch_size, 4].
+    labels: A Tensor of size [batch_size, 4].
+    params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
+
+  Returns:
+    A Tensor of size [batch_size], denoting the error between the quaternions.
+  """
+  use_logging = params['use_logging']
+  assertions = []
+  if use_logging:
+    assertions.append(
+        tf.Assert(
+            tf.reduce_all(
+                tf.less(
+                    tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1),
+                    1e-4)),
+            ['The l2 norm of each prediction quaternion vector should be 1.']))
+    assertions.append(
+        tf.Assert(
+            tf.reduce_all(
+                tf.less(
+                    tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)),
+            ['The l2 norm of each label quaternion vector should be 1.']))
+
+  with tf.control_dependencies(assertions):
+    product = tf.multiply(predictions, labels)
+  internal_dot_products = tf.reduce_sum(product, [1])
+
+  if use_logging:
+    internal_dot_products = tf.Print(
+        internal_dot_products,
+        [internal_dot_products, tf.shape(internal_dot_products)],
+        'internal_dot_products:')
+
+  logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
+  return logcost
+
+
+def log_quaternion_loss(predictions, labels, params):
+  """A helper function to compute the mean error between batches of quaternions.
+
+  The caller is expected to add the loss to the graph.
+
+  Args:
+    predictions: A Tensor of size [batch_size, 4].
+    labels: A Tensor of size [batch_size, 4].
+    params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
+
+  Returns:
+    A Tensor of size 1, denoting the mean error between batches of quaternions.
+  """
+  use_logging = params['use_logging']
+  logcost = log_quaternion_loss_batch(predictions, labels, params)
+  logcost = tf.reduce_sum(logcost, [0])
+  batch_size = params['batch_size']
+  logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss')
+  if use_logging:
+    logcost = tf.Print(
+        logcost, [logcost], '[logcost]', name='log_quaternion_loss_print')
+  return logcost

+ 110 - 0
domain_adaptation/domain_separation/losses_test.py

@@ -0,0 +1,110 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for DSN losses."""
+from functools import partial
+
+import numpy as np
+import tensorflow as tf
+
+import losses
+import utils
+
+
+def MaximumMeanDiscrepancySlow(x, y, sigmas):
+  num_samples = x.get_shape().as_list()[0]
+
+  def AverageGaussianKernel(x, y, sigmas):
+    result = 0
+    for sigma in sigmas:
+      dist = tf.reduce_sum(tf.square(x - y))
+      result += tf.exp((-1.0 / (2.0 * sigma)) * dist)
+    return result / num_samples**2
+
+  total = 0
+
+  for i in range(num_samples):
+    for j in range(num_samples):
+      total += AverageGaussianKernel(x[i, :], x[j, :], sigmas)
+      total += AverageGaussianKernel(y[i, :], y[j, :], sigmas)
+      total += -2 * AverageGaussianKernel(x[i, :], y[j, :], sigmas)
+
+  return total
+
+
+class LogQuaternionLossTest(tf.test.TestCase):
+
+  def test_log_quaternion_loss_batch(self):
+    with self.test_session():
+      predictions = tf.random_uniform((10, 4), seed=1)
+      predictions = tf.nn.l2_normalize(predictions, 1)
+      labels = tf.random_uniform((10, 4), seed=1)
+      labels = tf.nn.l2_normalize(labels, 1)
+      params = {'batch_size': 10, 'use_logging': False}
+      x = losses.log_quaternion_loss_batch(predictions, labels, params)
+      self.assertTrue(((10,) == tf.shape(x).eval()).all())
+
+
+class MaximumMeanDiscrepancyTest(tf.test.TestCase):
+
+  def test_mmd_name(self):
+    with self.test_session():
+      x = tf.random_uniform((2, 3), seed=1)
+      kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
+      loss = losses.maximum_mean_discrepancy(x, x, kernel)
+
+      self.assertEquals(loss.op.name, 'MaximumMeanDiscrepancy/value')
+
+  def test_mmd_is_zero_when_inputs_are_same(self):
+    with self.test_session():
+      x = tf.random_uniform((2, 3), seed=1)
+      kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
+      self.assertEquals(0, losses.maximum_mean_discrepancy(x, x, kernel).eval())
+
+  def test_fast_mmd_is_similar_to_slow_mmd(self):
+    with self.test_session():
+      x = tf.constant(np.random.normal(size=(2, 3)), tf.float32)
+      y = tf.constant(np.random.rand(2, 3), tf.float32)
+
+      cost_old = MaximumMeanDiscrepancySlow(x, y, [1.]).eval()
+      kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
+      cost_new = losses.maximum_mean_discrepancy(x, y, kernel).eval()
+
+      self.assertAlmostEqual(cost_old, cost_new, delta=1e-5)
+
+  def test_multiple_sigmas(self):
+    with self.test_session():
+      x = tf.constant(np.random.normal(size=(2, 3)), tf.float32)
+      y = tf.constant(np.random.rand(2, 3), tf.float32)
+
+      sigmas = tf.constant([2., 5., 10, 20, 30])
+      kernel = partial(utils.gaussian_kernel_matrix, sigmas=sigmas)
+      cost_old = MaximumMeanDiscrepancySlow(x, y, [2., 5., 10, 20, 30]).eval()
+      cost_new = losses.maximum_mean_discrepancy(x, y, kernel=kernel).eval()
+
+      self.assertAlmostEqual(cost_old, cost_new, delta=1e-5)
+
+  def test_mmd_is_zero_when_distributions_are_same(self):
+
+    with self.test_session():
+      x = tf.random_uniform((1000, 10), seed=1)
+      y = tf.random_uniform((1000, 10), seed=3)
+
+      kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([100.]))
+      loss = losses.maximum_mean_discrepancy(x, y, kernel=kernel).eval()
+
+      self.assertAlmostEqual(0, loss, delta=1e-4)
+
+if __name__ == '__main__':
+  tf.test.main()

+ 443 - 0
domain_adaptation/domain_separation/models.py

@@ -0,0 +1,443 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains different architectures for the different DSN parts.
+
+We define here the modules that can be used in the different parts of the DSN
+model.
+- shared encoder (dsn_cropped_linemod, dann_xxxx)
+- private encoder (default_encoder)
+- decoder (large_decoder, gtsrb_decoder, small_decoder)
+"""
+import tensorflow as tf
+
+#from models.domain_adaptation.domain_separation
+import utils
+
+slim = tf.contrib.slim
+
+
+def default_batch_norm_params(is_training=False):
+  """Returns default batch normalization parameters for DSNs.
+
+  Args:
+    is_training: whether or not the model is training.
+
+  Returns:
+    a dictionary that maps batch norm parameter names (strings) to values.
+  """
+  return {
+      # Decay for the moving averages.
+      'decay': 0.5,
+      # epsilon to prevent 0s in variance.
+      'epsilon': 0.001,
+      'is_training': is_training
+  }
+
+
+################################################################################
+# PRIVATE ENCODERS
+################################################################################
+def default_encoder(images, code_size, batch_norm_params=None,
+                    weight_decay=0.0):
+  """Encodes the given images to codes of the given size.
+
+  Args:
+    images: a tensor of size [batch_size, height, width, 1].
+    code_size: the number of hidden units in the code layer of the classifier.
+    batch_norm_params: a dictionary that maps batch norm parameter names to
+      values.
+    weight_decay: the value for the weight decay coefficient.
+
+  Returns:
+    end_points: the code of the input.
+  """
+  end_points = {}
+  with slim.arg_scope(
+      [slim.conv2d, slim.fully_connected],
+      weights_regularizer=slim.l2_regularizer(weight_decay),
+      activation_fn=tf.nn.relu,
+      normalizer_fn=slim.batch_norm,
+      normalizer_params=batch_norm_params):
+    with slim.arg_scope([slim.conv2d], kernel_size=[5, 5], padding='SAME'):
+      net = slim.conv2d(images, 32, scope='conv1')
+      net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
+      net = slim.conv2d(net, 64, scope='conv2')
+      net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
+
+      net = slim.flatten(net)
+      end_points['flatten'] = net
+      net = slim.fully_connected(net, code_size, scope='fc1')
+      end_points['fc3'] = net
+  return end_points
+
+
+################################################################################
+# DECODERS
+################################################################################
+def large_decoder(codes,
+                  height,
+                  width,
+                  channels,
+                  batch_norm_params=None,
+                  weight_decay=0.0):
+  """Decodes the codes to a fixed output size.
+
+  Args:
+    codes: a tensor of size [batch_size, code_size].
+    height: the height of the output images.
+    width: the width of the output images.
+    channels: the number of the output channels.
+    batch_norm_params: a dictionary that maps batch norm parameter names to
+      values.
+    weight_decay: the value for the weight decay coefficient.
+
+  Returns:
+    recons: the reconstruction tensor of shape [batch_size, height, width, 3].
+  """
+  with slim.arg_scope(
+      [slim.conv2d, slim.fully_connected],
+      weights_regularizer=slim.l2_regularizer(weight_decay),
+      activation_fn=tf.nn.relu,
+      normalizer_fn=slim.batch_norm,
+      normalizer_params=batch_norm_params):
+    net = slim.fully_connected(codes, 600, scope='fc1')
+    batch_size = net.get_shape().as_list()[0]
+    net = tf.reshape(net, [batch_size, 10, 10, 6])
+
+    net = slim.conv2d(net, 32, [5, 5], scope='conv1_1')
+
+    net = tf.image.resize_nearest_neighbor(net, (16, 16))
+
+    net = slim.conv2d(net, 32, [5, 5], scope='conv2_1')
+
+    net = tf.image.resize_nearest_neighbor(net, (32, 32))
+
+    net = slim.conv2d(net, 32, [5, 5], scope='conv3_2')
+
+    output_size = [height, width]
+    net = tf.image.resize_nearest_neighbor(net, output_size)
+
+    with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
+      net = slim.conv2d(net, channels, activation_fn=None, scope='conv4_1')
+
+  return net
+
+
+def gtsrb_decoder(codes,
+                  height,
+                  width,
+                  channels,
+                  batch_norm_params=None,
+                  weight_decay=0.0):
+  """Decodes the codes to a fixed output size. This decoder is specific to GTSRB
+
+  Args:
+    codes: a tensor of size [batch_size, 100].
+    height: the height of the output images.
+    width: the width of the output images.
+    channels: the number of the output channels.
+    batch_norm_params: a dictionary that maps batch norm parameter names to
+      values.
+    weight_decay: the value for the weight decay coefficient.
+
+  Returns:
+    recons: the reconstruction tensor of shape [batch_size, height, width, 3].
+
+  Raises:
+    ValueError: When the input code size is not 100.
+  """
+  batch_size, code_size = codes.get_shape().as_list()
+  if code_size != 100:
+    raise ValueError('The code size used as an input to the GTSRB decoder is '
+                     'expected to be 100.')
+
+  with slim.arg_scope(
+      [slim.conv2d, slim.fully_connected],
+      weights_regularizer=slim.l2_regularizer(weight_decay),
+      activation_fn=tf.nn.relu,
+      normalizer_fn=slim.batch_norm,
+      normalizer_params=batch_norm_params):
+    net = codes
+    net = tf.reshape(net, [batch_size, 10, 10, 1])
+    net = slim.conv2d(net, 32, [3, 3], scope='conv1_1')
+
+    # First upsampling 20x20
+    net = tf.image.resize_nearest_neighbor(net, [20, 20])
+
+    net = slim.conv2d(net, 32, [3, 3], scope='conv2_1')
+
+    output_size = [height, width]
+    # Final upsampling 40 x 40
+    net = tf.image.resize_nearest_neighbor(net, output_size)
+
+    with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
+      net = slim.conv2d(net, 16, scope='conv3_1')
+      net = slim.conv2d(net, channels, activation_fn=None, scope='conv3_2')
+
+  return net
+
+
+def small_decoder(codes,
+                  height,
+                  width,
+                  channels,
+                  batch_norm_params=None,
+                  weight_decay=0.0):
+  """Decodes the codes to a fixed output size.
+
+  Args:
+    codes: a tensor of size [batch_size, code_size].
+    height: the height of the output images.
+    width: the width of the output images.
+    channels: the number of the output channels.
+    batch_norm_params: a dictionary that maps batch norm parameter names to
+      values.
+    weight_decay: the value for the weight decay coefficient.
+
+  Returns:
+    recons: the reconstruction tensor of shape [batch_size, height, width, 3].
+  """
+  with slim.arg_scope(
+      [slim.conv2d, slim.fully_connected],
+      weights_regularizer=slim.l2_regularizer(weight_decay),
+      activation_fn=tf.nn.relu,
+      normalizer_fn=slim.batch_norm,
+      normalizer_params=batch_norm_params):
+    net = slim.fully_connected(codes, 300, scope='fc1')
+    batch_size = net.get_shape().as_list()[0]
+    net = tf.reshape(net, [batch_size, 10, 10, 3])
+
+    net = slim.conv2d(net, 16, [3, 3], scope='conv1_1')
+    net = slim.conv2d(net, 16, [3, 3], scope='conv1_2')
+
+    output_size = [height, width]
+    net = tf.image.resize_nearest_neighbor(net, output_size)
+
+    with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
+      net = slim.conv2d(net, 16, scope='conv2_1')
+      net = slim.conv2d(net, channels, activation_fn=None, scope='conv2_2')
+
+  return net
+
+
+################################################################################
+# SHARED ENCODERS
+################################################################################
+def dann_mnist(images,
+               weight_decay=0.0,
+               prefix='model',
+               num_classes=10,
+               **kwargs):
+  """Creates a convolution MNIST model.
+
+  Note that this model implements the architecture for MNIST proposed in:
+   Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
+   JMLR 2015
+
+  Args:
+    images: the MNIST digits, a tensor of size [batch_size, 28, 28, 1].
+    weight_decay: the value for the weight decay coefficient.
+    prefix: name of the model to use when prefixing tags.
+    num_classes: the number of output classes to use.
+    **kwargs: Placeholder for keyword arguments used by other shared encoders.
+
+  Returns:
+    the output logits, a tensor of size [batch_size, num_classes].
+    a dictionary with key/values the layer names and tensors.
+  """
+  end_points = {}
+
+  with slim.arg_scope(
+      [slim.conv2d, slim.fully_connected],
+      weights_regularizer=slim.l2_regularizer(weight_decay),
+      activation_fn=tf.nn.relu,):
+    with slim.arg_scope([slim.conv2d], padding='SAME'):
+      end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
+      end_points['pool1'] = slim.max_pool2d(
+          end_points['conv1'], [2, 2], 2, scope='pool1')
+      end_points['conv2'] = slim.conv2d(
+          end_points['pool1'], 48, [5, 5], scope='conv2')
+      end_points['pool2'] = slim.max_pool2d(
+          end_points['conv2'], [2, 2], 2, scope='pool2')
+      end_points['fc3'] = slim.fully_connected(
+          slim.flatten(end_points['pool2']), 100, scope='fc3')
+      end_points['fc4'] = slim.fully_connected(
+          slim.flatten(end_points['fc3']), 100, scope='fc4')
+
+  logits = slim.fully_connected(
+      end_points['fc4'], num_classes, activation_fn=None, scope='fc5')
+
+  return logits, end_points
+
+
+def dann_svhn(images,
+              weight_decay=0.0,
+              prefix='model',
+              num_classes=10,
+              **kwargs):
+  """Creates the convolutional SVHN model.
+
+  Note that this model implements the architecture for MNIST proposed in:
+   Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
+   JMLR 2015
+
+  Args:
+    images: the SVHN digits, a tensor of size [batch_size, 32, 32, 3].
+    weight_decay: the value for the weight decay coefficient.
+    prefix: name of the model to use when prefixing tags.
+    num_classes: the number of output classes to use.
+    **kwargs: Placeholder for keyword arguments used by other shared encoders.
+
+  Returns:
+    the output logits, a tensor of size [batch_size, num_classes].
+    a dictionary with key/values the layer names and tensors.
+  """
+
+  end_points = {}
+
+  with slim.arg_scope(
+      [slim.conv2d, slim.fully_connected],
+      weights_regularizer=slim.l2_regularizer(weight_decay),
+      activation_fn=tf.nn.relu,):
+    with slim.arg_scope([slim.conv2d], padding='SAME'):
+
+      end_points['conv1'] = slim.conv2d(images, 64, [5, 5], scope='conv1')
+      end_points['pool1'] = slim.max_pool2d(
+          end_points['conv1'], [3, 3], 2, scope='pool1')
+      end_points['conv2'] = slim.conv2d(
+          end_points['pool1'], 64, [5, 5], scope='conv2')
+      end_points['pool2'] = slim.max_pool2d(
+          end_points['conv2'], [3, 3], 2, scope='pool2')
+      end_points['conv3'] = slim.conv2d(
+          end_points['pool2'], 128, [5, 5], scope='conv3')
+
+      end_points['fc3'] = slim.fully_connected(
+          slim.flatten(end_points['conv3']), 3072, scope='fc3')
+      end_points['fc4'] = slim.fully_connected(
+          slim.flatten(end_points['fc3']), 2048, scope='fc4')
+
+  logits = slim.fully_connected(
+      end_points['fc4'], num_classes, activation_fn=None, scope='fc5')
+
+  return logits, end_points
+
+
+def dann_gtsrb(images,
+               weight_decay=0.0,
+               prefix='model',
+               num_classes=43,
+               **kwargs):
+  """Creates the convolutional GTSRB model.
+
+  Note that this model implements the architecture for MNIST proposed in:
+   Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
+   JMLR 2015
+
+  Args:
+    images: the GTSRB images, a tensor of size [batch_size, 40, 40, 3].
+    weight_decay: the value for the weight decay coefficient.
+    prefix: name of the model to use when prefixing tags.
+    num_classes: the number of output classes to use.
+    **kwargs: Placeholder for keyword arguments used by other shared encoders.
+
+  Returns:
+    the output logits, a tensor of size [batch_size, num_classes].
+    a dictionary with key/values the layer names and tensors.
+  """
+
+  end_points = {}
+
+  with slim.arg_scope(
+      [slim.conv2d, slim.fully_connected],
+      weights_regularizer=slim.l2_regularizer(weight_decay),
+      activation_fn=tf.nn.relu,):
+    with slim.arg_scope([slim.conv2d], padding='SAME'):
+
+      end_points['conv1'] = slim.conv2d(images, 96, [5, 5], scope='conv1')
+      end_points['pool1'] = slim.max_pool2d(
+          end_points['conv1'], [2, 2], 2, scope='pool1')
+      end_points['conv2'] = slim.conv2d(
+          end_points['pool1'], 144, [3, 3], scope='conv2')
+      end_points['pool2'] = slim.max_pool2d(
+          end_points['conv2'], [2, 2], 2, scope='pool2')
+      end_points['conv3'] = slim.conv2d(
+          end_points['pool2'], 256, [5, 5], scope='conv3')
+      end_points['pool3'] = slim.max_pool2d(
+          end_points['conv3'], [2, 2], 2, scope='pool3')
+
+      end_points['fc3'] = slim.fully_connected(
+          slim.flatten(end_points['pool3']), 512, scope='fc3')
+
+  logits = slim.fully_connected(
+      end_points['fc3'], num_classes, activation_fn=None, scope='fc4')
+
+  return logits, end_points
+
+
+def dsn_cropped_linemod(images,
+                        weight_decay=0.0,
+                        prefix='model',
+                        num_classes=11,
+                        batch_norm_params=None,
+                        is_training=False):
+  """Creates the convolutional pose estimation model for Cropped Linemod.
+
+  Args:
+    images: the Cropped Linemod samples, a tensor of size
+      [batch_size, 64, 64, 4].
+    weight_decay: the value for the weight decay coefficient.
+    prefix: name of the model to use when prefixing tags.
+    num_classes: the number of output classes to use.
+    batch_norm_params: a dictionary that maps batch norm parameter names to
+      values.
+    is_training: specifies whether or not we're currently training the model.
+      This variable will determine the behaviour of the dropout layer.
+
+  Returns:
+    the output logits, a tensor of size [batch_size, num_classes].
+    a dictionary with key/values the layer names and tensors.
+  """
+
+  end_points = {}
+
+  tf.summary.image('{}/input_images'.format(prefix), images)
+  with slim.arg_scope(
+      [slim.conv2d, slim.fully_connected],
+      weights_regularizer=slim.l2_regularizer(weight_decay),
+      activation_fn=tf.nn.relu,
+      normalizer_fn=slim.batch_norm if batch_norm_params else None,
+      normalizer_params=batch_norm_params):
+    with slim.arg_scope([slim.conv2d], padding='SAME'):
+      end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
+      end_points['pool1'] = slim.max_pool2d(
+          end_points['conv1'], [2, 2], 2, scope='pool1')
+      end_points['conv2'] = slim.conv2d(
+          end_points['pool1'], 64, [5, 5], scope='conv2')
+      end_points['pool2'] = slim.max_pool2d(
+          end_points['conv2'], [2, 2], 2, scope='pool2')
+      net = slim.flatten(end_points['pool2'])
+      end_points['fc3'] = slim.fully_connected(net, 128, scope='fc3')
+      net = slim.dropout(
+          end_points['fc3'], 0.5, is_training=is_training, scope='dropout')
+
+      with tf.variable_scope('quaternion_prediction'):
+        predicted_quaternion = slim.fully_connected(
+            net, 4, activation_fn=tf.nn.tanh)
+        predicted_quaternion = tf.nn.l2_normalize(predicted_quaternion, 1)
+      logits = slim.fully_connected(
+          net, num_classes, activation_fn=None, scope='fc4')
+  end_points['quaternion_pred'] = predicted_quaternion
+
+  return logits, end_points

+ 167 - 0
domain_adaptation/domain_separation/models_test.py

@@ -0,0 +1,167 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for DSN components."""
+
+import numpy as np
+import tensorflow as tf
+
+#from models.domain_adaptation.domain_separation
+import models
+
+
+class SharedEncodersTest(tf.test.TestCase):
+
+  def _testSharedEncoder(self,
+                         input_shape=[5, 28, 28, 1],
+                         model=models.dann_mnist,
+                         is_training=True):
+    images = tf.to_float(np.random.rand(*input_shape))
+
+    with self.test_session() as sess:
+      logits, _ = model(images)
+      sess.run(tf.global_variables_initializer())
+      logits_np = sess.run(logits)
+    return logits_np
+
+  def testBuildGRLMnistModel(self):
+    logits = self._testSharedEncoder(model=getattr(models,
+                                                   'dann_mnist'))
+    self.assertEqual(logits.shape, (5, 10))
+    self.assertTrue(np.any(logits))
+
+  def testBuildGRLSvhnModel(self):
+    logits = self._testSharedEncoder(model=getattr(models,
+                                                   'dann_svhn'))
+    self.assertEqual(logits.shape, (5, 10))
+    self.assertTrue(np.any(logits))
+
+  def testBuildGRLGtsrbModel(self):
+    logits = self._testSharedEncoder([5, 40, 40, 3],
+                                     getattr(models, 'dann_gtsrb'))
+    self.assertEqual(logits.shape, (5, 43))
+    self.assertTrue(np.any(logits))
+
+  def testBuildPoseModel(self):
+    logits = self._testSharedEncoder([5, 64, 64, 4],
+                                     getattr(models, 'dsn_cropped_linemod'))
+    self.assertEqual(logits.shape, (5, 11))
+    self.assertTrue(np.any(logits))
+
+  def testBuildPoseModelWithBatchNorm(self):
+    images = tf.to_float(np.random.rand(10, 64, 64, 4))
+
+    with self.test_session() as sess:
+      logits, _ = getattr(models, 'dsn_cropped_linemod')(
+          images, batch_norm_params=models.default_batch_norm_params(True))
+      sess.run(tf.global_variables_initializer())
+      logits_np = sess.run(logits)
+    self.assertEqual(logits_np.shape, (10, 11))
+    self.assertTrue(np.any(logits_np))
+
+
+class EncoderTest(tf.test.TestCase):
+
+  def _testEncoder(self, batch_norm_params=None, channels=1):
+    images = tf.to_float(np.random.rand(10, 28, 28, channels))
+
+    with self.test_session() as sess:
+      end_points = models.default_encoder(
+          images, 128, batch_norm_params=batch_norm_params)
+      sess.run(tf.global_variables_initializer())
+      private_code = sess.run(end_points['fc3'])
+    self.assertEqual(private_code.shape, (10, 128))
+    self.assertTrue(np.any(private_code))
+    self.assertTrue(np.all(np.isfinite(private_code)))
+
+  def testEncoder(self):
+    self._testEncoder()
+
+  def testEncoderMultiChannel(self):
+    self._testEncoder(None, 4)
+
+  def testEncoderIsTrainingBatchNorm(self):
+    self._testEncoder(models.default_batch_norm_params(True))
+
+  def testEncoderBatchNorm(self):
+    self._testEncoder(models.default_batch_norm_params(False))
+
+
+class DecoderTest(tf.test.TestCase):
+
+  def _testDecoder(self,
+                   height=64,
+                   width=64,
+                   channels=4,
+                   batch_norm_params=None,
+                   decoder=models.small_decoder):
+    codes = tf.to_float(np.random.rand(32, 100))
+
+    with self.test_session() as sess:
+      output = decoder(
+          codes,
+          height=height,
+          width=width,
+          channels=channels,
+          batch_norm_params=batch_norm_params)
+      sess.run(tf.initialize_all_variables())
+      output_np = sess.run(output)
+    self.assertEqual(output_np.shape, (32, height, width, channels))
+    self.assertTrue(np.any(output_np))
+    self.assertTrue(np.all(np.isfinite(output_np)))
+
+  def testSmallDecoder(self):
+    self._testDecoder(28, 28, 4, None, getattr(models, 'small_decoder'))
+
+  def testSmallDecoderThreeChannels(self):
+    self._testDecoder(28, 28, 3)
+
+  def testSmallDecoderBatchNorm(self):
+    self._testDecoder(28, 28, 4, models.default_batch_norm_params(False))
+
+  def testSmallDecoderIsTrainingBatchNorm(self):
+    self._testDecoder(28, 28, 4, models.default_batch_norm_params(True))
+
+  def testLargeDecoder(self):
+    self._testDecoder(32, 32, 4, None, getattr(models, 'large_decoder'))
+
+  def testLargeDecoderThreeChannels(self):
+    self._testDecoder(32, 32, 3, None, getattr(models, 'large_decoder'))
+
+  def testLargeDecoderBatchNorm(self):
+    self._testDecoder(32, 32, 4,
+                      models.default_batch_norm_params(False),
+                      getattr(models, 'large_decoder'))
+
+  def testLargeDecoderIsTrainingBatchNorm(self):
+    self._testDecoder(32, 32, 4,
+                      models.default_batch_norm_params(True),
+                      getattr(models, 'large_decoder'))
+
+  def testGtsrbDecoder(self):
+    self._testDecoder(40, 40, 3, None, getattr(models, 'large_decoder'))
+
+  def testGtsrbDecoderBatchNorm(self):
+    self._testDecoder(40, 40, 4,
+                      models.default_batch_norm_params(False),
+                      getattr(models, 'gtsrb_decoder'))
+
+  def testGtsrbDecoderIsTrainingBatchNorm(self):
+    self._testDecoder(40, 40, 4,
+                      models.default_batch_norm_params(True),
+                      getattr(models, 'gtsrb_decoder'))
+
+
+if __name__ == '__main__':
+  tf.test.main()

+ 183 - 0
domain_adaptation/domain_separation/utils.py

@@ -0,0 +1,183 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Auxiliary functions for domain adaptation related losses.
+"""
+import math
+import tensorflow as tf
+
+
+def create_summaries(end_points, prefix='', max_images=3, use_op_name=False):
+  """Creates a tf summary per endpoint.
+
+  If the endpoint is a 4 dimensional tensor it displays it as an image
+  otherwise if it is a two dimensional one it creates a histogram summary.
+
+  Args:
+    end_points: a dictionary of name, tf tensor pairs.
+    prefix: an optional string to prefix the summary with.
+    max_images: the maximum number of images to display per summary.
+    use_op_name: Use the op name as opposed to the shorter end_points key.
+  """
+  for layer_name in end_points:
+    if use_op_name:
+      name = end_points[layer_name].op.name
+    else:
+      name = layer_name
+    if len(end_points[layer_name].get_shape().as_list()) == 4:
+      # if it's an actual image do not attempt to reshape it
+      if end_points[layer_name].get_shape().as_list()[-1] == 1 or end_points[
+          layer_name].get_shape().as_list()[-1] == 3:
+        visualization_image = end_points[layer_name]
+      else:
+        visualization_image = reshape_feature_maps(end_points[layer_name])
+      tf.summary.image(
+          '{}/{}'.format(prefix, name),
+          visualization_image,
+          max_outputs=max_images)
+    elif len(end_points[layer_name].get_shape().as_list()) == 3:
+      images = tf.expand_dims(end_points[layer_name], 3)
+      tf.summary.image(
+          '{}/{}'.format(prefix, name),
+          images,
+          max_outputs=max_images)
+    elif len(end_points[layer_name].get_shape().as_list()) == 2:
+      tf.summary.histogram('{}/{}'.format(prefix, name), end_points[layer_name])
+
+
+def reshape_feature_maps(features_tensor):
+  """Reshape activations for tf.summary.image visualization.
+
+  Arguments:
+    features_tensor: a tensor of activations with a square number of feature
+                     maps, eg 4, 9, 16, etc.
+  Returns:
+    A composite image with all the feature maps that can be passed as an
+    argument to tf.summary.image.
+  """
+  assert len(features_tensor.get_shape().as_list()) == 4
+  num_filters = features_tensor.get_shape().as_list()[-1]
+  assert num_filters > 0
+  num_filters_sqrt = math.sqrt(num_filters)
+  assert num_filters_sqrt.is_integer(
+  ), 'Number of filters should be a square number but got {}'.format(
+      num_filters)
+  num_filters_sqrt = int(num_filters_sqrt)
+  conv_summary = tf.unstack(features_tensor, axis=3)
+  conv_one_row = tf.concat(conv_summary[0:num_filters_sqrt], 2)
+  ind = 1
+  conv_final = conv_one_row
+  for ind in range(1, num_filters_sqrt):
+    conv_one_row = tf.concat(conv_summary[
+        ind * num_filters_sqrt + 0:ind * num_filters_sqrt + num_filters_sqrt],
+                             2)
+    conv_final = tf.concat(
+        [tf.squeeze(conv_final), tf.squeeze(conv_one_row)], 1)
+    conv_final = tf.expand_dims(conv_final, -1)
+  return conv_final
+
+
+def accuracy(predictions, labels):
+  """Calculates the classificaton accuracy.
+
+  Args:
+    predictions: the predicted values, a tensor whose size matches 'labels'.
+    labels: the ground truth values, a tensor of any size.
+
+  Returns:
+    a tensor whose value on evaluation returns the total accuracy.
+  """
+  return tf.reduce_mean(tf.cast(tf.equal(predictions, labels), tf.float32))
+
+
+def compute_upsample_values(input_tensor, upsample_height, upsample_width):
+  """Compute values for an upsampling op (ops.BatchCropAndResize).
+
+  Args:
+    input_tensor: image tensor with shape [batch, height, width, in_channels]
+    upsample_height: integer
+    upsample_width: integer
+
+  Returns:
+    grid_centers: tensor with shape [batch, 1]
+    crop_sizes: tensor with shape [batch, 1]
+    output_height: integer
+    output_width: integer
+  """
+  batch, input_height, input_width, _ = input_tensor.shape
+
+  height_half = input_height / 2.
+  width_half = input_width / 2.
+  grid_centers = tf.constant(batch * [[height_half, width_half]])
+  crop_sizes = tf.constant(batch * [[input_height, input_width]])
+  output_height = input_height * upsample_height
+  output_width = input_width * upsample_width
+
+  return grid_centers, tf.to_float(crop_sizes), output_height, output_width
+
+
+def compute_pairwise_distances(x, y):
+  """Computes the squared pairwise Euclidean distances between x and y.
+
+  Args:
+    x: a tensor of shape [num_x_samples, num_features]
+    y: a tensor of shape [num_y_samples, num_features]
+
+  Returns:
+    a distance matrix of dimensions [num_x_samples, num_y_samples].
+
+  Raises:
+    ValueError: if the inputs do no matched the specified dimensions.
+  """
+
+  if not len(x.get_shape()) == len(y.get_shape()) == 2:
+    raise ValueError('Both inputs should be matrices.')
+
+  if x.get_shape().as_list()[1] != y.get_shape().as_list()[1]:
+    raise ValueError('The number of features should be the same.')
+
+  norm = lambda x: tf.reduce_sum(tf.square(x), 1)
+
+  # By making the `inner' dimensions of the two matrices equal to 1 using
+  # broadcasting then we are essentially substracting every pair of rows
+  # of x and y.
+  # x will be num_samples x num_features x 1,
+  # and y will be 1 x num_features x num_samples (after broadcasting).
+  # After the substraction we will get a
+  # num_x_samples x num_features x num_y_samples matrix.
+  # The resulting dist will be of shape num_y_samples x num_x_samples.
+  # and thus we need to transpose it again.
+  return tf.transpose(norm(tf.expand_dims(x, 2) - tf.transpose(y)))
+
+
+def gaussian_kernel_matrix(x, y, sigmas):
+  r"""Computes a Guassian Radial Basis Kernel between the samples of x and y.
+
+  We create a sum of multiple gaussian kernels each having a width sigma_i.
+
+  Args:
+    x: a tensor of shape [num_samples, num_features]
+    y: a tensor of shape [num_samples, num_features]
+    sigmas: a tensor of floats which denote the widths of each of the
+      gaussians in the kernel.
+  Returns:
+    A tensor of shape [num_samples{x}, num_samples{y}] with the RBF kernel.
+  """
+  beta = 1. / (2. * (tf.expand_dims(sigmas, 1)))
+
+  dist = compute_pairwise_distances(x, y)
+
+  s = tf.matmul(beta, tf.reshape(dist, (1, -1)))
+
+  return tf.reshape(tf.reduce_sum(tf.exp(-s), 0), tf.shape(dist))