image_embedding.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Image embedding ops."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_base
  21. slim = tf.contrib.slim
  22. def inception_v3(images,
  23. trainable=True,
  24. is_training=True,
  25. weight_decay=0.00004,
  26. stddev=0.1,
  27. dropout_keep_prob=0.8,
  28. use_batch_norm=True,
  29. batch_norm_params=None,
  30. add_summaries=True,
  31. scope="InceptionV3"):
  32. """Builds an Inception V3 subgraph for image embeddings.
  33. Args:
  34. images: A float32 Tensor of shape [batch, height, width, channels].
  35. trainable: Whether the inception submodel should be trainable or not.
  36. is_training: Boolean indicating training mode or not.
  37. weight_decay: Coefficient for weight regularization.
  38. stddev: The standard deviation of the trunctated normal weight initializer.
  39. dropout_keep_prob: Dropout keep probability.
  40. use_batch_norm: Whether to use batch normalization.
  41. batch_norm_params: Parameters for batch normalization. See
  42. tf.contrib.layers.batch_norm for details.
  43. add_summaries: Whether to add activation summaries.
  44. scope: Optional Variable scope.
  45. Returns:
  46. end_points: A dictionary of activations from inception_v3 layers.
  47. """
  48. # Only consider the inception model to be in training mode if it's trainable.
  49. is_inception_model_training = trainable and is_training
  50. if use_batch_norm:
  51. # Default parameters for batch normalization.
  52. if not batch_norm_params:
  53. batch_norm_params = {
  54. "is_training": is_inception_model_training,
  55. "trainable": trainable,
  56. # Decay for the moving averages.
  57. "decay": 0.9997,
  58. # Epsilon to prevent 0s in variance.
  59. "epsilon": 0.001,
  60. # Collection containing the moving mean and moving variance.
  61. "variables_collections": {
  62. "beta": None,
  63. "gamma": None,
  64. "moving_mean": ["moving_vars"],
  65. "moving_variance": ["moving_vars"],
  66. }
  67. }
  68. else:
  69. batch_norm_params = None
  70. if trainable:
  71. weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
  72. else:
  73. weights_regularizer = None
  74. with tf.variable_scope(scope, "InceptionV3", [images]) as scope:
  75. with slim.arg_scope(
  76. [slim.conv2d, slim.fully_connected],
  77. weights_regularizer=weights_regularizer,
  78. trainable=trainable):
  79. with slim.arg_scope(
  80. [slim.conv2d],
  81. weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
  82. activation_fn=tf.nn.relu,
  83. normalizer_fn=slim.batch_norm,
  84. normalizer_params=batch_norm_params):
  85. net, end_points = inception_v3_base(images, scope=scope)
  86. with tf.variable_scope("logits"):
  87. shape = net.get_shape()
  88. net = slim.avg_pool2d(net, shape[1:3], padding="VALID", scope="pool")
  89. net = slim.dropout(
  90. net,
  91. keep_prob=dropout_keep_prob,
  92. is_training=is_inception_model_training,
  93. scope="dropout")
  94. net = slim.flatten(net, scope="flatten")
  95. # Add summaries.
  96. if add_summaries:
  97. for v in end_points.values():
  98. tf.contrib.layers.summaries.summarize_activation(v)
  99. return net