12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- # Original code
- # https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/backbone/model_irse.py
- from collections import namedtuple
- import torch
- import torch.nn as nn
- class bottleneck_IR(nn.Module):
- def __init__(self, in_channel, depth, stride):
- super(bottleneck_IR, self).__init__()
- if in_channel == depth:
- self.shortcut_layer = nn.MaxPool2d(1, stride)
- else:
- self.shortcut_layer = nn.Sequential(
- nn.Conv2d(in_channel, depth, (1, 1), stride, bias=False),
- nn.BatchNorm2d(depth),
- )
- self.res_layer = nn.Sequential(
- nn.BatchNorm2d(in_channel),
- nn.Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
- nn.PReLU(depth),
- nn.Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
- nn.BatchNorm2d(depth),
- )
- def forward(self, x):
- shortcut = self.shortcut_layer(x)
- res = self.res_layer(x)
- return res + shortcut
- class Bottleneck(namedtuple("Block", ["in_channel", "depth", "stride"])):
- """A named tuple describing a ResNet block."""
- def get_block(in_channel, depth, num_units, stride=2):
- return [Bottleneck(in_channel, depth, stride)] + [
- Bottleneck(depth, depth, 1) for i in range(num_units - 1)
- ]
- class Backbone(nn.Module):
- def __init__(self, input_size):
- super(Backbone, self).__init__()
- assert input_size[0] in [
- 112,
- 224,
- ], "input_size should be [112, 112] or [224, 224]"
- blocks = [
- get_block(in_channel=64, depth=64, num_units=3),
- get_block(in_channel=64, depth=128, num_units=4),
- get_block(in_channel=128, depth=256, num_units=14),
- get_block(in_channel=256, depth=512, num_units=3),
- ]
- unit_module = bottleneck_IR
- self.input_layer = nn.Sequential(
- nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False), nn.BatchNorm2d(64), nn.PReLU(64),
- )
- if input_size[0] == 112:
- self.output_layer = nn.Sequential(
- nn.BatchNorm2d(512),
- nn.Dropout(),
- nn.Flatten(),
- nn.Linear(512 * 7 * 7, 512),
- nn.BatchNorm1d(512),
- )
- else:
- self.output_layer = nn.Sequential(
- nn.BatchNorm2d(512),
- nn.Dropout(),
- nn.Flatten(),
- nn.Linear(512 * 14 * 14, 512),
- nn.BatchNorm1d(512),
- )
- modules = []
- for block in blocks:
- for bottleneck in block:
- modules.append(
- unit_module(
- bottleneck.in_channel, bottleneck.depth, bottleneck.stride,
- ),
- )
- self.body = nn.Sequential(*modules)
- def forward(self, x):
- x = self.input_layer(x)
- x = self.body(x)
- x = self.output_layer(x)
- return x
|