12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- import csv
- import numpy as np
- from PIL import Image
- from torch.utils.data import Dataset
- mean = [0.485, 0.456, 0.406]
- std = [0.229, 0.224, 0.225]
- class AttributesDataset():
- def __init__(self, annotation_path):
- color_labels = []
- gender_labels = []
- article_labels = []
- with open(annotation_path) as f:
- reader = csv.DictReader(f)
- for row in reader:
- color_labels.append(row['baseColour'])
- gender_labels.append(row['gender'])
- article_labels.append(row['articleType'])
- self.color_labels = np.unique(color_labels)
- self.gender_labels = np.unique(gender_labels)
- self.article_labels = np.unique(article_labels)
- self.num_colors = len(self.color_labels)
- self.num_genders = len(self.gender_labels)
- self.num_articles = len(self.article_labels)
- self.color_id_to_name = dict(zip(range(len(self.color_labels)), self.color_labels))
- self.color_name_to_id = dict(zip(self.color_labels, range(len(self.color_labels))))
- self.gender_id_to_name = dict(zip(range(len(self.gender_labels)), self.gender_labels))
- self.gender_name_to_id = dict(zip(self.gender_labels, range(len(self.gender_labels))))
- self.article_id_to_name = dict(zip(range(len(self.article_labels)), self.article_labels))
- self.article_name_to_id = dict(zip(self.article_labels, range(len(self.article_labels))))
- class FashionDataset(Dataset):
- def __init__(self, annotation_path, attributes, transform=None):
- super().__init__()
- self.transform = transform
- self.attr = attributes
- # initialize the arrays to store the ground truth labels and paths to the images
- self.data = []
- self.color_labels = []
- self.gender_labels = []
- self.article_labels = []
- # read the annotations from the CSV file
- with open(annotation_path) as f:
- reader = csv.DictReader(f)
- for row in reader:
- self.data.append(row['image_path'])
- self.color_labels.append(self.attr.color_name_to_id[row['baseColour']])
- self.gender_labels.append(self.attr.gender_name_to_id[row['gender']])
- self.article_labels.append(self.attr.article_name_to_id[row['articleType']])
- def __len__(self):
- return len(self.data)
- def __getitem__(self, idx):
- # take the data sample by its index
- img_path = self.data[idx]
- # read image
- img = Image.open(img_path)
- # apply the image augmentations if needed
- if self.transform:
- img = self.transform(img)
- # return the image and all the associated labels
- dict_data = {
- 'img': img,
- 'labels': {
- 'color_labels': self.color_labels[idx],
- 'gender_labels': self.gender_labels[idx],
- 'article_labels': self.article_labels[idx]
- }
- }
- return dict_data
|