test.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import argparse
  2. import os
  3. import warnings
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import torch
  7. import torchvision.transforms as transforms
  8. from dataset import FashionDataset, AttributesDataset, mean, std
  9. from model import MultiOutputModel
  10. from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, balanced_accuracy_score
  11. from torch.utils.data import DataLoader
  12. def checkpoint_load(model, name):
  13. print('Restoring checkpoint: {}'.format(name))
  14. model.load_state_dict(torch.load(name, map_location='cpu'))
  15. epoch = int(os.path.splitext(os.path.basename(name))[0].split('-')[1])
  16. return epoch
  17. def validate(model, dataloader, logger, iteration, device, checkpoint=None):
  18. if checkpoint is not None:
  19. checkpoint_load(model, checkpoint)
  20. model.eval()
  21. with torch.no_grad():
  22. avg_loss = 0
  23. accuracy_color = 0
  24. accuracy_gender = 0
  25. accuracy_article = 0
  26. for batch in dataloader:
  27. img = batch['img']
  28. target_labels = batch['labels']
  29. target_labels = {t: target_labels[t].to(device) for t in target_labels}
  30. output = model(img.to(device))
  31. val_train, val_train_losses = model.get_loss(output, target_labels)
  32. avg_loss += val_train.item()
  33. batch_accuracy_color, batch_accuracy_gender, batch_accuracy_article = \
  34. calculate_metrics(output, target_labels)
  35. accuracy_color += batch_accuracy_color
  36. accuracy_gender += batch_accuracy_gender
  37. accuracy_article += batch_accuracy_article
  38. n_samples = len(dataloader)
  39. avg_loss /= n_samples
  40. accuracy_color /= n_samples
  41. accuracy_gender /= n_samples
  42. accuracy_article /= n_samples
  43. print('-' * 72)
  44. print("Validation loss: {:.4f}, color: {:.4f}, gender: {:.4f}, article: {:.4f}\n".format(
  45. avg_loss, accuracy_color, accuracy_gender, accuracy_article))
  46. logger.add_scalar('val_loss', avg_loss, iteration)
  47. logger.add_scalar('val_accuracy_color', accuracy_color, iteration)
  48. logger.add_scalar('val_accuracy_gender', accuracy_gender, iteration)
  49. logger.add_scalar('val_accuracy_article', accuracy_article, iteration)
  50. model.train()
  51. def visualize_grid(model, dataloader, attributes, device, show_cn_matrices=True, show_images=True, checkpoint=None,
  52. show_gt=False):
  53. if checkpoint is not None:
  54. checkpoint_load(model, checkpoint)
  55. model.eval()
  56. imgs = []
  57. labels = []
  58. gt_labels = []
  59. gt_color_all = []
  60. gt_gender_all = []
  61. gt_article_all = []
  62. predicted_color_all = []
  63. predicted_gender_all = []
  64. predicted_article_all = []
  65. accuracy_color = 0
  66. accuracy_gender = 0
  67. accuracy_article = 0
  68. with torch.no_grad():
  69. for batch in dataloader:
  70. img = batch['img']
  71. gt_colors = batch['labels']['color_labels']
  72. gt_genders = batch['labels']['gender_labels']
  73. gt_articles = batch['labels']['article_labels']
  74. output = model(img.to(device))
  75. batch_accuracy_color, batch_accuracy_gender, batch_accuracy_article = \
  76. calculate_metrics(output, batch['labels'])
  77. accuracy_color += batch_accuracy_color
  78. accuracy_gender += batch_accuracy_gender
  79. accuracy_article += batch_accuracy_article
  80. # get the most confident prediction for each image
  81. _, predicted_colors = output['color'].cpu().max(1)
  82. _, predicted_genders = output['gender'].cpu().max(1)
  83. _, predicted_articles = output['article'].cpu().max(1)
  84. for i in range(img.shape[0]):
  85. image = np.clip(img[i].permute(1, 2, 0).numpy() * std + mean, 0, 1)
  86. predicted_color = attributes.color_id_to_name[predicted_colors[i].item()]
  87. predicted_gender = attributes.gender_id_to_name[predicted_genders[i].item()]
  88. predicted_article = attributes.article_id_to_name[predicted_articles[i].item()]
  89. gt_color = attributes.color_id_to_name[gt_colors[i].item()]
  90. gt_gender = attributes.gender_id_to_name[gt_genders[i].item()]
  91. gt_article = attributes.article_id_to_name[gt_articles[i].item()]
  92. gt_color_all.append(gt_color)
  93. gt_gender_all.append(gt_gender)
  94. gt_article_all.append(gt_article)
  95. predicted_color_all.append(predicted_color)
  96. predicted_gender_all.append(predicted_gender)
  97. predicted_article_all.append(predicted_article)
  98. imgs.append(image)
  99. labels.append("{}\n{}\n{}".format(predicted_gender, predicted_article, predicted_color))
  100. gt_labels.append("{}\n{}\n{}".format(gt_gender, gt_article, gt_color))
  101. if not show_gt:
  102. n_samples = len(dataloader)
  103. print("\nAccuracy:\ncolor: {:.4f}, gender: {:.4f}, article: {:.4f}".format(
  104. accuracy_color / n_samples,
  105. accuracy_gender / n_samples,
  106. accuracy_article / n_samples))
  107. # Draw confusion matrices
  108. if show_cn_matrices:
  109. # color
  110. cn_matrix = confusion_matrix(
  111. y_true=gt_color_all,
  112. y_pred=predicted_color_all,
  113. labels=attributes.color_labels,
  114. normalize='true')
  115. ConfusionMatrixDisplay(cn_matrix, attributes.color_labels).plot(
  116. include_values=False, xticks_rotation='vertical')
  117. plt.title("Colors")
  118. plt.tight_layout()
  119. plt.show()
  120. # gender
  121. cn_matrix = confusion_matrix(
  122. y_true=gt_gender_all,
  123. y_pred=predicted_gender_all,
  124. labels=attributes.gender_labels,
  125. normalize='true')
  126. ConfusionMatrixDisplay(cn_matrix, attributes.gender_labels).plot(
  127. xticks_rotation='horizontal')
  128. plt.title("Genders")
  129. plt.tight_layout()
  130. plt.show()
  131. # Uncomment code below to see the article confusion matrix (it may be too big to display)
  132. cn_matrix = confusion_matrix(
  133. y_true=gt_article_all,
  134. y_pred=predicted_article_all,
  135. labels=attributes.article_labels,
  136. normalize='true')
  137. plt.rcParams.update({'font.size': 1.8})
  138. plt.rcParams.update({'figure.dpi': 300})
  139. ConfusionMatrixDisplay(cn_matrix, attributes.article_labels).plot(
  140. include_values=False, xticks_rotation='vertical')
  141. plt.rcParams.update({'figure.dpi': 100})
  142. plt.rcParams.update({'font.size': 5})
  143. plt.title("Article types")
  144. plt.show()
  145. if show_images:
  146. labels = gt_labels if show_gt else labels
  147. title = "Ground truth labels" if show_gt else "Predicted labels"
  148. n_cols = 5
  149. n_rows = 3
  150. fig, axs = plt.subplots(n_rows, n_cols, figsize=(10, 10))
  151. axs = axs.flatten()
  152. for img, ax, label in zip(imgs, axs, labels):
  153. ax.set_xlabel(label, rotation=0)
  154. ax.get_xaxis().set_ticks([])
  155. ax.get_yaxis().set_ticks([])
  156. ax.imshow(img)
  157. plt.suptitle(title)
  158. plt.tight_layout()
  159. plt.show()
  160. model.train()
  161. def calculate_metrics(output, target):
  162. _, predicted_color = output['color'].cpu().max(1)
  163. gt_color = target['color_labels'].cpu()
  164. _, predicted_gender = output['gender'].cpu().max(1)
  165. gt_gender = target['gender_labels'].cpu()
  166. _, predicted_article = output['article'].cpu().max(1)
  167. gt_article = target['article_labels'].cpu()
  168. with warnings.catch_warnings(): # sklearn may produce a warning when processing zero row in confusion matrix
  169. warnings.simplefilter("ignore")
  170. accuracy_color = balanced_accuracy_score(y_true=gt_color.numpy(), y_pred=predicted_color.numpy())
  171. accuracy_gender = balanced_accuracy_score(y_true=gt_gender.numpy(), y_pred=predicted_gender.numpy())
  172. accuracy_article = balanced_accuracy_score(y_true=gt_article.numpy(), y_pred=predicted_article.numpy())
  173. return accuracy_color, accuracy_gender, accuracy_article
  174. if __name__ == '__main__':
  175. parser = argparse.ArgumentParser(description='Inference pipeline')
  176. parser.add_argument('--checkpoint', type=str, required=True, help="Path to the checkpoint")
  177. parser.add_argument('--attributes_file', type=str, default='./fashion-product-images/styles.csv',
  178. help="Path to the file with attributes")
  179. parser.add_argument('--device', type=str, default='cuda',
  180. help="Device: 'cuda' or 'cpu'")
  181. args = parser.parse_args()
  182. device = torch.device("cuda" if torch.cuda.is_available() and args.device == 'cuda' else "cpu")
  183. # attributes variable contains labels for the categories in the dataset and mapping between string names and IDs
  184. attributes = AttributesDataset(args.attributes_file)
  185. # during validation we use only tensor and normalization transforms
  186. val_transform = transforms.Compose([
  187. transforms.ToTensor(),
  188. transforms.Normalize(mean, std)
  189. ])
  190. test_dataset = FashionDataset('./val.csv', attributes, val_transform)
  191. test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8)
  192. model = MultiOutputModel(n_color_classes=attributes.num_colors, n_gender_classes=attributes.num_genders,
  193. n_article_classes=attributes.num_articles).to(device)
  194. # Visualization of the trained model
  195. visualize_grid(model, test_dataloader, attributes, device, checkpoint=args.checkpoint)