train.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. r"""PyTorch Detection Training.
  2. To run in a multi-gpu environment, use the distributed launcher::
  3. python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \
  4. train.py ... --world-size $NGPU
  5. The default hyperparameters are tuned for training on 8 gpus and 2 images per gpu.
  6. --lr 0.02 --batch-size 2 --world-size 8
  7. If you use different number of gpus, the learning rate should be changed to 0.02/8*$NGPU.
  8. On top of that, for training Faster/Mask R-CNN, the default hyperparameters are
  9. --epochs 26 --lr-steps 16 22 --aspect-ratio-group-factor 3
  10. Also, if you train Keypoint R-CNN, the default hyperparameters are
  11. --epochs 46 --lr-steps 36 43 --aspect-ratio-group-factor 3
  12. Because the number of images is smaller in the person keypoint subset of COCO,
  13. the number of epochs should be adapted so that we have the same number of iterations.
  14. """
  15. import datetime
  16. import os
  17. import time
  18. import torch
  19. import torch.utils.data
  20. from torch import nn
  21. import torchvision
  22. import torchvision.models.detection
  23. import torchvision.models.detection.mask_rcnn
  24. from coco_utils import get_coco, get_coco_kp
  25. from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
  26. from engine import train_one_epoch, evaluate
  27. import utils
  28. import transforms as T
  29. # ### OUR CODE
  30. # import tensorboard and w&b
  31. import tensorboard
  32. import wandb
  33. # ### END OF OUR CODE
  34. def get_dataset(name, image_set, transform, data_path):
  35. paths = {
  36. "coco": (data_path, get_coco, 91),
  37. "coco_kp": (data_path, get_coco_kp, 2)
  38. }
  39. p, ds_fn, num_classes = paths[name]
  40. ds = ds_fn(p, image_set=image_set, transforms=transform)
  41. return ds, num_classes
  42. def get_transform(train):
  43. transforms = []
  44. transforms.append(T.ToTensor())
  45. if train:
  46. transforms.append(T.RandomHorizontalFlip(0.5))
  47. return T.Compose(transforms)
  48. def main(args):
  49. utils.init_distributed_mode(args)
  50. print(args)
  51. # applying logging only in the main process
  52. # ### OUR CODE ###
  53. if utils.is_main_process():
  54. # passing argparse config with hyperparameters
  55. tensorboard.args = vars(args)
  56. # init wandb using config and experiment name
  57. wandb.init(config=vars(args), name=tensorboard.name)
  58. # enable tensorboard sync
  59. wandb.init(sync_tensorboard=True)
  60. # ### END OF OUR CODE ###
  61. device = torch.device(args.device)
  62. # Data loading code
  63. print("Loading data")
  64. dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path)
  65. dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False), args.data_path)
  66. print("Creating data loaders")
  67. if args.distributed:
  68. train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
  69. test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
  70. else:
  71. train_sampler = torch.utils.data.RandomSampler(dataset)
  72. test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  73. if args.aspect_ratio_group_factor >= 0:
  74. group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
  75. train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
  76. else:
  77. train_batch_sampler = torch.utils.data.BatchSampler(
  78. train_sampler, args.batch_size, drop_last=True)
  79. data_loader = torch.utils.data.DataLoader(
  80. dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
  81. collate_fn=utils.collate_fn)
  82. data_loader_test = torch.utils.data.DataLoader(
  83. dataset_test, batch_size=1,
  84. sampler=test_sampler, num_workers=args.workers,
  85. collate_fn=utils.collate_fn)
  86. print("Creating model")
  87. model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes,
  88. pretrained=args.pretrained)
  89. model.to(device)
  90. model_without_ddp = model
  91. if args.distributed:
  92. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
  93. model_without_ddp = model.module
  94. params = [p for p in model.parameters() if p.requires_grad]
  95. optimizer = torch.optim.SGD(
  96. params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
  97. # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
  98. lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
  99. if args.resume:
  100. checkpoint = torch.load(args.resume, map_location='cpu')
  101. model_without_ddp.load_state_dict(checkpoint['model'])
  102. optimizer.load_state_dict(checkpoint['optimizer'])
  103. lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  104. args.start_epoch = checkpoint['epoch'] + 1
  105. if args.test_only:
  106. evaluate(model, data_loader_test, device=device)
  107. return
  108. print("Start training")
  109. start_time = time.time()
  110. for epoch in range(args.start_epoch, args.epochs):
  111. if args.distributed:
  112. train_sampler.set_epoch(epoch)
  113. train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
  114. lr_scheduler.step()
  115. if args.output_dir:
  116. utils.save_on_master({
  117. 'model': model_without_ddp.state_dict(),
  118. 'optimizer': optimizer.state_dict(),
  119. 'lr_scheduler': lr_scheduler.state_dict(),
  120. 'args': args,
  121. 'epoch': epoch},
  122. os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
  123. # evaluate after every epoch
  124. evaluate(model, data_loader_test, device=device)
  125. total_time = time.time() - start_time
  126. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  127. print('Training time {}'.format(total_time_str))
  128. if __name__ == "__main__":
  129. import argparse
  130. parser = argparse.ArgumentParser(
  131. description=__doc__)
  132. parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
  133. parser.add_argument('--dataset', default='coco', help='dataset')
  134. parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model')
  135. parser.add_argument('--device', default='cuda', help='device')
  136. parser.add_argument('-b', '--batch-size', default=2, type=int,
  137. help='images per gpu, the total batch size is $NGPU x batch_size')
  138. parser.add_argument('--epochs', default=26, type=int, metavar='N',
  139. help='number of total epochs to run')
  140. parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
  141. help='number of data loading workers (default: 4)')
  142. parser.add_argument('--lr', default=0.02, type=float,
  143. help='initial learning rate, 0.02 is the default value for training '
  144. 'on 8 gpus and 2 images_per_gpu')
  145. parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
  146. help='momentum')
  147. parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
  148. metavar='W', help='weight decay (default: 1e-4)',
  149. dest='weight_decay')
  150. parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
  151. parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, help='decrease lr every step-size epochs')
  152. parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
  153. parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
  154. parser.add_argument('--output-dir', default='.', help='path where to save')
  155. parser.add_argument('--resume', default='', help='resume from checkpoint')
  156. parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
  157. parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
  158. parser.add_argument(
  159. "--test-only",
  160. dest="test_only",
  161. help="Only test the model",
  162. action="store_true",
  163. )
  164. parser.add_argument(
  165. "--pretrained",
  166. dest="pretrained",
  167. help="Use pre-trained models from the modelzoo",
  168. action="store_true",
  169. )
  170. # distributed training parameters
  171. parser.add_argument('--world-size', default=1, type=int,
  172. help='number of distributed processes')
  173. parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
  174. args = parser.parse_args()
  175. if args.output_dir:
  176. utils.mkdir(args.output_dir)
  177. main(args)