profiler_demo_utils.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import torch.nn as nn
  2. import torch.optim as optim
  3. from torchvision import models, datasets
  4. import torch
  5. import numpy as np
  6. import time
  7. from torch.utils.tensorboard import SummaryWriter
  8. try:
  9. from apex import amp
  10. has_apex=True
  11. except:
  12. print('apex not available')
  13. has_apex=False
  14. from torch.utils.data import Dataset, DataLoader
  15. from torchvision import transforms as T
  16. ## THIS IS THE IMPORTANT BIT
  17. from torch.profiler import profile, record_function, ProfilerActivity, schedule
  18. class CIFAR10_Manager(object):
  19. def __init__(self, indir, bsize=128):
  20. self.indir=indir
  21. self.inputsize=(32,32)
  22. self.input_transforms=self.get_input_transforms()
  23. self.batchsize=bsize
  24. self.train_loader=self.get_train_loader()
  25. self.valid_loader=self.get_valid_loader()
  26. def get_train_loader(self):
  27. pass
  28. tdata=datasets.CIFAR10(
  29. root=self.indir,
  30. train=True,
  31. transform=self.input_transforms,
  32. download=True)
  33. tloader=DataLoader(tdata, self.batchsize, shuffle=True, num_workers=8)
  34. return tloader
  35. def get_valid_loader(self):
  36. pass
  37. vdata=datasets.CIFAR10(
  38. root=self.indir,
  39. train=False,
  40. transform=self.input_transforms,
  41. download=True)
  42. vloader=DataLoader(vdata, self.batchsize, shuffle=True, num_workers=8)
  43. return vloader
  44. def get_input_transforms(self):
  45. normalize_transform=T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  46. affine=T.RandomAffine(degrees=5, translate=(0.05, 0.05))
  47. hflip =T.RandomHorizontalFlip(p=0.7)
  48. vflip =T.RandomVerticalFlip(p=0.7)
  49. blur=T.GaussianBlur(5) #kernel size 5x5
  50. composed=T.Compose([T.Resize(self.inputsize), affine,hflip, vflip, blur, T.ToTensor(), normalize_transform])
  51. return composed
  52. class VisionClassifier(nn.Module):
  53. def __init__(self, nclasses, mname='resnet18'):
  54. super(VisionClassifier, self).__init__()
  55. self.nclasses=nclasses
  56. mdict={
  57. 'resnet18':models.resnet18,
  58. 'resnet50':models.resnet50,
  59. 'mobilenetv3':models.mobilenet_v3_large,
  60. 'densenet':models.densenet121,
  61. 'squeezenet':models.squeezenet1_0,
  62. 'inception':models.inception_v3,
  63. }
  64. mhandle=mdict.get(mname, None)
  65. if not mhandle:
  66. print(f'Model {mname} not supported. Supportd models are: {mdict.keys()}')
  67. quit()
  68. else:
  69. print(f'Initializing {mname}')
  70. fullmodel=mhandle(pretrained=True)
  71. self.backbone=nn.Sequential(*list(fullmodel.children())[:-1])
  72. self.flatten=nn.Flatten()
  73. hidden_dim=list(fullmodel.children())[-1].in_features
  74. self.linear=nn.Linear(hidden_dim, self.nclasses)
  75. def forward(self, x):
  76. x=self.backbone(x)
  77. x=self.flatten(x)
  78. x=self.linear(x)
  79. return x