|
@@ -8,8 +8,6 @@ def fsdp_auto_wrap_policy(model, transformer_layer_name):
|
|
|
|
|
|
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
|
|
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
|
|
|
|
|
|
- from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
|
|
|
|
-
|
|
|
|
def lambda_policy_fn(module):
|
|
def lambda_policy_fn(module):
|
|
if (
|
|
if (
|
|
len(list(module.named_children())) == 0
|
|
len(list(module.named_children())) == 0
|
|
@@ -23,13 +21,7 @@ def fsdp_auto_wrap_policy(model, transformer_layer_name):
|
|
transformer_wrap_policy = functools.partial(
|
|
transformer_wrap_policy = functools.partial(
|
|
transformer_auto_wrap_policy,
|
|
transformer_auto_wrap_policy,
|
|
transformer_layer_cls=(
|
|
transformer_layer_cls=(
|
|
- PrefixEncoder,
|
|
|
|
- PromptEncoder,
|
|
|
|
- PromptEmbedding,
|
|
|
|
transformer_layer_name,
|
|
transformer_layer_name,
|
|
- # FullyShardedDataParallelPlugin.get_module_class_from_name(
|
|
|
|
- # model, transformer_layer_name
|
|
|
|
- # ),
|
|
|
|
),
|
|
),
|
|
)
|
|
)
|
|
|
|
|