inception_model.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Build the Inception v3 network on ImageNet data set.
  16. The Inception v3 architecture is described in http://arxiv.org/abs/1512.00567
  17. Summary of available functions:
  18. inference: Compute inference on the model inputs to make a prediction
  19. loss: Compute the loss of the prediction with respect to the labels
  20. """
  21. from __future__ import absolute_import
  22. from __future__ import division
  23. from __future__ import print_function
  24. import re
  25. import tensorflow as tf
  26. from inception.slim import slim
  27. FLAGS = tf.app.flags.FLAGS
  28. # If a model is trained using multiple GPUs, prefix all Op names with tower_name
  29. # to differentiate the operations. Note that this prefix is removed from the
  30. # names of the summaries when visualizing a model.
  31. TOWER_NAME = 'tower'
  32. # Batch normalization. Constant governing the exponential moving average of
  33. # the 'global' mean and variance for all activations.
  34. BATCHNORM_MOVING_AVERAGE_DECAY = 0.9997
  35. # The decay to use for the moving average.
  36. MOVING_AVERAGE_DECAY = 0.9999
  37. def inference(images, num_classes, for_training=False, restore_logits=True,
  38. scope=None):
  39. """Build Inception v3 model architecture.
  40. See here for reference: http://arxiv.org/abs/1512.00567
  41. Args:
  42. images: Images returned from inputs() or distorted_inputs().
  43. num_classes: number of classes
  44. for_training: If set to `True`, build the inference model for training.
  45. Kernels that operate differently for inference during training
  46. e.g. dropout, are appropriately configured.
  47. restore_logits: whether or not the logits layers should be restored.
  48. Useful for fine-tuning a model with different num_classes.
  49. scope: optional prefix string identifying the ImageNet tower.
  50. Returns:
  51. Logits. 2-D float Tensor.
  52. Auxiliary Logits. 2-D float Tensor of side-head. Used for training only.
  53. """
  54. # Parameters for BatchNorm.
  55. batch_norm_params = {
  56. # Decay for the moving averages.
  57. 'decay': BATCHNORM_MOVING_AVERAGE_DECAY,
  58. # epsilon to prevent 0s in variance.
  59. 'epsilon': 0.001,
  60. }
  61. # Set weight_decay for weights in Conv and FC layers.
  62. with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
  63. with slim.arg_scope([slim.ops.conv2d],
  64. stddev=0.1,
  65. activation=tf.nn.relu,
  66. batch_norm_params=batch_norm_params):
  67. # Force all Variables to reside on the CPU.
  68. with slim.arg_scope([slim.variables.variable], device='/cpu:0'):
  69. logits, endpoints = slim.inception.inception_v3(
  70. images,
  71. dropout_keep_prob=0.8,
  72. num_classes=num_classes,
  73. is_training=for_training,
  74. restore_logits=restore_logits,
  75. scope=scope)
  76. # Add summaries for viewing model statistics on TensorBoard.
  77. _activation_summaries(endpoints)
  78. # Grab the logits associated with the side head. Employed during training.
  79. auxiliary_logits = endpoints['aux_logits']
  80. return logits, auxiliary_logits
  81. def loss(logits, labels, batch_size=None):
  82. """Adds all losses for the model.
  83. Note the final loss is not returned. Instead, the list of losses are collected
  84. by slim.losses. The losses are accumulated in tower_loss() and summed to
  85. calculate the total loss.
  86. Args:
  87. logits: List of logits from inference(). Each entry is a 2-D float Tensor.
  88. labels: Labels from distorted_inputs or inputs(). 1-D tensor
  89. of shape [batch_size]
  90. batch_size: integer
  91. """
  92. if not batch_size:
  93. batch_size = FLAGS.batch_size
  94. # Reshape the labels into a dense Tensor of
  95. # shape [FLAGS.batch_size, num_classes].
  96. sparse_labels = tf.reshape(labels, [batch_size, 1])
  97. indices = tf.reshape(tf.range(batch_size), [batch_size, 1])
  98. concated = tf.concat(1, [indices, sparse_labels])
  99. num_classes = logits[0].get_shape()[-1].value
  100. dense_labels = tf.sparse_to_dense(concated,
  101. [batch_size, num_classes],
  102. 1.0, 0.0)
  103. # Cross entropy loss for the main softmax prediction.
  104. slim.losses.cross_entropy_loss(logits[0],
  105. dense_labels,
  106. label_smoothing=0.1,
  107. weight=1.0)
  108. # Cross entropy loss for the auxiliary softmax head.
  109. slim.losses.cross_entropy_loss(logits[1],
  110. dense_labels,
  111. label_smoothing=0.1,
  112. weight=0.4,
  113. scope='aux_loss')
  114. def _activation_summary(x):
  115. """Helper to create summaries for activations.
  116. Creates a summary that provides a histogram of activations.
  117. Creates a summary that measure the sparsity of activations.
  118. Args:
  119. x: Tensor
  120. """
  121. # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
  122. # session. This helps the clarity of presentation on tensorboard.
  123. tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
  124. tf.histogram_summary(tensor_name + '/activations', x)
  125. tf.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
  126. def _activation_summaries(endpoints):
  127. with tf.name_scope('summaries'):
  128. for act in endpoints.values():
  129. _activation_summary(act)