cifarnet.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains a variant of the CIFAR-10 model definition."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. slim = tf.contrib.slim
  21. trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev)
  22. def cifarnet(images, num_classes=10, is_training=False,
  23. dropout_keep_prob=0.5,
  24. prediction_fn=slim.softmax,
  25. scope='CifarNet'):
  26. """Creates a variant of the CifarNet model.
  27. Note that since the output is a set of 'logits', the values fall in the
  28. interval of (-infinity, infinity). Consequently, to convert the outputs to a
  29. probability distribution over the characters, one will need to convert them
  30. using the softmax function:
  31. logits = cifarnet.cifarnet(images, is_training=False)
  32. probabilities = tf.nn.softmax(logits)
  33. predictions = tf.argmax(logits, 1)
  34. Args:
  35. images: A batch of `Tensors` of size [batch_size, height, width, channels].
  36. num_classes: the number of classes in the dataset.
  37. is_training: specifies whether or not we're currently training the model.
  38. This variable will determine the behaviour of the dropout layer.
  39. dropout_keep_prob: the percentage of activation values that are retained.
  40. prediction_fn: a function to get predictions out of logits.
  41. scope: Optional variable_scope.
  42. Returns:
  43. logits: the pre-softmax activations, a tensor of size
  44. [batch_size, `num_classes`]
  45. end_points: a dictionary from components of the network to the corresponding
  46. activation.
  47. """
  48. end_points = {}
  49. with tf.variable_scope(scope, 'CifarNet', [images, num_classes]):
  50. net = slim.conv2d(images, 64, [5, 5], scope='conv1')
  51. end_points['conv1'] = net
  52. net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
  53. end_points['pool1'] = net
  54. net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1')
  55. net = slim.conv2d(net, 64, [5, 5], scope='conv2')
  56. end_points['conv2'] = net
  57. net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2')
  58. net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
  59. end_points['pool2'] = net
  60. net = slim.flatten(net)
  61. end_points['Flatten'] = net
  62. net = slim.fully_connected(net, 384, scope='fc3')
  63. end_points['fc3'] = net
  64. net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
  65. scope='dropout3')
  66. net = slim.fully_connected(net, 192, scope='fc4')
  67. end_points['fc4'] = net
  68. logits = slim.fully_connected(net, num_classes,
  69. biases_initializer=tf.zeros_initializer(),
  70. weights_initializer=trunc_normal(1/192.0),
  71. weights_regularizer=None,
  72. activation_fn=None,
  73. scope='logits')
  74. end_points['Logits'] = logits
  75. end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
  76. return logits, end_points
  77. cifarnet.default_image_size = 32
  78. def cifarnet_arg_scope(weight_decay=0.004):
  79. """Defines the default cifarnet argument scope.
  80. Args:
  81. weight_decay: The weight decay to use for regularizing the model.
  82. Returns:
  83. An `arg_scope` to use for the inception v3 model.
  84. """
  85. with slim.arg_scope(
  86. [slim.conv2d],
  87. weights_initializer=tf.truncated_normal_initializer(stddev=5e-2),
  88. activation_fn=tf.nn.relu):
  89. with slim.arg_scope(
  90. [slim.fully_connected],
  91. biases_initializer=tf.constant_initializer(0.1),
  92. weights_initializer=trunc_normal(0.04),
  93. weights_regularizer=slim.l2_regularizer(weight_decay),
  94. activation_fn=tf.nn.relu) as sc:
  95. return sc