ResNet18.py 972 B

123456789101112131415161718192021222324252627282930313233343536373839
  1. from torchvision import models
  2. from PIL import Image
  3. import cv2
  4. import torch
  5. from torchsummary import summary
  6. from torchvision import transforms
  7. transform = transforms.Compose([ #[1]
  8. # transforms.Resize(256), #[2]
  9. # transforms.CenterCrop(224), #[3]
  10. transforms.ToTensor(), #[4]
  11. transforms.Normalize( #[5]
  12. mean=[0.485, 0.456, 0.406], #[6]
  13. std=[0.229, 0.224, 0.225] #[7]
  14. )])
  15. with open('imagenet_classes.txt') as f:
  16. labels = [line.strip() for line in f.readlines()]
  17. dir(models)
  18. img = Image.open("camel.jpg")
  19. img_t = transform(img)
  20. batch_t = torch.unsqueeze(img_t, 0)
  21. # First, load the model
  22. resnet = models.resnet18(pretrained=True)
  23. summary(resnet, (3, 224,224))
  24. # Second, put the network in eval mode
  25. resnet.eval()
  26. # Third, carry out model inference
  27. preds = resnet(batch_t)
  28. pred, class_idx = torch.max(preds, dim=1)
  29. print(labels[class_idx])