main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # Authors: Zhaoshuo Li, Xingtong Liu, Francis X. Creighton, Russell H. Taylor, and Mathias Unberath
  2. #
  3. # Copyright (c) 2020. Johns Hopkins University - All rights reserved.
  4. import argparse
  5. import os
  6. import random
  7. import numpy as np
  8. import torch
  9. from dataset import build_data_loader
  10. from module.sttr import STTR
  11. from utilities.checkpoint_saver import Saver
  12. from utilities.eval import evaluate
  13. from utilities.inference import inference
  14. from utilities.summary_logger import TensorboardSummary
  15. from utilities.train import train_one_epoch
  16. from utilities.foward_pass import set_downsample
  17. from module.loss import build_criterion
  18. def get_args_parser():
  19. """
  20. Parse arguments
  21. """
  22. parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
  23. parser.add_argument('--lr', default=1e-4, type=float)
  24. parser.add_argument('--lr_backbone', default=1e-4, type=float)
  25. parser.add_argument('--lr_regression', default=2e-4, type=float)
  26. parser.add_argument('--lr_decay_rate', default=0.99, type=float)
  27. parser.add_argument('--batch_size', default=2, type=int)
  28. parser.add_argument('--weight_decay', default=1e-4, type=float)
  29. parser.add_argument('--epochs', default=300, type=int)
  30. parser.add_argument('--clip_max_norm', default=0.1, type=float,
  31. help='gradient clipping max norm')
  32. parser.add_argument('--device', default='cuda',
  33. help='device to use for training / testing')
  34. parser.add_argument('--seed', default=42, type=int)
  35. parser.add_argument('--resume', default='', help='resume from checkpoint')
  36. parser.add_argument('--ft', action='store_true', help='load model from checkpoint, but discard optimizer state')
  37. parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
  38. help='start epoch')
  39. parser.add_argument('--eval', action='store_true')
  40. parser.add_argument('--inference', action='store_true')
  41. parser.add_argument('--num_workers', default=1, type=int)
  42. parser.add_argument('--checkpoint', type=str, default='dev', help='checkpoint name for current experiment')
  43. parser.add_argument('--pre_train', action='store_true')
  44. parser.add_argument('--downsample', default=3, type=int, help='Ratio to downsample width/height')
  45. parser.add_argument('--apex', action='store_true', help='enable mixed precision training')
  46. # * STTR
  47. parser.add_argument('--channel_dim', default=128, type=int,
  48. help="Size of the embeddings (dimension of the transformer)")
  49. # * Positional Encoding
  50. parser.add_argument('--position_encoding', default='sine1d_rel', type=str, choices=('sine1d_rel', 'none'),
  51. help="Type of positional embedding to use on top of the image features")
  52. # * Transformer
  53. parser.add_argument('--num_attn_layers', default=6, type=int, help="Number of attention layers in the transformer")
  54. parser.add_argument('--nheads', default=8, type=int,
  55. help="Number of attention heads inside the transformer's attentions")
  56. # * Regression Head
  57. parser.add_argument('--regression_head', default='ot', type=str, choices=('softmax', 'ot'),
  58. help='Normalization to be used')
  59. parser.add_argument('--context_adjustment_layer', default='cal', choices=['cal', 'none'], type=str)
  60. parser.add_argument('--cal_num_blocks', default=8, type=int)
  61. parser.add_argument('--cal_feat_dim', default=16, type=int)
  62. parser.add_argument('--cal_expansion_ratio', default=4, type=int)
  63. # * Dataset parameters
  64. parser.add_argument('--dataset', default='sceneflow', type=str, help='dataset to train/eval on')
  65. parser.add_argument('--dataset_directory', default='', type=str, help='directory to dataset')
  66. parser.add_argument('--validation', default='validation', type=str, choices={'validation', 'validation_all'},
  67. help='If we validate on all provided training images')
  68. # * Loss
  69. parser.add_argument('--px_error_threshold', type=int, default=3,
  70. help='Number of pixels for error computation (default 3 px)')
  71. parser.add_argument('--loss_weight', type=str, default='rr:1.0, l1_raw:1.0, l1:1.0, occ_be:1.0',
  72. help='Weight for losses')
  73. parser.add_argument('--validation_max_disp', type=int, default=-1)
  74. return parser
  75. def save_checkpoint(epoch, model, optimizer, lr_scheduler, prev_best, checkpoint_saver, best, amp=None):
  76. """
  77. Save current state of training
  78. """
  79. # save model
  80. checkpoint = {
  81. 'epoch': epoch,
  82. 'state_dict': model.state_dict(),
  83. 'optimizer': optimizer.state_dict(),
  84. 'lr_scheduler': lr_scheduler.state_dict(),
  85. 'best_pred': prev_best
  86. }
  87. if amp is not None:
  88. checkpoint['amp'] = amp.state_dict()
  89. if best:
  90. checkpoint_saver.save_checkpoint(checkpoint, 'model.pth.tar', write_best=False)
  91. else:
  92. checkpoint_saver.save_checkpoint(checkpoint, 'epoch_' + str(epoch) + '_model.pth.tar', write_best=False)
  93. def print_param(model):
  94. """
  95. print number of parameters in the model
  96. """
  97. n_parameters = sum(p.numel() for n, p in model.named_parameters() if 'backbone' in n and p.requires_grad)
  98. print('number of params in backbone:', f'{n_parameters:,}')
  99. n_parameters = sum(p.numel() for n, p in model.named_parameters() if
  100. 'transformer' in n and 'regression' not in n and p.requires_grad)
  101. print('number of params in transformer:', f'{n_parameters:,}')
  102. n_parameters = sum(p.numel() for n, p in model.named_parameters() if 'tokenizer' in n and p.requires_grad)
  103. print('number of params in tokenizer:', f'{n_parameters:,}')
  104. n_parameters = sum(p.numel() for n, p in model.named_parameters() if 'regression' in n and p.requires_grad)
  105. print('number of params in regression:', f'{n_parameters:,}')
  106. def main(args):
  107. # get device
  108. device = torch.device(args.device)
  109. # fix the seed for reproducibility
  110. seed = args.seed
  111. torch.manual_seed(seed)
  112. np.random.seed(seed)
  113. random.seed(seed)
  114. # build model
  115. model = STTR(args).to(device)
  116. print_param(model)
  117. # set learning rate
  118. param_dicts = [
  119. {"params": [p for n, p in model.named_parameters() if
  120. "backbone" not in n and "regression" not in n and p.requires_grad]},
  121. {
  122. "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
  123. "lr": args.lr_backbone,
  124. },
  125. {
  126. "params": [p for n, p in model.named_parameters() if "regression" in n and p.requires_grad],
  127. "lr": args.lr_regression,
  128. },
  129. ]
  130. # define optimizer and learning rate scheduler
  131. optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay)
  132. lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_decay_rate)
  133. # mixed precision training
  134. if args.apex:
  135. from apex import amp
  136. model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
  137. else:
  138. amp = None
  139. # load checkpoint if provided
  140. prev_best = np.inf
  141. if args.resume != '':
  142. if not os.path.isfile(args.resume):
  143. raise RuntimeError(f"=> no checkpoint found at '{args.resume}'")
  144. checkpoint = torch.load(args.resume)
  145. pretrained_dict = checkpoint['state_dict']
  146. missing, unexpected = model.load_state_dict(pretrained_dict, strict=False)
  147. # check missing and unexpected keys
  148. if len(missing) > 0:
  149. print("Missing keys: ", ','.join(missing))
  150. raise Exception("Missing keys.")
  151. unexpected_filtered = [k for k in unexpected if
  152. 'running_mean' not in k and 'running_var' not in k] # skip bn params
  153. if len(unexpected_filtered) > 0:
  154. print("Unexpected keys: ", ','.join(unexpected_filtered))
  155. raise Exception("Unexpected keys.")
  156. print("Pre-trained model successfully loaded.")
  157. # if not ft/inference/eval, load states for optimizer, lr_scheduler, amp and prev best
  158. if not (args.ft or args.inference or args.eval):
  159. if len(unexpected) > 0: # loaded checkpoint has bn parameters, legacy resume, skip loading
  160. raise Exception("Resuming legacy model with BN parameters. Not possible due to BN param change. " +
  161. "Do you want to finetune or inference? If so, check your arguments.")
  162. else:
  163. args.start_epoch = checkpoint['epoch'] + 1
  164. optimizer.load_state_dict(checkpoint['optimizer'])
  165. lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  166. prev_best = checkpoint['best_pred']
  167. if args.apex:
  168. amp.load_state_dict(checkpoint['amp'])
  169. print("Pre-trained optimizer, lr scheduler and stats successfully loaded.")
  170. # inference
  171. if args.inference:
  172. print("Start inference")
  173. _, _, data_loader = build_data_loader(args)
  174. inference(model, data_loader, device, args.downsample)
  175. return
  176. # initiate saver and logger
  177. checkpoint_saver = Saver(args)
  178. summary_writer = TensorboardSummary(checkpoint_saver.experiment_dir)
  179. # build dataloader
  180. data_loader_train, data_loader_val, _ = build_data_loader(args)
  181. # build loss criterion
  182. criterion = build_criterion(args)
  183. # set downsample rate
  184. set_downsample(args)
  185. # eval
  186. if args.eval:
  187. print("Start evaluation")
  188. evaluate(model, criterion, data_loader_val, device, 0, summary_writer, True)
  189. return
  190. # train
  191. print("Start training")
  192. for epoch in range(args.start_epoch, args.epochs):
  193. # train
  194. print("Epoch: %d" % epoch)
  195. train_one_epoch(model, data_loader_train, optimizer, criterion, device, epoch, summary_writer,
  196. args.clip_max_norm, amp)
  197. # step lr if not pretraining
  198. if not args.pre_train:
  199. lr_scheduler.step()
  200. print("current learning rate", lr_scheduler.get_lr())
  201. # empty cache
  202. torch.cuda.empty_cache()
  203. # save if pretrain, save every 50 epochs
  204. if args.pre_train or epoch % 50 == 0:
  205. save_checkpoint(epoch, model, optimizer, lr_scheduler, prev_best, checkpoint_saver, False, amp)
  206. # validate
  207. eval_stats = evaluate(model, criterion, data_loader_val, device, epoch, summary_writer, False)
  208. # save if best
  209. if prev_best > eval_stats['epe'] and 0.5 > eval_stats['px_error_rate']:
  210. save_checkpoint(epoch, model, optimizer, lr_scheduler, prev_best, checkpoint_saver, True, amp)
  211. # save final model
  212. save_checkpoint(epoch, model, optimizer, lr_scheduler, prev_best, checkpoint_saver, False, amp)
  213. return
  214. if __name__ == '__main__':
  215. ap = argparse.ArgumentParser('STTR training and evaluation script', parents=[get_args_parser()])
  216. args_ = ap.parse_args()
  217. main(args_)