| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 | 
							- #!/usr/bin/env python3
 
- import argparse
 
- from string import Template
 
- import sys
 
- torch = {
 
-     '1.9': {
 
-         '10.2': {
 
-             'torch_version': '1.9.0+cu102',
 
-             'torchvision_version': '0.10.0+cu102',
 
-             'torchaudio_version': '0.9.0',
 
-             'find_links': 'https://download.pytorch.org/whl/torch_stable.html'
 
-         },
 
-         '11.1': {
 
-             'torch_version': '1.9.0+cu111',
 
-             'torchvision_version': '0.10.0+cu111',
 
-             'torchaudio_version': '0.9.0',
 
-             'find_links': 'https://download.pytorch.org/whl/torch_stable.html'
 
-         }
 
-     },
 
-     '1.10': {
 
-         '11.3': {
 
-             'torch_version': '1.10.2+cu113',
 
-             'torchvision_version': '0.11.3+cu113',
 
-             'torchaudio_version': '0.10.2+cu113',
 
-             'find_links': (
 
-                 'https://download.pytorch.org/whl/cu113/torch_stable.html'
 
-             )
 
-         }
 
-     }
 
- }
 
- def render(torch_version: str, cuda_version: str) -> str:
 
-     with open('pytorch_cu.def.template', 'r', encoding='utf8') as f:
 
-         template = Template(f.read())
 
-     return template.substitute(**torch[torch_version][cuda_version])
 
- def write_def(torch_version: str, cuda_version: str) -> None:
 
-     file_name = f'pytorch_{torch_version}_cu_{cuda_version}.def'
 
-     text = render(torch_version, cuda_version)
 
-     with open(file_name, 'w', encoding='utf8') as f:
 
-         f.write(text)
 
- def main() -> None:
 
-     parser = argparse.ArgumentParser(
 
-         description='Template Pytorch definition files'
 
-     )
 
-     parser.add_argument(
 
-         'torch',
 
-         help='Torch version',
 
-         type=str,
 
-         nargs='?',
 
-         choices=['all'] + list(torch.keys()),
 
-         default='all'
 
-     )
 
-     parser.add_argument(
 
-         'cuda',
 
-         help='CUDA version',
 
-         type=str,
 
-         nargs='?',
 
-         default='all',
 
-     )
 
-     args = parser.parse_args()
 
-     match (args.torch, args.cuda):
 
-         # Build all def files
 
-         case('all', 'all'):
 
-             for torch_version, cuda in torch.items():
 
-                 for cuda_version in cuda.keys():
 
-                     write_def(torch_version, cuda_version)
 
-         # Building all def files for a particular CUDA version is unsupported
 
-         case('all', _):
 
-             print('Building all def files for a particular CUDA version is'
 
-                   ' not supported')
 
-             sys.exit(1)
 
-         # Build all def files for a particular Torch version
 
-         case(_, 'all'):
 
-             cuda = torch[args.torch]
 
-             for cuda_version in cuda.keys():
 
-                 write_def(args.torch, cuda_version)
 
-         # Build a single def file for a single Torch & CUDA combination
 
-         case(_, _):
 
-             if args.cuda not in (supported_versions :=
 
-                                  torch[args.torch].keys()):
 
-                 print(f'Torch {args.torch} does not support CUDA {args.cuda}')
 
-                 print('Supported CUDA versions:'
 
-                       f' {" ".join(supported_versions)}')
 
-                 sys.exit(1)
 
-             write_def(args.torch, args.cuda)
 
- if __name__ == '__main__':
 
-     main()
 
 
  |