| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 | # Copyright (c) Meta Platforms, Inc. and affiliates.# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.import torch.distributed as distimport torch.nn as nnimport torchfrom transformers.models.llama.modeling_llama import LlamaDecoderLayerfrom torch.distributed.fsdp.fully_sharded_data_parallel import (    FullyShardedDataParallel as FSDP,    CPUOffload,    BackwardPrefetch,    MixedPrecision,)from torch.distributed.fsdp.wrap import (    transformer_auto_wrap_policy,    size_based_auto_wrap_policy,    enable_wrap,    wrap,)import functoolsfrom typing import Typedef get_size_policy(min_params=1e8):    num_wrap_policy = functools.partial(        size_based_auto_wrap_policy, min_num_params=min_params    )    return num_wrap_policydef get_llama_wrapper():    """we register our main layer class and use the fsdp transformer wrapping policy    ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers    """    # ====   use new transformer wrapper    llama_auto_wrap_policy = functools.partial(        transformer_auto_wrap_policy,        transformer_layer_cls={            LlamaDecoderLayer,        },    )    return llama_auto_wrap_policy
 |