train.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import argparse
  2. import os
  3. from datetime import datetime
  4. import torch
  5. import torchvision.transforms as transforms
  6. from dataset import FashionDataset, AttributesDataset, mean, std
  7. from model import MultiOutputModel
  8. from test import calculate_metrics, validate, visualize_grid
  9. from torch.utils.data import DataLoader
  10. from torch.utils.tensorboard import SummaryWriter
  11. def get_cur_time():
  12. return datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M')
  13. def checkpoint_save(model, name, epoch):
  14. f = os.path.join(name, 'checkpoint-{:06d}.pth'.format(epoch))
  15. torch.save(model.state_dict(), f)
  16. print('Saved checkpoint:', f)
  17. if __name__ == '__main__':
  18. parser = argparse.ArgumentParser(description='Training pipeline')
  19. parser.add_argument('--attributes_file', type=str, default='./fashion-product-images/styles.csv',
  20. help="Path to the file with attributes")
  21. parser.add_argument('--device', type=str, default='cuda', help="Device: 'cuda' or 'cpu'")
  22. args = parser.parse_args()
  23. start_epoch = 1
  24. N_epochs = 50
  25. batch_size = 16
  26. num_workers = 8 # number of processes to handle dataset loading
  27. device = torch.device("cuda" if torch.cuda.is_available() and args.device == 'cuda' else "cpu")
  28. # attributes variable contains labels for the categories in the dataset and mapping between string names and IDs
  29. attributes = AttributesDataset(args.attributes_file)
  30. # specify image transforms for augmentation during training
  31. train_transform = transforms.Compose([
  32. transforms.RandomHorizontalFlip(p=0.5),
  33. transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0),
  34. transforms.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.8, 1.2),
  35. shear=None, resample=False, fillcolor=(255, 255, 255)),
  36. transforms.ToTensor(),
  37. transforms.Normalize(mean, std)
  38. ])
  39. # during validation we use only tensor and normalization transforms
  40. val_transform = transforms.Compose([
  41. transforms.ToTensor(),
  42. transforms.Normalize(mean, std)
  43. ])
  44. train_dataset = FashionDataset('./train.csv', attributes, train_transform)
  45. train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
  46. val_dataset = FashionDataset('./val.csv', attributes, val_transform)
  47. val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
  48. model = MultiOutputModel(n_color_classes=attributes.num_colors,
  49. n_gender_classes=attributes.num_genders,
  50. n_article_classes=attributes.num_articles)\
  51. .to(device)
  52. optimizer = torch.optim.Adam(model.parameters())
  53. logdir = os.path.join('./logs/', get_cur_time())
  54. savedir = os.path.join('./checkpoints/', get_cur_time())
  55. os.makedirs(logdir, exist_ok=True)
  56. os.makedirs(savedir, exist_ok=True)
  57. logger = SummaryWriter(logdir)
  58. n_train_samples = len(train_dataloader)
  59. # Uncomment rows below to see example images with ground truth labels in val dataset and all the labels:
  60. # visualize_grid(model, val_dataloader, attributes, device, show_cn_matrices=False, show_images=True,
  61. # checkpoint=None, show_gt=True)
  62. # print("\nAll gender labels:\n", attributes.gender_labels)
  63. # print("\nAll color labels:\n", attributes.color_labels)
  64. # print("\nAll article labels:\n", attributes.article_labels)
  65. print("Starting training ...")
  66. for epoch in range(start_epoch, N_epochs + 1):
  67. total_loss = 0
  68. accuracy_color = 0
  69. accuracy_gender = 0
  70. accuracy_article = 0
  71. for batch in train_dataloader:
  72. optimizer.zero_grad()
  73. img = batch['img']
  74. target_labels = batch['labels']
  75. target_labels = {t: target_labels[t].to(device) for t in target_labels}
  76. output = model(img.to(device))
  77. loss_train, losses_train = model.get_loss(output, target_labels)
  78. total_loss += loss_train.item()
  79. batch_accuracy_color, batch_accuracy_gender, batch_accuracy_article = \
  80. calculate_metrics(output, target_labels)
  81. accuracy_color += batch_accuracy_color
  82. accuracy_gender += batch_accuracy_gender
  83. accuracy_article += batch_accuracy_article
  84. loss_train.backward()
  85. optimizer.step()
  86. print("epoch {:4d}, loss: {:.4f}, color: {:.4f}, gender: {:.4f}, article: {:.4f}".format(
  87. epoch,
  88. total_loss / n_train_samples,
  89. accuracy_color / n_train_samples,
  90. accuracy_gender / n_train_samples,
  91. accuracy_article / n_train_samples))
  92. logger.add_scalar('train_loss', total_loss / n_train_samples, epoch)
  93. if epoch % 5 == 0:
  94. validate(model, val_dataloader, logger, epoch, device)
  95. if epoch % 25 == 0:
  96. checkpoint_save(model, savedir, epoch)