resnet_model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """ResNet model.
  16. Related papers:
  17. https://arxiv.org/pdf/1603.05027v2.pdf
  18. https://arxiv.org/pdf/1512.03385v1.pdf
  19. https://arxiv.org/pdf/1605.07146v1.pdf
  20. """
  21. from collections import namedtuple
  22. import numpy as np
  23. import tensorflow as tf
  24. import six
  25. from tensorflow.python.training import moving_averages
  26. HParams = namedtuple('HParams',
  27. 'batch_size, num_classes, min_lrn_rate, lrn_rate, '
  28. 'num_residual_units, use_bottleneck, weight_decay_rate, '
  29. 'relu_leakiness, optimizer')
  30. class ResNet(object):
  31. """ResNet model."""
  32. def __init__(self, hps, images, labels, mode):
  33. """ResNet constructor.
  34. Args:
  35. hps: Hyperparameters.
  36. images: Batches of images. [batch_size, image_size, image_size, 3]
  37. labels: Batches of labels. [batch_size, num_classes]
  38. mode: One of 'train' and 'eval'.
  39. """
  40. self.hps = hps
  41. self._images = images
  42. self.labels = labels
  43. self.mode = mode
  44. self._extra_train_ops = []
  45. def build_graph(self):
  46. """Build a whole graph for the model."""
  47. self.global_step = tf.contrib.framework.get_or_create_global_step()
  48. self._build_model()
  49. if self.mode == 'train':
  50. self._build_train_op()
  51. self.summaries = tf.summary.merge_all()
  52. def _stride_arr(self, stride):
  53. """Map a stride scalar to the stride array for tf.nn.conv2d."""
  54. return [1, stride, stride, 1]
  55. def _build_model(self):
  56. """Build the core model within the graph."""
  57. with tf.variable_scope('init'):
  58. x = self._images
  59. x = self._conv('init_conv', x, 3, 3, 16, self._stride_arr(1))
  60. strides = [1, 2, 2]
  61. activate_before_residual = [True, False, False]
  62. if self.hps.use_bottleneck:
  63. res_func = self._bottleneck_residual
  64. filters = [16, 64, 128, 256]
  65. else:
  66. res_func = self._residual
  67. filters = [16, 16, 32, 64]
  68. # Uncomment the following codes to use w28-10 wide residual network.
  69. # It is more memory efficient than very deep residual network and has
  70. # comparably good performance.
  71. # https://arxiv.org/pdf/1605.07146v1.pdf
  72. # filters = [16, 160, 320, 640]
  73. # Update hps.num_residual_units to 9
  74. with tf.variable_scope('unit_1_0'):
  75. x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]),
  76. activate_before_residual[0])
  77. for i in six.moves.range(1, self.hps.num_residual_units):
  78. with tf.variable_scope('unit_1_%d' % i):
  79. x = res_func(x, filters[1], filters[1], self._stride_arr(1), False)
  80. with tf.variable_scope('unit_2_0'):
  81. x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]),
  82. activate_before_residual[1])
  83. for i in six.moves.range(1, self.hps.num_residual_units):
  84. with tf.variable_scope('unit_2_%d' % i):
  85. x = res_func(x, filters[2], filters[2], self._stride_arr(1), False)
  86. with tf.variable_scope('unit_3_0'):
  87. x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]),
  88. activate_before_residual[2])
  89. for i in six.moves.range(1, self.hps.num_residual_units):
  90. with tf.variable_scope('unit_3_%d' % i):
  91. x = res_func(x, filters[3], filters[3], self._stride_arr(1), False)
  92. with tf.variable_scope('unit_last'):
  93. x = self._batch_norm('final_bn', x)
  94. x = self._relu(x, self.hps.relu_leakiness)
  95. x = self._global_avg_pool(x)
  96. with tf.variable_scope('logit'):
  97. logits = self._fully_connected(x, self.hps.num_classes)
  98. self.predictions = tf.nn.softmax(logits)
  99. with tf.variable_scope('costs'):
  100. xent = tf.nn.softmax_cross_entropy_with_logits(
  101. logits=logits, labels=self.labels)
  102. self.cost = tf.reduce_mean(xent, name='xent')
  103. self.cost += self._decay()
  104. tf.summary.scalar('cost', self.cost)
  105. def _build_train_op(self):
  106. """Build training specific ops for the graph."""
  107. self.lrn_rate = tf.constant(self.hps.lrn_rate, tf.float32)
  108. tf.summary.scalar('learning_rate', self.lrn_rate)
  109. trainable_variables = tf.trainable_variables()
  110. grads = tf.gradients(self.cost, trainable_variables)
  111. if self.hps.optimizer == 'sgd':
  112. optimizer = tf.train.GradientDescentOptimizer(self.lrn_rate)
  113. elif self.hps.optimizer == 'mom':
  114. optimizer = tf.train.MomentumOptimizer(self.lrn_rate, 0.9)
  115. apply_op = optimizer.apply_gradients(
  116. zip(grads, trainable_variables),
  117. global_step=self.global_step, name='train_step')
  118. train_ops = [apply_op] + self._extra_train_ops
  119. self.train_op = tf.group(*train_ops)
  120. # TODO(xpan): Consider batch_norm in contrib/layers/python/layers/layers.py
  121. def _batch_norm(self, name, x):
  122. """Batch normalization."""
  123. with tf.variable_scope(name):
  124. params_shape = [x.get_shape()[-1]]
  125. beta = tf.get_variable(
  126. 'beta', params_shape, tf.float32,
  127. initializer=tf.constant_initializer(0.0, tf.float32))
  128. gamma = tf.get_variable(
  129. 'gamma', params_shape, tf.float32,
  130. initializer=tf.constant_initializer(1.0, tf.float32))
  131. if self.mode == 'train':
  132. mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments')
  133. moving_mean = tf.get_variable(
  134. 'moving_mean', params_shape, tf.float32,
  135. initializer=tf.constant_initializer(0.0, tf.float32),
  136. trainable=False)
  137. moving_variance = tf.get_variable(
  138. 'moving_variance', params_shape, tf.float32,
  139. initializer=tf.constant_initializer(1.0, tf.float32),
  140. trainable=False)
  141. self._extra_train_ops.append(moving_averages.assign_moving_average(
  142. moving_mean, mean, 0.9))
  143. self._extra_train_ops.append(moving_averages.assign_moving_average(
  144. moving_variance, variance, 0.9))
  145. else:
  146. mean = tf.get_variable(
  147. 'moving_mean', params_shape, tf.float32,
  148. initializer=tf.constant_initializer(0.0, tf.float32),
  149. trainable=False)
  150. variance = tf.get_variable(
  151. 'moving_variance', params_shape, tf.float32,
  152. initializer=tf.constant_initializer(1.0, tf.float32),
  153. trainable=False)
  154. tf.summary.histogram(mean.op.name, mean)
  155. tf.summary.histogram(variance.op.name, variance)
  156. # elipson used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net.
  157. y = tf.nn.batch_normalization(
  158. x, mean, variance, beta, gamma, 0.001)
  159. y.set_shape(x.get_shape())
  160. return y
  161. def _residual(self, x, in_filter, out_filter, stride,
  162. activate_before_residual=False):
  163. """Residual unit with 2 sub layers."""
  164. if activate_before_residual:
  165. with tf.variable_scope('shared_activation'):
  166. x = self._batch_norm('init_bn', x)
  167. x = self._relu(x, self.hps.relu_leakiness)
  168. orig_x = x
  169. else:
  170. with tf.variable_scope('residual_only_activation'):
  171. orig_x = x
  172. x = self._batch_norm('init_bn', x)
  173. x = self._relu(x, self.hps.relu_leakiness)
  174. with tf.variable_scope('sub1'):
  175. x = self._conv('conv1', x, 3, in_filter, out_filter, stride)
  176. with tf.variable_scope('sub2'):
  177. x = self._batch_norm('bn2', x)
  178. x = self._relu(x, self.hps.relu_leakiness)
  179. x = self._conv('conv2', x, 3, out_filter, out_filter, [1, 1, 1, 1])
  180. with tf.variable_scope('sub_add'):
  181. if in_filter != out_filter:
  182. orig_x = tf.nn.avg_pool(orig_x, stride, stride, 'VALID')
  183. orig_x = tf.pad(
  184. orig_x, [[0, 0], [0, 0], [0, 0],
  185. [(out_filter-in_filter)//2, (out_filter-in_filter)//2]])
  186. x += orig_x
  187. tf.logging.debug('image after unit %s', x.get_shape())
  188. return x
  189. def _bottleneck_residual(self, x, in_filter, out_filter, stride,
  190. activate_before_residual=False):
  191. """Bottleneck residual unit with 3 sub layers."""
  192. if activate_before_residual:
  193. with tf.variable_scope('common_bn_relu'):
  194. x = self._batch_norm('init_bn', x)
  195. x = self._relu(x, self.hps.relu_leakiness)
  196. orig_x = x
  197. else:
  198. with tf.variable_scope('residual_bn_relu'):
  199. orig_x = x
  200. x = self._batch_norm('init_bn', x)
  201. x = self._relu(x, self.hps.relu_leakiness)
  202. with tf.variable_scope('sub1'):
  203. x = self._conv('conv1', x, 1, in_filter, out_filter/4, stride)
  204. with tf.variable_scope('sub2'):
  205. x = self._batch_norm('bn2', x)
  206. x = self._relu(x, self.hps.relu_leakiness)
  207. x = self._conv('conv2', x, 3, out_filter/4, out_filter/4, [1, 1, 1, 1])
  208. with tf.variable_scope('sub3'):
  209. x = self._batch_norm('bn3', x)
  210. x = self._relu(x, self.hps.relu_leakiness)
  211. x = self._conv('conv3', x, 1, out_filter/4, out_filter, [1, 1, 1, 1])
  212. with tf.variable_scope('sub_add'):
  213. if in_filter != out_filter:
  214. orig_x = self._conv('project', orig_x, 1, in_filter, out_filter, stride)
  215. x += orig_x
  216. tf.logging.info('image after unit %s', x.get_shape())
  217. return x
  218. def _decay(self):
  219. """L2 weight decay loss."""
  220. costs = []
  221. for var in tf.trainable_variables():
  222. if var.op.name.find(r'DW') > 0:
  223. costs.append(tf.nn.l2_loss(var))
  224. # tf.summary.histogram(var.op.name, var)
  225. return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs))
  226. def _conv(self, name, x, filter_size, in_filters, out_filters, strides):
  227. """Convolution."""
  228. with tf.variable_scope(name):
  229. n = filter_size * filter_size * out_filters
  230. kernel = tf.get_variable(
  231. 'DW', [filter_size, filter_size, in_filters, out_filters],
  232. tf.float32, initializer=tf.random_normal_initializer(
  233. stddev=np.sqrt(2.0/n)))
  234. return tf.nn.conv2d(x, kernel, strides, padding='SAME')
  235. def _relu(self, x, leakiness=0.0):
  236. """Relu, with optional leaky support."""
  237. return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
  238. def _fully_connected(self, x, out_dim):
  239. """FullyConnected layer for final output."""
  240. x = tf.reshape(x, [self.hps.batch_size, -1])
  241. w = tf.get_variable(
  242. 'DW', [x.get_shape()[1], out_dim],
  243. initializer=tf.uniform_unit_scaling_initializer(factor=1.0))
  244. b = tf.get_variable('biases', [out_dim],
  245. initializer=tf.constant_initializer())
  246. return tf.nn.xw_plus_b(x, w, b)
  247. def _global_avg_pool(self, x):
  248. assert x.get_shape().ndims == 4
  249. return tf.reduce_mean(x, [1, 2])