1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- #!/usr/bin/env python3
- import argparse
- from string import Template
- CUDA = {
- '10.2': {
- 'torch_version': '1.9.0+cu102',
- 'torchvision_version': '0.10.0+cu102',
- 'torchaudio_version': '0.9.0',
- 'find_links': 'https://download.pytorch.org/whl/torch_stable.html'
- },
- '11.1': {
- 'torch_version': '1.9.0+cu111',
- 'torchvision_version': '0.10.0+cu111',
- 'torchaudio_version': '0.9.0',
- 'find_links': 'https://download.pytorch.org/whl/torch_stable.html'
- },
- '11.3': {
- 'torch_version': '1.10.2+cu113',
- 'torchvision_version': '0.11.3+cu113',
- 'torchaudio_version': '0.10.2+cu113',
- 'find_links': (
- 'https://download.pytorch.org/whl/cu113/torch_stable.html'
- )
- }
- }
- def render(cuda_version: str) -> str:
- with open('pytorch_cu.def.template', 'r', encoding='utf8') as f:
- template = Template(f.read())
- return template.substitute(**CUDA[cuda_version])
- def write_def(cuda_version: str, text: str) -> None:
- file_name = f'pytorch_cu_{cuda_version}.def'
- with open(file_name, 'w', encoding='utf8') as f:
- f.write(text)
- def main() -> None:
- parser = argparse.ArgumentParser(
- description='Template Pytorch definition files'
- )
- parser.add_argument(
- 'cuda',
- help='CUDA version',
- type=str,
- choices=['all'] + list(CUDA.keys())
- )
- args = parser.parse_args()
- if args.cuda == 'all':
- for cuda_version in CUDA.keys():
- write_def(cuda_version, render(cuda_version))
- else:
- write_def(args.cuda, render(args.cuda))
- if __name__ == '__main__':
- main()
|