|
@@ -6,7 +6,6 @@ import time
|
|
|
import yaml
|
|
|
from contextlib import nullcontext
|
|
|
from pathlib import Path
|
|
|
-from pkg_resources import packaging
|
|
|
from datetime import datetime
|
|
|
import contextlib
|
|
|
|
|
@@ -474,7 +473,7 @@ def get_policies(cfg, rank):
|
|
|
verify_bfloat_support = ((
|
|
|
torch.version.cuda
|
|
|
and torch.cuda.is_bf16_supported()
|
|
|
- and packaging.version.parse(torch.version.cuda).release >= (11, 0)
|
|
|
+ and torch.version.cuda >= "11.0"
|
|
|
and dist.is_nccl_available()
|
|
|
and nccl.version() >= (2, 10)
|
|
|
) or
|