train.py 529 B

1234567891011121314151617
  1. import darklight as dl
  2. import torch
  3. from vit import VisionTransformer
  4. import vitconfigs as vcfg
  5. net=VisionTransformer(vcfg.base)
  6. dm=dl.ImageNetManager('/sfnvme/imagenet/', size=[224,224], bsize=128)
  7. opt_params={
  8. 'optimizer': torch.optim.AdamW,
  9. 'okwargs': {'lr': 1e-4, 'weight_decay':0.05},
  10. 'scheduler':torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
  11. 'skwargs': {'T_0':10,'T_mult':2},
  12. 'amplevel': None
  13. }
  14. trainer=dl.StudentTrainer(net, dm, None, opt_params=opt_params)
  15. trainer.train(epochs=300, save='vitbase_{}.pth')