models.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains different architectures for the different DSN parts.
  16. We define here the modules that can be used in the different parts of the DSN
  17. model.
  18. - shared encoder (dsn_cropped_linemod, dann_xxxx)
  19. - private encoder (default_encoder)
  20. - decoder (large_decoder, gtsrb_decoder, small_decoder)
  21. """
  22. import tensorflow as tf
  23. #from models.domain_adaptation.domain_separation
  24. import utils
  25. slim = tf.contrib.slim
  26. def default_batch_norm_params(is_training=False):
  27. """Returns default batch normalization parameters for DSNs.
  28. Args:
  29. is_training: whether or not the model is training.
  30. Returns:
  31. a dictionary that maps batch norm parameter names (strings) to values.
  32. """
  33. return {
  34. # Decay for the moving averages.
  35. 'decay': 0.5,
  36. # epsilon to prevent 0s in variance.
  37. 'epsilon': 0.001,
  38. 'is_training': is_training
  39. }
  40. ################################################################################
  41. # PRIVATE ENCODERS
  42. ################################################################################
  43. def default_encoder(images, code_size, batch_norm_params=None,
  44. weight_decay=0.0):
  45. """Encodes the given images to codes of the given size.
  46. Args:
  47. images: a tensor of size [batch_size, height, width, 1].
  48. code_size: the number of hidden units in the code layer of the classifier.
  49. batch_norm_params: a dictionary that maps batch norm parameter names to
  50. values.
  51. weight_decay: the value for the weight decay coefficient.
  52. Returns:
  53. end_points: the code of the input.
  54. """
  55. end_points = {}
  56. with slim.arg_scope(
  57. [slim.conv2d, slim.fully_connected],
  58. weights_regularizer=slim.l2_regularizer(weight_decay),
  59. activation_fn=tf.nn.relu,
  60. normalizer_fn=slim.batch_norm,
  61. normalizer_params=batch_norm_params):
  62. with slim.arg_scope([slim.conv2d], kernel_size=[5, 5], padding='SAME'):
  63. net = slim.conv2d(images, 32, scope='conv1')
  64. net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
  65. net = slim.conv2d(net, 64, scope='conv2')
  66. net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
  67. net = slim.flatten(net)
  68. end_points['flatten'] = net
  69. net = slim.fully_connected(net, code_size, scope='fc1')
  70. end_points['fc3'] = net
  71. return end_points
  72. ################################################################################
  73. # DECODERS
  74. ################################################################################
  75. def large_decoder(codes,
  76. height,
  77. width,
  78. channels,
  79. batch_norm_params=None,
  80. weight_decay=0.0):
  81. """Decodes the codes to a fixed output size.
  82. Args:
  83. codes: a tensor of size [batch_size, code_size].
  84. height: the height of the output images.
  85. width: the width of the output images.
  86. channels: the number of the output channels.
  87. batch_norm_params: a dictionary that maps batch norm parameter names to
  88. values.
  89. weight_decay: the value for the weight decay coefficient.
  90. Returns:
  91. recons: the reconstruction tensor of shape [batch_size, height, width, 3].
  92. """
  93. with slim.arg_scope(
  94. [slim.conv2d, slim.fully_connected],
  95. weights_regularizer=slim.l2_regularizer(weight_decay),
  96. activation_fn=tf.nn.relu,
  97. normalizer_fn=slim.batch_norm,
  98. normalizer_params=batch_norm_params):
  99. net = slim.fully_connected(codes, 600, scope='fc1')
  100. batch_size = net.get_shape().as_list()[0]
  101. net = tf.reshape(net, [batch_size, 10, 10, 6])
  102. net = slim.conv2d(net, 32, [5, 5], scope='conv1_1')
  103. net = tf.image.resize_nearest_neighbor(net, (16, 16))
  104. net = slim.conv2d(net, 32, [5, 5], scope='conv2_1')
  105. net = tf.image.resize_nearest_neighbor(net, (32, 32))
  106. net = slim.conv2d(net, 32, [5, 5], scope='conv3_2')
  107. output_size = [height, width]
  108. net = tf.image.resize_nearest_neighbor(net, output_size)
  109. with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
  110. net = slim.conv2d(net, channels, activation_fn=None, scope='conv4_1')
  111. return net
  112. def gtsrb_decoder(codes,
  113. height,
  114. width,
  115. channels,
  116. batch_norm_params=None,
  117. weight_decay=0.0):
  118. """Decodes the codes to a fixed output size. This decoder is specific to GTSRB
  119. Args:
  120. codes: a tensor of size [batch_size, 100].
  121. height: the height of the output images.
  122. width: the width of the output images.
  123. channels: the number of the output channels.
  124. batch_norm_params: a dictionary that maps batch norm parameter names to
  125. values.
  126. weight_decay: the value for the weight decay coefficient.
  127. Returns:
  128. recons: the reconstruction tensor of shape [batch_size, height, width, 3].
  129. Raises:
  130. ValueError: When the input code size is not 100.
  131. """
  132. batch_size, code_size = codes.get_shape().as_list()
  133. if code_size != 100:
  134. raise ValueError('The code size used as an input to the GTSRB decoder is '
  135. 'expected to be 100.')
  136. with slim.arg_scope(
  137. [slim.conv2d, slim.fully_connected],
  138. weights_regularizer=slim.l2_regularizer(weight_decay),
  139. activation_fn=tf.nn.relu,
  140. normalizer_fn=slim.batch_norm,
  141. normalizer_params=batch_norm_params):
  142. net = codes
  143. net = tf.reshape(net, [batch_size, 10, 10, 1])
  144. net = slim.conv2d(net, 32, [3, 3], scope='conv1_1')
  145. # First upsampling 20x20
  146. net = tf.image.resize_nearest_neighbor(net, [20, 20])
  147. net = slim.conv2d(net, 32, [3, 3], scope='conv2_1')
  148. output_size = [height, width]
  149. # Final upsampling 40 x 40
  150. net = tf.image.resize_nearest_neighbor(net, output_size)
  151. with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
  152. net = slim.conv2d(net, 16, scope='conv3_1')
  153. net = slim.conv2d(net, channels, activation_fn=None, scope='conv3_2')
  154. return net
  155. def small_decoder(codes,
  156. height,
  157. width,
  158. channels,
  159. batch_norm_params=None,
  160. weight_decay=0.0):
  161. """Decodes the codes to a fixed output size.
  162. Args:
  163. codes: a tensor of size [batch_size, code_size].
  164. height: the height of the output images.
  165. width: the width of the output images.
  166. channels: the number of the output channels.
  167. batch_norm_params: a dictionary that maps batch norm parameter names to
  168. values.
  169. weight_decay: the value for the weight decay coefficient.
  170. Returns:
  171. recons: the reconstruction tensor of shape [batch_size, height, width, 3].
  172. """
  173. with slim.arg_scope(
  174. [slim.conv2d, slim.fully_connected],
  175. weights_regularizer=slim.l2_regularizer(weight_decay),
  176. activation_fn=tf.nn.relu,
  177. normalizer_fn=slim.batch_norm,
  178. normalizer_params=batch_norm_params):
  179. net = slim.fully_connected(codes, 300, scope='fc1')
  180. batch_size = net.get_shape().as_list()[0]
  181. net = tf.reshape(net, [batch_size, 10, 10, 3])
  182. net = slim.conv2d(net, 16, [3, 3], scope='conv1_1')
  183. net = slim.conv2d(net, 16, [3, 3], scope='conv1_2')
  184. output_size = [height, width]
  185. net = tf.image.resize_nearest_neighbor(net, output_size)
  186. with slim.arg_scope([slim.conv2d], kernel_size=[3, 3]):
  187. net = slim.conv2d(net, 16, scope='conv2_1')
  188. net = slim.conv2d(net, channels, activation_fn=None, scope='conv2_2')
  189. return net
  190. ################################################################################
  191. # SHARED ENCODERS
  192. ################################################################################
  193. def dann_mnist(images,
  194. weight_decay=0.0,
  195. prefix='model',
  196. num_classes=10,
  197. **kwargs):
  198. """Creates a convolution MNIST model.
  199. Note that this model implements the architecture for MNIST proposed in:
  200. Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
  201. JMLR 2015
  202. Args:
  203. images: the MNIST digits, a tensor of size [batch_size, 28, 28, 1].
  204. weight_decay: the value for the weight decay coefficient.
  205. prefix: name of the model to use when prefixing tags.
  206. num_classes: the number of output classes to use.
  207. **kwargs: Placeholder for keyword arguments used by other shared encoders.
  208. Returns:
  209. the output logits, a tensor of size [batch_size, num_classes].
  210. a dictionary with key/values the layer names and tensors.
  211. """
  212. end_points = {}
  213. with slim.arg_scope(
  214. [slim.conv2d, slim.fully_connected],
  215. weights_regularizer=slim.l2_regularizer(weight_decay),
  216. activation_fn=tf.nn.relu,):
  217. with slim.arg_scope([slim.conv2d], padding='SAME'):
  218. end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
  219. end_points['pool1'] = slim.max_pool2d(
  220. end_points['conv1'], [2, 2], 2, scope='pool1')
  221. end_points['conv2'] = slim.conv2d(
  222. end_points['pool1'], 48, [5, 5], scope='conv2')
  223. end_points['pool2'] = slim.max_pool2d(
  224. end_points['conv2'], [2, 2], 2, scope='pool2')
  225. end_points['fc3'] = slim.fully_connected(
  226. slim.flatten(end_points['pool2']), 100, scope='fc3')
  227. end_points['fc4'] = slim.fully_connected(
  228. slim.flatten(end_points['fc3']), 100, scope='fc4')
  229. logits = slim.fully_connected(
  230. end_points['fc4'], num_classes, activation_fn=None, scope='fc5')
  231. return logits, end_points
  232. def dann_svhn(images,
  233. weight_decay=0.0,
  234. prefix='model',
  235. num_classes=10,
  236. **kwargs):
  237. """Creates the convolutional SVHN model.
  238. Note that this model implements the architecture for MNIST proposed in:
  239. Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
  240. JMLR 2015
  241. Args:
  242. images: the SVHN digits, a tensor of size [batch_size, 32, 32, 3].
  243. weight_decay: the value for the weight decay coefficient.
  244. prefix: name of the model to use when prefixing tags.
  245. num_classes: the number of output classes to use.
  246. **kwargs: Placeholder for keyword arguments used by other shared encoders.
  247. Returns:
  248. the output logits, a tensor of size [batch_size, num_classes].
  249. a dictionary with key/values the layer names and tensors.
  250. """
  251. end_points = {}
  252. with slim.arg_scope(
  253. [slim.conv2d, slim.fully_connected],
  254. weights_regularizer=slim.l2_regularizer(weight_decay),
  255. activation_fn=tf.nn.relu,):
  256. with slim.arg_scope([slim.conv2d], padding='SAME'):
  257. end_points['conv1'] = slim.conv2d(images, 64, [5, 5], scope='conv1')
  258. end_points['pool1'] = slim.max_pool2d(
  259. end_points['conv1'], [3, 3], 2, scope='pool1')
  260. end_points['conv2'] = slim.conv2d(
  261. end_points['pool1'], 64, [5, 5], scope='conv2')
  262. end_points['pool2'] = slim.max_pool2d(
  263. end_points['conv2'], [3, 3], 2, scope='pool2')
  264. end_points['conv3'] = slim.conv2d(
  265. end_points['pool2'], 128, [5, 5], scope='conv3')
  266. end_points['fc3'] = slim.fully_connected(
  267. slim.flatten(end_points['conv3']), 3072, scope='fc3')
  268. end_points['fc4'] = slim.fully_connected(
  269. slim.flatten(end_points['fc3']), 2048, scope='fc4')
  270. logits = slim.fully_connected(
  271. end_points['fc4'], num_classes, activation_fn=None, scope='fc5')
  272. return logits, end_points
  273. def dann_gtsrb(images,
  274. weight_decay=0.0,
  275. prefix='model',
  276. num_classes=43,
  277. **kwargs):
  278. """Creates the convolutional GTSRB model.
  279. Note that this model implements the architecture for MNIST proposed in:
  280. Y. Ganin et al., Domain-Adversarial Training of Neural Networks (DANN),
  281. JMLR 2015
  282. Args:
  283. images: the GTSRB images, a tensor of size [batch_size, 40, 40, 3].
  284. weight_decay: the value for the weight decay coefficient.
  285. prefix: name of the model to use when prefixing tags.
  286. num_classes: the number of output classes to use.
  287. **kwargs: Placeholder for keyword arguments used by other shared encoders.
  288. Returns:
  289. the output logits, a tensor of size [batch_size, num_classes].
  290. a dictionary with key/values the layer names and tensors.
  291. """
  292. end_points = {}
  293. with slim.arg_scope(
  294. [slim.conv2d, slim.fully_connected],
  295. weights_regularizer=slim.l2_regularizer(weight_decay),
  296. activation_fn=tf.nn.relu,):
  297. with slim.arg_scope([slim.conv2d], padding='SAME'):
  298. end_points['conv1'] = slim.conv2d(images, 96, [5, 5], scope='conv1')
  299. end_points['pool1'] = slim.max_pool2d(
  300. end_points['conv1'], [2, 2], 2, scope='pool1')
  301. end_points['conv2'] = slim.conv2d(
  302. end_points['pool1'], 144, [3, 3], scope='conv2')
  303. end_points['pool2'] = slim.max_pool2d(
  304. end_points['conv2'], [2, 2], 2, scope='pool2')
  305. end_points['conv3'] = slim.conv2d(
  306. end_points['pool2'], 256, [5, 5], scope='conv3')
  307. end_points['pool3'] = slim.max_pool2d(
  308. end_points['conv3'], [2, 2], 2, scope='pool3')
  309. end_points['fc3'] = slim.fully_connected(
  310. slim.flatten(end_points['pool3']), 512, scope='fc3')
  311. logits = slim.fully_connected(
  312. end_points['fc3'], num_classes, activation_fn=None, scope='fc4')
  313. return logits, end_points
  314. def dsn_cropped_linemod(images,
  315. weight_decay=0.0,
  316. prefix='model',
  317. num_classes=11,
  318. batch_norm_params=None,
  319. is_training=False):
  320. """Creates the convolutional pose estimation model for Cropped Linemod.
  321. Args:
  322. images: the Cropped Linemod samples, a tensor of size
  323. [batch_size, 64, 64, 4].
  324. weight_decay: the value for the weight decay coefficient.
  325. prefix: name of the model to use when prefixing tags.
  326. num_classes: the number of output classes to use.
  327. batch_norm_params: a dictionary that maps batch norm parameter names to
  328. values.
  329. is_training: specifies whether or not we're currently training the model.
  330. This variable will determine the behaviour of the dropout layer.
  331. Returns:
  332. the output logits, a tensor of size [batch_size, num_classes].
  333. a dictionary with key/values the layer names and tensors.
  334. """
  335. end_points = {}
  336. tf.summary.image('{}/input_images'.format(prefix), images)
  337. with slim.arg_scope(
  338. [slim.conv2d, slim.fully_connected],
  339. weights_regularizer=slim.l2_regularizer(weight_decay),
  340. activation_fn=tf.nn.relu,
  341. normalizer_fn=slim.batch_norm if batch_norm_params else None,
  342. normalizer_params=batch_norm_params):
  343. with slim.arg_scope([slim.conv2d], padding='SAME'):
  344. end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1')
  345. end_points['pool1'] = slim.max_pool2d(
  346. end_points['conv1'], [2, 2], 2, scope='pool1')
  347. end_points['conv2'] = slim.conv2d(
  348. end_points['pool1'], 64, [5, 5], scope='conv2')
  349. end_points['pool2'] = slim.max_pool2d(
  350. end_points['conv2'], [2, 2], 2, scope='pool2')
  351. net = slim.flatten(end_points['pool2'])
  352. end_points['fc3'] = slim.fully_connected(net, 128, scope='fc3')
  353. net = slim.dropout(
  354. end_points['fc3'], 0.5, is_training=is_training, scope='dropout')
  355. with tf.variable_scope('quaternion_prediction'):
  356. predicted_quaternion = slim.fully_connected(
  357. net, 4, activation_fn=tf.nn.tanh)
  358. predicted_quaternion = tf.nn.l2_normalize(predicted_quaternion, 1)
  359. logits = slim.fully_connected(
  360. net, num_classes, activation_fn=None, scope='fc4')
  361. end_points['quaternion_pred'] = predicted_quaternion
  362. return logits, end_points