|
@@ -5,7 +5,7 @@ import torch
|
|
|
import torch.cuda.nccl as nccl
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
-from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
|
|
|
+from llama_cookbook.policies import fpSixteen,bfSixteen, get_llama_wrapper
|
|
|
from torch.distributed._tensor.device_mesh import init_device_mesh
|
|
|
|
|
|
def fsdp_auto_wrap_policy(model, transformer_layer_names):
|