| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- # Copyright 2016 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Build the Inception v3 network on ImageNet data set.
- The Inception v3 architecture is described in http://arxiv.org/abs/1512.00567
- Summary of available functions:
- inference: Compute inference on the model inputs to make a prediction
- loss: Compute the loss of the prediction with respect to the labels
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import re
- import tensorflow as tf
- from inception.slim import slim
- FLAGS = tf.app.flags.FLAGS
- # If a model is trained using multiple GPUs, prefix all Op names with tower_name
- # to differentiate the operations. Note that this prefix is removed from the
- # names of the summaries when visualizing a model.
- TOWER_NAME = 'tower'
- # Batch normalization. Constant governing the exponential moving average of
- # the 'global' mean and variance for all activations.
- BATCHNORM_MOVING_AVERAGE_DECAY = 0.9997
- # The decay to use for the moving average.
- MOVING_AVERAGE_DECAY = 0.9999
- def inference(images, num_classes, for_training=False, restore_logits=True,
- scope=None):
- """Build Inception v3 model architecture.
- See here for reference: http://arxiv.org/abs/1512.00567
- Args:
- images: Images returned from inputs() or distorted_inputs().
- num_classes: number of classes
- for_training: If set to `True`, build the inference model for training.
- Kernels that operate differently for inference during training
- e.g. dropout, are appropriately configured.
- restore_logits: whether or not the logits layers should be restored.
- Useful for fine-tuning a model with different num_classes.
- scope: optional prefix string identifying the ImageNet tower.
- Returns:
- Logits. 2-D float Tensor.
- Auxiliary Logits. 2-D float Tensor of side-head. Used for training only.
- """
- # Parameters for BatchNorm.
- batch_norm_params = {
- # Decay for the moving averages.
- 'decay': BATCHNORM_MOVING_AVERAGE_DECAY,
- # epsilon to prevent 0s in variance.
- 'epsilon': 0.001,
- }
- # Set weight_decay for weights in Conv and FC layers.
- with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
- with slim.arg_scope([slim.ops.conv2d],
- stddev=0.1,
- activation=tf.nn.relu,
- batch_norm_params=batch_norm_params):
- # Force all Variables to reside on the CPU.
- with slim.arg_scope([slim.variables.variable], device='/cpu:0'):
- logits, endpoints = slim.inception.inception_v3(
- images,
- dropout_keep_prob=0.8,
- num_classes=num_classes,
- is_training=for_training,
- restore_logits=restore_logits,
- scope=scope)
- # Add summaries for viewing model statistics on TensorBoard.
- _activation_summaries(endpoints)
- # Grab the logits associated with the side head. Employed during training.
- auxiliary_logits = endpoints['aux_logits']
- return logits, auxiliary_logits
- def loss(logits, labels, batch_size=None):
- """Adds all losses for the model.
- Note the final loss is not returned. Instead, the list of losses are collected
- by slim.losses. The losses are accumulated in tower_loss() and summed to
- calculate the total loss.
- Args:
- logits: List of logits from inference(). Each entry is a 2-D float Tensor.
- labels: Labels from distorted_inputs or inputs(). 1-D tensor
- of shape [batch_size]
- batch_size: integer
- """
- if not batch_size:
- batch_size = FLAGS.batch_size
- # Reshape the labels into a dense Tensor of
- # shape [FLAGS.batch_size, num_classes].
- sparse_labels = tf.reshape(labels, [batch_size, 1])
- indices = tf.reshape(tf.range(batch_size), [batch_size, 1])
- concated = tf.concat(1, [indices, sparse_labels])
- num_classes = logits[0].get_shape()[-1].value
- dense_labels = tf.sparse_to_dense(concated,
- [batch_size, num_classes],
- 1.0, 0.0)
- # Cross entropy loss for the main softmax prediction.
- slim.losses.cross_entropy_loss(logits[0],
- dense_labels,
- label_smoothing=0.1,
- weight=1.0)
- # Cross entropy loss for the auxiliary softmax head.
- slim.losses.cross_entropy_loss(logits[1],
- dense_labels,
- label_smoothing=0.1,
- weight=0.4,
- scope='aux_loss')
- def _activation_summary(x):
- """Helper to create summaries for activations.
- Creates a summary that provides a histogram of activations.
- Creates a summary that measure the sparsity of activations.
- Args:
- x: Tensor
- """
- # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
- # session. This helps the clarity of presentation on tensorboard.
- tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
- tf.histogram_summary(tensor_name + '/activations', x)
- tf.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
- def _activation_summaries(endpoints):
- with tf.name_scope('summaries'):
- for act in endpoints.values():
- _activation_summary(act)
|