resnet.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import torch
  2. from torchvision import models
  3. from torch.hub import load_state_dict_from_url
  4. # Define the architecture by modifying resnet.
  5. # Original code is here
  6. # https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py
  7. class ResNet101(models.ResNet):
  8. def __init__(self, num_classes=1000, pretrained=True, **kwargs):
  9. # Start with standard resnet101 defined here
  10. # https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py
  11. super().__init__(block=models.resnet.Bottleneck, layers=[3, 4, 23, 3], num_classes=num_classes, **kwargs)
  12. if pretrained:
  13. state_dict = load_state_dict_from_url(models.resnet.model_urls['resnet101'], progress=True)
  14. self.load_state_dict(state_dict)
  15. # Reimplementing forward pass.
  16. # Replacing the following code
  17. # https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py#L197-L213
  18. def _forward_impl(self, x):
  19. # Standard forward for resnet
  20. x = self.conv1(x)
  21. x = self.bn1(x)
  22. x = self.relu(x)
  23. x = self.maxpool(x)
  24. x = self.layer1(x)
  25. x = self.layer2(x)
  26. x = self.layer3(x)
  27. x = self.layer4(x)
  28. # Notice there is no forward pass through the original classifier.
  29. x = self.avgpool(x)
  30. x = torch.flatten(x, 1)
  31. return x