train.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from dataset import load_data
  2. from models import MRnet
  3. from config import config
  4. import torch
  5. from torch.utils.tensorboard import SummaryWriter
  6. from utils import _train_model, _evaluate_model, _get_lr
  7. import time
  8. import torch.utils.data as data
  9. import os
  10. """Performs training of a specified model.
  11. Input params:
  12. config_file: Takes in configurations to train with
  13. """
  14. def train(config : dict):
  15. """
  16. Function where actual training takes place
  17. Args:
  18. config (dict) : Configuration to train with
  19. """
  20. print('Starting to Train Model...')
  21. train_loader, val_loader, train_wts, val_wts = load_data(config['task'])
  22. print('Initializing Model...')
  23. model = MRnet()
  24. if torch.cuda.is_available():
  25. model = model.cuda()
  26. train_wts = train_wts.cuda()
  27. val_wts = val_wts.cuda()
  28. print('Initializing Loss Method...')
  29. criterion = torch.nn.BCEWithLogitsLoss(pos_weight=train_wts)
  30. val_criterion = torch.nn.BCEWithLogitsLoss(pos_weight=val_wts)
  31. if torch.cuda.is_available():
  32. criterion = criterion.cuda()
  33. val_criterion = val_criterion.cuda()
  34. print('Setup the Optimizer')
  35. optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
  36. scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
  37. optimizer, patience=3, factor=.3, threshold=1e-4, verbose=True)
  38. starting_epoch = config['starting_epoch']
  39. num_epochs = config['max_epoch']
  40. patience = config['patience']
  41. log_train = config['log_train']
  42. log_val = config['log_val']
  43. best_val_loss = float('inf')
  44. best_val_auc = float(0)
  45. print('Starting Training')
  46. writer = SummaryWriter(comment='lr={} task={}'.format(config['lr'], config['task']))
  47. t_start_training = time.time()
  48. for epoch in range(starting_epoch, num_epochs):
  49. current_lr = _get_lr(optimizer)
  50. epoch_start_time = time.time() # timer for entire epoch
  51. train_loss, train_auc = _train_model(
  52. model, train_loader, epoch, num_epochs, optimizer, criterion, writer, current_lr, log_train)
  53. val_loss, val_auc = _evaluate_model(
  54. model, val_loader, val_criterion, epoch, num_epochs, writer, current_lr, log_val)
  55. writer.add_scalar('Train/Avg Loss', train_loss, epoch)
  56. writer.add_scalar('Val/Avg Loss', val_loss, epoch)
  57. scheduler.step(val_loss)
  58. t_end = time.time()
  59. delta = t_end - epoch_start_time
  60. print("train loss : {0} | train auc {1} | val loss {2} | val auc {3} | elapsed time {4} s".format(
  61. train_loss, train_auc, val_loss, val_auc, delta))
  62. print('-' * 30)
  63. writer.flush()
  64. if val_auc > best_val_auc:
  65. best_val_auc = val_auc
  66. if bool(config['save_model']):
  67. file_name = 'model_{}_{}_val_auc_{:0.4f}_train_auc_{:0.4f}_epoch_{}.pth'.format(config['exp_name'], config['task'], val_auc, train_auc, epoch+1)
  68. torch.save({
  69. 'model_state_dict': model.state_dict()
  70. }, './weights/{}/{}'.format(config['task'],file_name))
  71. t_end_training = time.time()
  72. print(f'training took {t_end_training - t_start_training} s')
  73. writer.flush()
  74. writer.close()
  75. if __name__ == '__main__':
  76. print('Training Configuration')
  77. print(config)
  78. train(config=config)
  79. print('Training Ended...')