|
@@ -0,0 +1,64 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+
|
|
|
+import argparse
|
|
|
+from string import Template
|
|
|
+
|
|
|
+CUDA = {
|
|
|
+ '10.2': {
|
|
|
+ 'torch_version': '1.9.0+cu102',
|
|
|
+ 'torchvision_version': '0.10.0+cu102',
|
|
|
+ 'torchaudio_version': '0.9.0',
|
|
|
+ 'find_links': 'https://download.pytorch.org/whl/torch_stable.html'
|
|
|
+ },
|
|
|
+ '11.1': {
|
|
|
+ 'torch_version': '1.9.0+cu111',
|
|
|
+ 'torchvision_version': '0.10.0+cu111',
|
|
|
+ 'torchaudio_version': '0.9.0',
|
|
|
+ 'find_links': 'https://download.pytorch.org/whl/torch_stable.html'
|
|
|
+ },
|
|
|
+ '11.3': {
|
|
|
+ 'torch_version': '1.10.2+cu113',
|
|
|
+ 'torchvision_version': '0.11.3+cu113',
|
|
|
+ 'torchaudio_version': '0.10.2+cu113',
|
|
|
+ 'find_links': (
|
|
|
+ 'https://download.pytorch.org/whl/cu113/torch_stable.html'
|
|
|
+ )
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+def render(cuda_version: str) -> str:
|
|
|
+ with open('pytorch_cu.def.template', 'r', encoding='utf8') as f:
|
|
|
+ template = Template(f.read())
|
|
|
+
|
|
|
+ return template.substitute(**CUDA[cuda_version])
|
|
|
+
|
|
|
+
|
|
|
+def write_def(cuda_version: str, text: str) -> None:
|
|
|
+ file_name = f'pytorch_cu_{cuda_version}.def'
|
|
|
+ with open(file_name, 'w', encoding='utf8') as f:
|
|
|
+ f.write(text)
|
|
|
+
|
|
|
+
|
|
|
+def main() -> None:
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ description='Template Pytorch definition files'
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ 'cuda',
|
|
|
+ help='CUDA version',
|
|
|
+ type=str,
|
|
|
+ choices=['all'] + list(CUDA.keys())
|
|
|
+ )
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ if args.cuda == 'all':
|
|
|
+ for cuda_version in CUDA.keys():
|
|
|
+ write_def(cuda_version, render(cuda_version))
|
|
|
+ else:
|
|
|
+ write_def(args.cuda, render(args.cuda))
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ main()
|