template.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. #!/usr/bin/env python3
  2. import argparse
  3. from string import Template
  4. CUDA = {
  5. '10.2': {
  6. 'torch_version': '1.9.0+cu102',
  7. 'torchvision_version': '0.10.0+cu102',
  8. 'torchaudio_version': '0.9.0',
  9. 'find_links': 'https://download.pytorch.org/whl/torch_stable.html'
  10. },
  11. '11.1': {
  12. 'torch_version': '1.9.0+cu111',
  13. 'torchvision_version': '0.10.0+cu111',
  14. 'torchaudio_version': '0.9.0',
  15. 'find_links': 'https://download.pytorch.org/whl/torch_stable.html'
  16. },
  17. '11.3': {
  18. 'torch_version': '1.10.2+cu113',
  19. 'torchvision_version': '0.11.3+cu113',
  20. 'torchaudio_version': '0.10.2+cu113',
  21. 'find_links': (
  22. 'https://download.pytorch.org/whl/cu113/torch_stable.html'
  23. )
  24. }
  25. }
  26. def render(cuda_version: str) -> str:
  27. with open('pytorch_cu.def.template', 'r', encoding='utf8') as f:
  28. template = Template(f.read())
  29. return template.substitute(**CUDA[cuda_version])
  30. def write_def(cuda_version: str, text: str) -> None:
  31. file_name = f'pytorch_cu_{cuda_version}.def'
  32. with open(file_name, 'w', encoding='utf8') as f:
  33. f.write(text)
  34. def main() -> None:
  35. parser = argparse.ArgumentParser(
  36. description='Template Pytorch definition files'
  37. )
  38. parser.add_argument(
  39. 'cuda',
  40. help='CUDA version',
  41. type=str,
  42. choices=['all'] + list(CUDA.keys())
  43. )
  44. args = parser.parse_args()
  45. if args.cuda == 'all':
  46. for cuda_version in CUDA.keys():
  47. write_def(cuda_version, render(cuda_version))
  48. else:
  49. write_def(args.cuda, render(args.cuda))
  50. if __name__ == '__main__':
  51. main()