model_factory.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains a factory for building various models."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from tensorflow.contrib.slim import nets
  21. from slim.nets import lenet
  22. slim = tf.contrib.slim
  23. def get_model(name, num_classes, weight_decay=0.0, is_training=False):
  24. """Returns a model_fn such as `logits, end_points = model_fn(images)`.
  25. Args:
  26. name: The name of the model.
  27. num_classes: The number of classes to use for classification.
  28. weight_decay: The l2 coefficient for the model weights.
  29. is_training: `True` if the model is being used for training and `False`
  30. otherwise.
  31. Returns:
  32. model_fn: A function that applies the model to a batch of images. It has
  33. the following signature:
  34. logits, end_points = model_fn(images)
  35. Raises:
  36. ValueError: If model `name` is not recognized.
  37. """
  38. if name == 'inception_v1':
  39. default_image_size = nets.inception.inception_v1.default_image_size
  40. def func(images):
  41. with slim.arg_scope(nets.inception.inception_v1_arg_scope(
  42. weight_decay=weight_decay)):
  43. return nets.inception.inception_v1(images,
  44. num_classes,
  45. is_training=is_training)
  46. model_fn = func
  47. elif name == 'inception_v2':
  48. default_image_size = nets.inception.inception_v2.default_image_size
  49. def func(images):
  50. with slim.arg_scope(nets.inception.inception_v2_arg_scope(
  51. weight_decay=weight_decay)):
  52. return nets.inception.inception_v2(images,
  53. num_classes=num_classes,
  54. is_training=is_training)
  55. model_fn = func
  56. elif name == 'inception_v3':
  57. default_image_size = nets.inception.inception_v3.default_image_size
  58. def func(images):
  59. with slim.arg_scope(nets.inception.inception_v3_arg_scope(
  60. weight_decay=weight_decay)):
  61. return nets.inception.inception_v3(images,
  62. num_classes=num_classes,
  63. is_training=is_training)
  64. model_fn = func
  65. elif name == 'lenet':
  66. default_image_size = lenet.lenet.default_image_size
  67. def func(images):
  68. with slim.arg_scope(lenet.lenet_arg_scope(weight_decay=weight_decay)):
  69. return lenet.lenet(images,
  70. num_classes=num_classes,
  71. is_training=is_training)
  72. model_fn = func
  73. elif name == 'resnet_v1_50':
  74. default_image_size = nets.resnet_v1.resnet_v1.default_image_size
  75. def func(images):
  76. with slim.arg_scope(nets.resnet_v1.resnet_arg_scope(
  77. is_training, weight_decay=weight_decay)):
  78. net, end_points = nets.resnet_v1.resnet_v1_50(
  79. images, num_classes=num_classes)
  80. net = tf.squeeze(net, squeeze_dims=[1, 2])
  81. return net, end_points
  82. model_fn = func
  83. elif name == 'resnet_v1_101':
  84. default_image_size = nets.resnet_v1.resnet_v1.default_image_size
  85. def func(images):
  86. with slim.arg_scope(nets.resnet_v1.resnet_arg_scope(
  87. is_training, weight_decay=weight_decay)):
  88. net, end_points = nets.resnet_v1.resnet_v1_101(
  89. images, num_classes=num_classes)
  90. net = tf.squeeze(net, squeeze_dims=[1, 2])
  91. return net, end_points
  92. model_fn = func
  93. elif name == 'resnet_v1_152':
  94. default_image_size = nets.resnet_v1.resnet_v1.default_image_size
  95. def func(images):
  96. with slim.arg_scope(nets.resnet_v1.resnet_arg_scope(
  97. is_training, weight_decay=weight_decay)):
  98. net, end_points = nets.resnet_v1.resnet_v1_152(
  99. images, num_classes=num_classes)
  100. net = tf.squeeze(net, squeeze_dims=[1, 2])
  101. return net, end_points
  102. model_fn = func
  103. elif name == 'vgg_a':
  104. default_image_size = nets.vgg.vgg_a.default_image_size
  105. def func(images):
  106. with slim.arg_scope(nets.vgg.vgg_arg_scope(weight_decay)):
  107. return nets.vgg.vgg_a(images,
  108. num_classes=num_classes,
  109. is_training=is_training)
  110. model_fn = func
  111. elif name == 'vgg_16':
  112. default_image_size = nets.vgg.vgg_16.default_image_size
  113. def func(images):
  114. with slim.arg_scope(nets.vgg.vgg_arg_scope(weight_decay)):
  115. return nets.vgg.vgg_16(images,
  116. num_classes=num_classes,
  117. is_training=is_training)
  118. model_fn = func
  119. elif name == 'vgg_19':
  120. default_image_size = nets.vgg.vgg_19.default_image_size
  121. def func(images):
  122. with slim.arg_scope(nets.vgg.vgg_arg_scope(weight_decay)):
  123. return nets.vgg.vgg_19(images,
  124. num_classes=num_classes,
  125. is_training=is_training)
  126. model_fn = func
  127. else:
  128. raise ValueError('Model name [%s] was not recognized' % name)
  129. model_fn.default_image_size = default_image_size
  130. return model_fn