backbone.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Original code
  2. # https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/backbone/model_irse.py
  3. from collections import namedtuple
  4. import torch
  5. import torch.nn as nn
  6. class bottleneck_IR(nn.Module):
  7. def __init__(self, in_channel, depth, stride):
  8. super(bottleneck_IR, self).__init__()
  9. if in_channel == depth:
  10. self.shortcut_layer = nn.MaxPool2d(1, stride)
  11. else:
  12. self.shortcut_layer = nn.Sequential(
  13. nn.Conv2d(in_channel, depth, (1, 1), stride, bias=False),
  14. nn.BatchNorm2d(depth),
  15. )
  16. self.res_layer = nn.Sequential(
  17. nn.BatchNorm2d(in_channel),
  18. nn.Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
  19. nn.PReLU(depth),
  20. nn.Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
  21. nn.BatchNorm2d(depth),
  22. )
  23. def forward(self, x):
  24. shortcut = self.shortcut_layer(x)
  25. res = self.res_layer(x)
  26. return res + shortcut
  27. class Bottleneck(namedtuple("Block", ["in_channel", "depth", "stride"])):
  28. """A named tuple describing a ResNet block."""
  29. def get_block(in_channel, depth, num_units, stride=2):
  30. return [Bottleneck(in_channel, depth, stride)] + [
  31. Bottleneck(depth, depth, 1) for i in range(num_units - 1)
  32. ]
  33. class Backbone(nn.Module):
  34. def __init__(self, input_size):
  35. super(Backbone, self).__init__()
  36. assert input_size[0] in [
  37. 112,
  38. 224,
  39. ], "input_size should be [112, 112] or [224, 224]"
  40. blocks = [
  41. get_block(in_channel=64, depth=64, num_units=3),
  42. get_block(in_channel=64, depth=128, num_units=4),
  43. get_block(in_channel=128, depth=256, num_units=14),
  44. get_block(in_channel=256, depth=512, num_units=3),
  45. ]
  46. unit_module = bottleneck_IR
  47. self.input_layer = nn.Sequential(
  48. nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False), nn.BatchNorm2d(64), nn.PReLU(64),
  49. )
  50. if input_size[0] == 112:
  51. self.output_layer = nn.Sequential(
  52. nn.BatchNorm2d(512),
  53. nn.Dropout(),
  54. nn.Flatten(),
  55. nn.Linear(512 * 7 * 7, 512),
  56. nn.BatchNorm1d(512),
  57. )
  58. else:
  59. self.output_layer = nn.Sequential(
  60. nn.BatchNorm2d(512),
  61. nn.Dropout(),
  62. nn.Flatten(),
  63. nn.Linear(512 * 14 * 14, 512),
  64. nn.BatchNorm1d(512),
  65. )
  66. modules = []
  67. for block in blocks:
  68. for bottleneck in block:
  69. modules.append(
  70. unit_module(
  71. bottleneck.in_channel, bottleneck.depth, bottleneck.stride,
  72. ),
  73. )
  74. self.body = nn.Sequential(*modules)
  75. def forward(self, x):
  76. x = self.input_layer(x)
  77. x = self.body(x)
  78. x = self.output_layer(x)
  79. return x