|
|
@@ -117,6 +117,7 @@ def resnet_v2(inputs,
|
|
|
global_pool=True,
|
|
|
output_stride=None,
|
|
|
include_root_block=True,
|
|
|
+ spatial_squeeze=True,
|
|
|
reuse=None,
|
|
|
scope=None):
|
|
|
"""Generator for v2 (preactivation) ResNet models.
|
|
|
@@ -157,6 +158,8 @@ def resnet_v2(inputs,
|
|
|
include_root_block: If True, include the initial convolution followed by
|
|
|
max-pooling, if False excludes it. If excluded, `inputs` should be the
|
|
|
results of an activation-less convolution.
|
|
|
+ spatial_squeeze: if True, logits is of shape [B, C], if false logits is
|
|
|
+ of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
|
|
|
reuse: whether or not the network and its variables should be reused. To be
|
|
|
able to reuse 'scope' must be given.
|
|
|
scope: Optional variable_scope.
|
|
|
@@ -206,12 +209,14 @@ def resnet_v2(inputs,
|
|
|
if num_classes is not None:
|
|
|
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
|
|
|
normalizer_fn=None, scope='logits')
|
|
|
- logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
|
|
|
+ if spatial_squeeze:
|
|
|
+ logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
|
|
|
# Convert end_points_collection into a dictionary of end_points.
|
|
|
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
|
|
|
if num_classes is not None:
|
|
|
- end_points['predictions'] = slim.softmax(net, scope='predictions')
|
|
|
+ end_points['predictions'] = slim.softmax(logits, scope='predictions')
|
|
|
return logits, end_points
|
|
|
+resnet_v2.default_image_size = 224
|
|
|
|
|
|
|
|
|
def resnet_v2_50(inputs,
|
|
|
@@ -234,7 +239,8 @@ def resnet_v2_50(inputs,
|
|
|
return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
|
|
|
global_pool=global_pool, output_stride=output_stride,
|
|
|
include_root_block=True, reuse=reuse, scope=scope)
|
|
|
-resnet_v2_50.default_image_size = 224
|
|
|
+resnet_v2_50.default_image_size = resnet_v2.default_image_size
|
|
|
+
|
|
|
|
|
|
def resnet_v2_101(inputs,
|
|
|
num_classes=None,
|
|
|
@@ -256,7 +262,7 @@ def resnet_v2_101(inputs,
|
|
|
return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
|
|
|
global_pool=global_pool, output_stride=output_stride,
|
|
|
include_root_block=True, reuse=reuse, scope=scope)
|
|
|
-resnet_v2_101.default_image_size = 224
|
|
|
+resnet_v2_101.default_image_size = resnet_v2.default_image_size
|
|
|
|
|
|
|
|
|
def resnet_v2_152(inputs,
|
|
|
@@ -279,7 +285,7 @@ def resnet_v2_152(inputs,
|
|
|
return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
|
|
|
global_pool=global_pool, output_stride=output_stride,
|
|
|
include_root_block=True, reuse=reuse, scope=scope)
|
|
|
-resnet_v2_152.default_image_size = 224
|
|
|
+resnet_v2_152.default_image_size = resnet_v2.default_image_size
|
|
|
|
|
|
|
|
|
def resnet_v2_200(inputs,
|
|
|
@@ -302,4 +308,4 @@ def resnet_v2_200(inputs,
|
|
|
return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
|
|
|
global_pool=global_pool, output_stride=output_stride,
|
|
|
include_root_block=True, reuse=reuse, scope=scope)
|
|
|
-resnet_v2_200.default_image_size = 224
|
|
|
+resnet_v2_200.default_image_size = resnet_v2.default_image_size
|