pytorch_profiler_demo.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from profiler_demo_utils import *
  2. #importing * is not good practice, but simplifies
  3. #this demo. Please do not imitate this :-)
  4. class VisionTrainer(object):
  5. def __init__(self, net, dm):
  6. pass
  7. self.net=net
  8. self.dm=dm
  9. self.writer=SummaryWriter()
  10. self.criterion=nn.CrossEntropyLoss()
  11. self.optimizer=optim.AdamW(self.net.parameters(), lr=1e-6)
  12. self.savepath=None
  13. def train(self, epochs, save, profiler=None):
  14. pass
  15. eval_interval=200 #evaluate every 200 steps
  16. self.savepath=save
  17. device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  18. train_loader, valid_loader = self.dm.train_loader, self.dm.valid_loader #ignore test loader if any
  19. self.net.to(device).train()
  20. if has_apex:
  21. self.net, self.optimizer = amp.initialize(self.net, self.optimizer,
  22. opt_level='O2', enabled=True)
  23. step=0
  24. get_accuracy=lambda p,y: (torch.argmax(p, dim=1) == y).to(torch.float).mean().item()
  25. for epoch in range(epochs):
  26. estart=time.time()
  27. for x,y in train_loader:
  28. with record_function("training_events"): #record these as training_events
  29. self.optimizer.zero_grad()
  30. x=x.to(device)
  31. y=y.to(device)
  32. pred = self.net(x)
  33. loss = self.criterion(pred,y)
  34. #print(loss.item())
  35. self.writer.add_scalar('Training Loss', loss.item(), step)
  36. with amp.scale_loss(loss, self.optimizer) as scaled_loss:
  37. scaled_loss.backward()
  38. #loss.backward()
  39. torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.01)
  40. self.optimizer.step()
  41. acc=get_accuracy(pred, y)
  42. step+=1
  43. self.writer.add_scalar('Training Accuracy', acc, step)
  44. if step%eval_interval==0:
  45. with record_function("evaluation_events"): #record these as evaluation_events
  46. self.net.eval()
  47. valoss=[]
  48. vaacc=[]
  49. with torch.no_grad():
  50. pass
  51. for imgs, ys in valid_loader:
  52. imgs=imgs.to(device)
  53. ys=ys.to(device)
  54. preds=self.net(imgs)
  55. vacc=get_accuracy(preds, ys)
  56. vloss=self.criterion(preds, ys)
  57. #pdb.set_trace()
  58. valoss.append(vloss.flatten().item())
  59. vaacc.append(vacc)
  60. self.writer.add_scalar('Validation Loss', np.mean(valoss), step)
  61. self.writer.add_scalar('Validation Accuracy', np.mean(vaacc), step)
  62. self.net.train()
  63. if profiler:
  64. profiler.step()
  65. self.save(epoch)
  66. eend=time.time()
  67. print('Time taken for last epoch = {:.3f}'.format(eend-estart))
  68. def save(self, epoch):
  69. if self.savepath:
  70. path=self.savepath.format(epoch)
  71. torch.save(self.net.state_dict(), path)
  72. print(f'Saved model to {path}')
  73. def main():
  74. dm=CIFAR10_Manager('./cf10')
  75. #Just change name to one of the following:
  76. #resnet18, resnet50, mobilenetv3, densenet, squeezenet, inception
  77. mname='resnet50'
  78. net=VisionClassifier(nclasses=10, mname=mname)
  79. trainer=VisionTrainer(net,dm)
  80. with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
  81. record_shapes=True,
  82. schedule=schedule(
  83. wait=1,
  84. warmup=1,
  85. active=2),
  86. on_trace_ready=torch.profiler.tensorboard_trace_handler('./runs'),
  87. profile_memory=True,
  88. use_cuda=True) as prof:
  89. trainer.train(epochs=1, save='models/cf10_{}.pth', profiler=prof)
  90. if __name__=='__main__':
  91. main()