123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Contains a factory for building various models."""
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import functools
- import tensorflow as tf
- from nets import alexnet
- from nets import cifarnet
- from nets import inception
- from nets import lenet
- from nets import overfeat
- from nets import resnet_v1
- from nets import resnet_v2
- from nets import vgg
- slim = tf.contrib.slim
- networks_map = {'alexnet_v2': alexnet.alexnet_v2,
- 'cifarnet': cifarnet.cifarnet,
- 'overfeat': overfeat.overfeat,
- 'vgg_a': vgg.vgg_a,
- 'vgg_16': vgg.vgg_16,
- 'vgg_19': vgg.vgg_19,
- 'inception_v1': inception.inception_v1,
- 'inception_v2': inception.inception_v2,
- 'inception_v3': inception.inception_v3,
- 'inception_v4': inception.inception_v4,
- 'inception_resnet_v2': inception.inception_resnet_v2,
- 'lenet': lenet.lenet,
- 'resnet_v1_50': resnet_v1.resnet_v1_50,
- 'resnet_v1_101': resnet_v1.resnet_v1_101,
- 'resnet_v1_152': resnet_v1.resnet_v1_152,
- 'resnet_v1_200': resnet_v1.resnet_v1_200,
- 'resnet_v2_50': resnet_v2.resnet_v2_50,
- 'resnet_v2_101': resnet_v2.resnet_v2_101,
- 'resnet_v2_152': resnet_v2.resnet_v2_152,
- 'resnet_v2_200': resnet_v2.resnet_v2_200,
- }
- arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
- 'cifarnet': cifarnet.cifarnet_arg_scope,
- 'overfeat': overfeat.overfeat_arg_scope,
- 'vgg_a': vgg.vgg_arg_scope,
- 'vgg_16': vgg.vgg_arg_scope,
- 'vgg_19': vgg.vgg_arg_scope,
- 'inception_v1': inception.inception_v3_arg_scope,
- 'inception_v2': inception.inception_v3_arg_scope,
- 'inception_v3': inception.inception_v3_arg_scope,
- 'inception_v4': inception.inception_v4_arg_scope,
- 'inception_resnet_v2':
- inception.inception_resnet_v2_arg_scope,
- 'lenet': lenet.lenet_arg_scope,
- 'resnet_v1_50': resnet_v1.resnet_arg_scope,
- 'resnet_v1_101': resnet_v1.resnet_arg_scope,
- 'resnet_v1_152': resnet_v1.resnet_arg_scope,
- 'resnet_v1_200': resnet_v1.resnet_arg_scope,
- 'resnet_v2_50': resnet_v2.resnet_arg_scope,
- 'resnet_v2_101': resnet_v2.resnet_arg_scope,
- 'resnet_v2_152': resnet_v2.resnet_arg_scope,
- 'resnet_v2_200': resnet_v2.resnet_arg_scope,
- }
- def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
- """Returns a network_fn such as `logits, end_points = network_fn(images)`.
- Args:
- name: The name of the network.
- num_classes: The number of classes to use for classification.
- weight_decay: The l2 coefficient for the model weights.
- is_training: `True` if the model is being used for training and `False`
- otherwise.
- Returns:
- network_fn: A function that applies the model to a batch of images. It has
- the following signature:
- logits, end_points = network_fn(images)
- Raises:
- ValueError: If network `name` is not recognized.
- """
- if name not in networks_map:
- raise ValueError('Name of network unknown %s' % name)
- arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
- func = networks_map[name]
- @functools.wraps(func)
- def network_fn(images):
- with slim.arg_scope(arg_scope):
- return func(images, num_classes, is_training=is_training)
- if hasattr(func, 'default_image_size'):
- network_fn.default_image_size = func.default_image_size
- return network_fn
|