inception_model.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Build the Inception v3 network on ImageNet data set.
  16. The Inception v3 architecture is described in http://arxiv.org/abs/1512.00567
  17. Summary of available functions:
  18. inference: Compute inference on the model inputs to make a prediction
  19. loss: Compute the loss of the prediction with respect to the labels
  20. """
  21. from __future__ import absolute_import
  22. from __future__ import division
  23. from __future__ import print_function
  24. import re
  25. import tensorflow as tf
  26. from inception.slim import slim
  27. FLAGS = tf.app.flags.FLAGS
  28. # If a model is trained using multiple GPUs, prefix all Op names with tower_name
  29. # to differentiate the operations. Note that this prefix is removed from the
  30. # names of the summaries when visualizing a model.
  31. TOWER_NAME = 'tower'
  32. # Batch normalization. Constant governing the exponential moving average of
  33. # the 'global' mean and variance for all activations.
  34. BATCHNORM_MOVING_AVERAGE_DECAY = 0.9997
  35. # The decay to use for the moving average.
  36. MOVING_AVERAGE_DECAY = 0.9999
  37. def inference(images, num_classes, for_training=False, restore_logits=True,
  38. scope=None):
  39. """Build Inception v3 model architecture.
  40. See here for reference: http://arxiv.org/abs/1512.00567
  41. Args:
  42. images: Images returned from inputs() or distorted_inputs().
  43. num_classes: number of classes
  44. for_training: If set to `True`, build the inference model for training.
  45. Kernels that operate differently for inference during training
  46. e.g. dropout, are appropriately configured.
  47. restore_logits: whether or not the logits layers should be restored.
  48. Useful for fine-tuning a model with different num_classes.
  49. scope: optional prefix string identifying the ImageNet tower.
  50. Returns:
  51. Logits. 2-D float Tensor.
  52. Auxiliary Logits. 2-D float Tensor of side-head. Used for training only.
  53. """
  54. # Parameters for BatchNorm.
  55. batch_norm_params = {
  56. # Decay for the moving averages.
  57. 'decay': BATCHNORM_MOVING_AVERAGE_DECAY,
  58. # epsilon to prevent 0s in variance.
  59. 'epsilon': 0.001,
  60. }
  61. # Set weight_decay for weights in Conv and FC layers.
  62. with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
  63. with slim.arg_scope([slim.ops.conv2d],
  64. stddev=0.1,
  65. activation=tf.nn.relu,
  66. batch_norm_params=batch_norm_params):
  67. logits, endpoints = slim.inception.inception_v3(
  68. images,
  69. dropout_keep_prob=0.8,
  70. num_classes=num_classes,
  71. is_training=for_training,
  72. restore_logits=restore_logits,
  73. scope=scope)
  74. # Add summaries for viewing model statistics on TensorBoard.
  75. _activation_summaries(endpoints)
  76. # Grab the logits associated with the side head. Employed during training.
  77. auxiliary_logits = endpoints['aux_logits']
  78. return logits, auxiliary_logits
  79. def loss(logits, labels, batch_size=None):
  80. """Adds all losses for the model.
  81. Note the final loss is not returned. Instead, the list of losses are collected
  82. by slim.losses. The losses are accumulated in tower_loss() and summed to
  83. calculate the total loss.
  84. Args:
  85. logits: List of logits from inference(). Each entry is a 2-D float Tensor.
  86. labels: Labels from distorted_inputs or inputs(). 1-D tensor
  87. of shape [batch_size]
  88. batch_size: integer
  89. """
  90. if not batch_size:
  91. batch_size = FLAGS.batch_size
  92. # Reshape the labels into a dense Tensor of
  93. # shape [FLAGS.batch_size, num_classes].
  94. sparse_labels = tf.reshape(labels, [batch_size, 1])
  95. indices = tf.reshape(tf.range(batch_size), [batch_size, 1])
  96. concated = tf.concat(1, [indices, sparse_labels])
  97. num_classes = logits[0].get_shape()[-1].value
  98. dense_labels = tf.sparse_to_dense(concated,
  99. [batch_size, num_classes],
  100. 1.0, 0.0)
  101. # Cross entropy loss for the main softmax prediction.
  102. slim.losses.cross_entropy_loss(logits[0],
  103. dense_labels,
  104. label_smoothing=0.1,
  105. weight=1.0)
  106. # Cross entropy loss for the auxiliary softmax head.
  107. slim.losses.cross_entropy_loss(logits[1],
  108. dense_labels,
  109. label_smoothing=0.1,
  110. weight=0.4,
  111. scope='aux_loss')
  112. def _activation_summary(x):
  113. """Helper to create summaries for activations.
  114. Creates a summary that provides a histogram of activations.
  115. Creates a summary that measure the sparsity of activations.
  116. Args:
  117. x: Tensor
  118. """
  119. # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
  120. # session. This helps the clarity of presentation on tensorboard.
  121. tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
  122. tf.histogram_summary(tensor_name + '/activations', x)
  123. tf.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
  124. def _activation_summaries(endpoints):
  125. with tf.name_scope('summaries'):
  126. for act in endpoints.values():
  127. _activation_summary(act)