template.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. #!/usr/bin/env python3
  2. import argparse
  3. from string import Template
  4. import sys
  5. torch = {
  6. '1.9': {
  7. '10.2': {
  8. 'torch_version': '1.9.0+cu102',
  9. 'torchvision_version': '0.10.0+cu102',
  10. 'torchaudio_version': '0.9.0',
  11. 'find_links': 'https://download.pytorch.org/whl/torch_stable.html'
  12. },
  13. '11.1': {
  14. 'torch_version': '1.9.0+cu111',
  15. 'torchvision_version': '0.10.0+cu111',
  16. 'torchaudio_version': '0.9.0',
  17. 'find_links': 'https://download.pytorch.org/whl/torch_stable.html'
  18. }
  19. },
  20. '1.10': {
  21. '11.3': {
  22. 'torch_version': '1.10.2+cu113',
  23. 'torchvision_version': '0.11.3+cu113',
  24. 'torchaudio_version': '0.10.2+cu113',
  25. 'find_links': (
  26. 'https://download.pytorch.org/whl/cu113/torch_stable.html'
  27. )
  28. }
  29. }
  30. }
  31. def render(torch_version: str, cuda_version: str) -> str:
  32. with open('pytorch_cu.def.template', 'r', encoding='utf8') as f:
  33. template = Template(f.read())
  34. return template.substitute(**torch[torch_version][cuda_version])
  35. def write_def(torch_version: str, cuda_version: str) -> None:
  36. file_name = f'pytorch_{torch_version}_cu_{cuda_version}.def'
  37. text = render(torch_version, cuda_version)
  38. with open(file_name, 'w', encoding='utf8') as f:
  39. f.write(text)
  40. def main() -> None:
  41. parser = argparse.ArgumentParser(
  42. description='Template Pytorch definition files'
  43. )
  44. parser.add_argument(
  45. 'torch',
  46. help='Torch version',
  47. type=str,
  48. nargs='?',
  49. choices=['all'] + list(torch.keys()),
  50. default='all'
  51. )
  52. parser.add_argument(
  53. 'cuda',
  54. help='CUDA version',
  55. type=str,
  56. nargs='?',
  57. default='all',
  58. )
  59. args = parser.parse_args()
  60. match (args.torch, args.cuda):
  61. # Build all def files
  62. case('all', 'all'):
  63. for torch_version, cuda in torch.items():
  64. for cuda_version in cuda.keys():
  65. write_def(torch_version, cuda_version)
  66. # Building all def files for a particular CUDA version is unsupported
  67. case('all', _):
  68. print('Building all def files for a particular CUDA version is'
  69. ' not supported')
  70. sys.exit(1)
  71. # Build all def files for a particular Torch version
  72. case(_, 'all'):
  73. cuda = torch[args.torch]
  74. for cuda_version in cuda.keys():
  75. write_def(args.torch, cuda_version)
  76. # Build a single def file for a single Torch & CUDA combination
  77. case(_, _):
  78. if args.cuda not in (supported_versions :=
  79. torch[args.torch].keys()):
  80. print(f'Torch {args.torch} does not support CUDA {args.cuda}')
  81. print('Supported CUDA versions:'
  82. f' {" ".join(supported_versions)}')
  83. sys.exit(1)
  84. write_def(args.torch, args.cuda)
  85. if __name__ == '__main__':
  86. main()