split_data.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import argparse
  2. import csv
  3. import os
  4. import numpy as np
  5. from PIL import Image
  6. from tqdm import tqdm
  7. def save_csv(data, path, fieldnames=['image_path', 'gender', 'articleType', 'baseColour']):
  8. with open(path, 'w', newline='') as csv_file:
  9. writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
  10. writer.writeheader()
  11. for row in data:
  12. writer.writerow(dict(zip(fieldnames, row)))
  13. if __name__ == '__main__':
  14. parser = argparse.ArgumentParser(description='Split data for the dataset')
  15. parser.add_argument('--input', type=str, required=True, help="Path to the dataset")
  16. parser.add_argument('--output', type=str, required=True, help="Path to the working folder")
  17. args = parser.parse_args()
  18. input_folder = args.input
  19. output_folder = args.output
  20. annotation = os.path.join(input_folder, 'styles.csv')
  21. # open annotation file
  22. all_data = []
  23. with open(annotation) as csv_file:
  24. # parse it as CSV
  25. reader = csv.DictReader(csv_file)
  26. # tqdm shows pretty progress bar
  27. # each row in the CSV file corresponds to the image
  28. for row in tqdm(reader, total=reader.line_num):
  29. # we need image ID to build the path to the image file
  30. img_id = row['id']
  31. # we're going to use only 3 attributes
  32. gender = row['gender']
  33. articleType = row['articleType']
  34. baseColour = row['baseColour']
  35. img_name = os.path.join(input_folder, 'images', str(img_id) + '.jpg')
  36. # check if file is in place
  37. if os.path.exists(img_name):
  38. # check if the image has 80*60 pixels with 3 channels
  39. img = Image.open(img_name)
  40. if img.size == (60, 80) and img.mode == "RGB":
  41. all_data.append([img_name, gender, articleType, baseColour])
  42. # set the seed of the random numbers generator, so we can reproduce the results later
  43. np.random.seed(42)
  44. # construct a Numpy array from the list
  45. all_data = np.asarray(all_data)
  46. # Take 40000 samples in random order
  47. inds = np.random.choice(40000, 40000, replace=False)
  48. # split the data into train/val and save them as csv files
  49. save_csv(all_data[inds][:32000], os.path.join(output_folder, 'train.csv'))
  50. save_csv(all_data[inds][32000:40000], os.path.join(output_folder, 'val.csv'))