dataset.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import csv
  2. import numpy as np
  3. from PIL import Image
  4. from torch.utils.data import Dataset
  5. mean = [0.485, 0.456, 0.406]
  6. std = [0.229, 0.224, 0.225]
  7. class AttributesDataset():
  8. def __init__(self, annotation_path):
  9. color_labels = []
  10. gender_labels = []
  11. article_labels = []
  12. with open(annotation_path) as f:
  13. reader = csv.DictReader(f)
  14. for row in reader:
  15. color_labels.append(row['baseColour'])
  16. gender_labels.append(row['gender'])
  17. article_labels.append(row['articleType'])
  18. self.color_labels = np.unique(color_labels)
  19. self.gender_labels = np.unique(gender_labels)
  20. self.article_labels = np.unique(article_labels)
  21. self.num_colors = len(self.color_labels)
  22. self.num_genders = len(self.gender_labels)
  23. self.num_articles = len(self.article_labels)
  24. self.color_id_to_name = dict(zip(range(len(self.color_labels)), self.color_labels))
  25. self.color_name_to_id = dict(zip(self.color_labels, range(len(self.color_labels))))
  26. self.gender_id_to_name = dict(zip(range(len(self.gender_labels)), self.gender_labels))
  27. self.gender_name_to_id = dict(zip(self.gender_labels, range(len(self.gender_labels))))
  28. self.article_id_to_name = dict(zip(range(len(self.article_labels)), self.article_labels))
  29. self.article_name_to_id = dict(zip(self.article_labels, range(len(self.article_labels))))
  30. class FashionDataset(Dataset):
  31. def __init__(self, annotation_path, attributes, transform=None):
  32. super().__init__()
  33. self.transform = transform
  34. self.attr = attributes
  35. # initialize the arrays to store the ground truth labels and paths to the images
  36. self.data = []
  37. self.color_labels = []
  38. self.gender_labels = []
  39. self.article_labels = []
  40. # read the annotations from the CSV file
  41. with open(annotation_path) as f:
  42. reader = csv.DictReader(f)
  43. for row in reader:
  44. self.data.append(row['image_path'])
  45. self.color_labels.append(self.attr.color_name_to_id[row['baseColour']])
  46. self.gender_labels.append(self.attr.gender_name_to_id[row['gender']])
  47. self.article_labels.append(self.attr.article_name_to_id[row['articleType']])
  48. def __len__(self):
  49. return len(self.data)
  50. def __getitem__(self, idx):
  51. # take the data sample by its index
  52. img_path = self.data[idx]
  53. # read image
  54. img = Image.open(img_path)
  55. # apply the image augmentations if needed
  56. if self.transform:
  57. img = self.transform(img)
  58. # return the image and all the associated labels
  59. dict_data = {
  60. 'img': img,
  61. 'labels': {
  62. 'color_labels': self.color_labels[idx],
  63. 'gender_labels': self.gender_labels[idx],
  64. 'article_labels': self.article_labels[idx]
  65. }
  66. }
  67. return dict_data