| 1234567891011121314151617181920212223242526272829303132333435363738 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates.
 
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
- def fsdp_auto_wrap_policy(model, transformer_layer_name):
 
-     import functools
 
-     import os
 
-     from accelerate import FullyShardedDataParallelPlugin
 
-     from transformers.models.t5.modeling_t5 import T5Block
 
-     from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
 
-     from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
 
-     def lambda_policy_fn(module):
 
-         if (
 
-             len(list(module.named_children())) == 0
 
-             and getattr(module, "weight", None) is not None
 
-             and module.weight.requires_grad
 
-         ):
 
-             return True
 
-         return False
 
-     lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
 
-     transformer_wrap_policy = functools.partial(
 
-         transformer_auto_wrap_policy,
 
-         transformer_layer_cls=(
 
-             PrefixEncoder,
 
-             PromptEncoder,
 
-             PromptEmbedding,
 
-             transformer_layer_name,
 
-             # FullyShardedDataParallelPlugin.get_module_class_from_name(
 
-             #     model, transformer_layer_name
 
-             # ),
 
-         ),
 
-     )
 
-     auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
 
-     return auto_wrap_policy
 
 
  |