PyTorchFullyConvolutionalResnet18.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import torch
  2. import torch.nn as nn
  3. from torch.hub import load_state_dict_from_url
  4. from torchvision import models
  5. # Define the architecture by modifying resnet.
  6. # Original code is here
  7. # https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py
  8. class FullyConvolutionalResnet18(models.ResNet):
  9. def __init__(self, num_classes=1000, pretrained=False, **kwargs):
  10. # Start with standard resnet18 defined here
  11. # https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py
  12. super().__init__(
  13. block=models.resnet.BasicBlock,
  14. layers=[2, 2, 2, 2],
  15. num_classes=num_classes,
  16. **kwargs,
  17. )
  18. if pretrained:
  19. state_dict = load_state_dict_from_url(
  20. models.resnet.model_urls["resnet18"], progress=True,
  21. )
  22. self.load_state_dict(state_dict)
  23. # Replace AdaptiveAvgPool2d with standard AvgPool2d
  24. # https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py#L153-L154
  25. self.avgpool = nn.AvgPool2d((7, 7))
  26. # Add final Convolution Layer.
  27. self.last_conv = torch.nn.Conv2d(
  28. in_channels=self.fc.in_features, out_channels=num_classes, kernel_size=1,
  29. )
  30. self.last_conv.weight.data.copy_(
  31. self.fc.weight.data.view(*self.fc.weight.data.shape, 1, 1),
  32. )
  33. self.last_conv.bias.data.copy_(self.fc.bias.data)
  34. # Reimplementing forward pass.
  35. # Replacing the following code
  36. # https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py#L197-L213
  37. def _forward_impl(self, x):
  38. # Standard forward for resnet18
  39. x = self.conv1(x)
  40. x = self.bn1(x)
  41. x = self.relu(x)
  42. x = self.maxpool(x)
  43. x = self.layer1(x)
  44. x = self.layer2(x)
  45. x = self.layer3(x)
  46. x = self.layer4(x)
  47. x = self.avgpool(x)
  48. # Notice, there is no forward pass
  49. # through the original fully connected layer.
  50. # Instead, we forward pass through the last conv layer
  51. x = self.last_conv(x)
  52. return x