embeddings.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Original code
  2. # https://github.com/ZhaoJ9014/face.evoLVe.PyTorch/blob/master/util/extract_feature_v1.py
  3. import os
  4. import cv2
  5. import numpy as np
  6. import torch
  7. import torch.utils.data as data
  8. import torchvision.datasets as datasets
  9. import torch.nn.functional as F
  10. import torchvision.transforms as transforms
  11. from backbone import Backbone
  12. from tqdm import tqdm
  13. def get_embeddings(data_root, model_root, input_size=[112, 112], embedding_size=512):
  14. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  15. # check data and model paths
  16. assert os.path.exists(data_root)
  17. assert os.path.exists(model_root)
  18. print(f"Data root: {data_root}")
  19. # define image preprocessing
  20. transform = transforms.Compose(
  21. [
  22. transforms.Resize(
  23. [int(128 * input_size[0] / 112), int(128 * input_size[0] / 112)],
  24. ), # smaller side resized
  25. transforms.CenterCrop([input_size[0], input_size[1]]),
  26. transforms.ToTensor(),
  27. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
  28. ],
  29. )
  30. # define data loader
  31. dataset = datasets.ImageFolder(data_root, transform)
  32. loader = data.DataLoader(
  33. dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=0,
  34. )
  35. print(f"Number of classes: {len(loader.dataset.classes)}")
  36. # load backbone weigths from a checkpoint
  37. backbone = Backbone(input_size)
  38. backbone.load_state_dict(torch.load(model_root, map_location=torch.device("cpu")))
  39. backbone.to(device)
  40. backbone.eval()
  41. # get embedding for each face
  42. embeddings = np.zeros([len(loader.dataset), embedding_size])
  43. with torch.no_grad():
  44. for idx, (image, _) in enumerate(
  45. tqdm(loader, desc="Create embeddings matrix", total=len(loader)),
  46. ):
  47. embeddings[idx, :] = F.normalize(backbone(image.to(device))).cpu()
  48. # get all original images
  49. images = []
  50. for img_path, _ in dataset.samples:
  51. img = cv2.imread(img_path)
  52. images.append(img)
  53. return images, embeddings