channel_replication.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import torch
  2. from torchvision import models
  3. import cv2
  4. import numpy as np
  5. def classify_image(image):
  6. model = models.resnet18(pretrained=True).eval()
  7. imagenet_means = [0.485, 0.456, 0.406][::-1]
  8. imagenet_stds = [0.229, 0.224, 0.225][::-1]
  9. image = (image / 255.0 - imagenet_means) / imagenet_stds
  10. image = torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32)
  11. image = image.unsqueeze(0)
  12. with torch.no_grad():
  13. preds = model(image)
  14. pred, class_idx = torch.max(preds, dim=1)
  15. print('Class id: {}, confidence: {}'.format(class_idx.item(), pred.item()))
  16. def classify_grayscale():
  17. image = cv2.imread("dog-basset-hound.jpg", cv2.IMREAD_GRAYSCALE)
  18. image = cv2.resize(image, (224, 224))
  19. image = np.stack((image, image, image), axis=2)
  20. classify_image(image)
  21. def classify_colorful():
  22. image = cv2.imread("dog-basset-hound.jpg", cv2.IMREAD_UNCHANGED)
  23. image = cv2.resize(image, (224, 224))
  24. classify_image(image)
  25. if __name__ == "__main__":
  26. classify_colorful()
  27. classify_grayscale()