train_resnet.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import torch.nn as nn
  2. import torch.optim as optim
  3. from torchvision import models
  4. import torch
  5. import numpy as np
  6. from dataset import DataManager
  7. import config as cfg
  8. import time
  9. from torch.utils.tensorboard import SummaryWriter
  10. from apex import amp
  11. class SoftTargetLoss(nn.Module):
  12. def __init__(self, temperature=1):
  13. """
  14. Soft Target Loss as introduced by Hinton et. al.
  15. in https://arxiv.org/abs/1503.02531
  16. temp (float or int): annealing temperature hyperparameter
  17. temperature=1 corresponds to usual softmax
  18. """
  19. super(SoftTargetLoss, self).__init__()
  20. self.register_buffer('temperature', torch.tensor(temperature))
  21. #temperature
  22. def forward(self, student_logits, teacher_logits):
  23. student_probabilities=nn.functional.softmax(student_logits/self.temperature)
  24. teacher_probabilities=nn.functional.softmax(teacher_logits/self.temperature)
  25. loss = - torch.mul(teacher_probabilities, torch.log(student_probabilities))
  26. return torch.mean(loss)
  27. class Trainer(object):
  28. def __init__(self, net, manager, savepath):
  29. """
  30. net(nn.Module): Neural network to be trained
  31. manager(DataManager): data manager from dataset.py
  32. savepath(str): a format-ready string like 'model_{}.path'
  33. for which .format method can be called while saving models
  34. at every epoch
  35. """
  36. self.net=net
  37. self.manager=manager
  38. self.savepath=savepath #should have curly brackets, ex. 'model_{}.pth'
  39. self.criterion = SoftTargetLoss(cfg.TEMPERATURE)
  40. self.optimizer = optim.Adam(self.net.parameters(), lr=cfg.LR)
  41. self.writer=SummaryWriter()
  42. def save(self, path):
  43. checkpoint= {'model':self.net.state_dict(),
  44. 'optimizer':self.optimizer.state_dict(),
  45. 'amp':amp.state_dict() }
  46. torch.save(checkpoint, path)
  47. print(f'Saved model to {path}')
  48. def train(self, epochs=None, evaluate_interval=None):
  49. steps=0
  50. epochs=epochs if epochs else cfg.EPOCHS
  51. evaluate_interval=evaluate_interval if evaluate_interval else cfg.EVAL_INTERVAL
  52. device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  53. if device.type != 'cuda':
  54. print(f'GPU not found. Training will be done on device of type {device.type}')
  55. self.net.to(device)
  56. self.net, self.optimizer = amp.initialize(self.net, self.optimizer,
  57. opt_level='O2')
  58. self.net.train()
  59. train_iterator, valid_iterator, *_ = self.manager.dataloaders
  60. get_top5_accuracy=lambda p,y: (torch.topk(p, 5, dim=1).indices == torch.argmax(y, 1)[:,None]).sum(dim=1).to(torch.float).mean().item()
  61. mean= lambda v: sum(v)/len(v)
  62. for epoch in range(epochs):
  63. start_time=time.time()
  64. for idx, (x,y) in enumerate(train_iterator):
  65. self.optimizer.zero_grad()
  66. #print('Resnet input shape= ', x.shape)
  67. x=x.to(device)
  68. y=y.to(device)
  69. preds=self.net(x)
  70. loss=self.criterion(preds, y)
  71. #loss.backward()
  72. with amp.scale_loss(loss, self.optimizer) as scaled_loss:
  73. scaled_loss.backward()
  74. self.optimizer.step()
  75. top5_accuracy=get_top5_accuracy(preds, y)
  76. #this isn't *really* the top 5 accuracy because it is evaluated against the outputs of the teacher
  77. #model as opposed to ground truth labels. Since the value of the loss is not easy to grasp
  78. #intuitively, this proxy serves as an easily computable metric to monitor the progress of the
  79. #student network, especially if the training data is also imagenet.
  80. self.writer.add_scalar('Loss', loss, steps)
  81. self.writer.add_scalar('Top-5 training accuracy', top5_accuracy, steps)
  82. steps+=1
  83. if steps%evaluate_interval==0:
  84. valid_loss=[]
  85. valid_accuracy=[]
  86. self.net.eval() #put network in evaluation mode
  87. with torch.no_grad():
  88. for xv, yv in valid_iterator:
  89. xv=xv.to(device)
  90. yv=yv.to(device)
  91. preds=self.net(xv)
  92. vtop5a=get_top5_accuracy(preds, yv)
  93. vloss=self.criterion(preds, yv)
  94. valid_loss.append(vloss.item())
  95. valid_accuracy.append(vtop5a)
  96. self.writer.add_scalar('Validation Loss', mean(valid_loss), steps)
  97. self.writer.add_scalar('Top-5 validation accuracy', mean(valid_accuracy), steps)
  98. self.writer.flush()
  99. self.net.train() #return to training mode
  100. pass
  101. self.writer.flush() #make sure the writer updates all stats until now
  102. self.save(self.savepath.format(epoch))
  103. end_time=time.time()
  104. print('Time taken for last epoch = {:.3f} seconds'.format(end_time-start_time))
  105. def main():
  106. manager=DataManager(cfg.IMGPATH_FILE, cfg.SOFT_TARGET_PATH, cfg.SIZE)
  107. net=models.resnet18(pretrained=False)
  108. trainer=Trainer(net, manager, cfg.SAVE_PATH)
  109. trainer.train()
  110. if __name__=="__main__":
  111. main()