main.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import torch
  2. from pytorch_lightning import (
  3. Trainer,
  4. seed_everything,
  5. )
  6. from pytorch_lightning.callbacks import ModelCheckpoint
  7. from torchvision.models import (
  8. resnet18,
  9. resnet50,
  10. )
  11. from model.model import (
  12. LitFood101,
  13. LitFood101KD,
  14. )
  15. from utils.args import get_program_level_args
  16. def main():
  17. parser = get_program_level_args()
  18. parser = LitFood101.add_model_specific_args(parser)
  19. parser = Trainer.add_argparse_args(parser)
  20. args = parser.parse_args()
  21. seed_everything(args.seed)
  22. checkpoint_callback = ModelCheckpoint(monitor="avg_val_acc", mode="max")
  23. trainer = Trainer.from_argparse_args(
  24. args,
  25. deterministic=True,
  26. benchmark=False,
  27. checkpoint_callback=checkpoint_callback,
  28. precision=16 if args.amp_level != "O0" else 32,
  29. )
  30. # create model
  31. model = resnet18(pretrained=True)
  32. if args.use_knowledge_distillation:
  33. teacher_model = resnet50(pretrained=False)
  34. model = LitFood101KD(model, teacher_model, args)
  35. else:
  36. model = LitFood101(model, args)
  37. if args.evaluate:
  38. checkpoint = torch.load(args.checkpoint)
  39. model.load_state_dict(checkpoint["state_dict"])
  40. trainer.test(model, test_dataloaders=model.test_dataloader())
  41. return 0
  42. trainer.fit(model)
  43. trainer.test()
  44. if __name__ == "__main__":
  45. main()