| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438 | # Copyright (c) Meta Platforms, Inc. and affiliates.# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.import dataclassesimport osimport randomfrom collections import Counterfrom warnings import warnimport fireimport numpy as npimport torchimport torch.optim as optimfrom accelerate.utils import is_xpu_availablefrom llama_cookbook.configs import (    fsdp_config as FSDP_CONFIG,    quantization_config as QUANTIZATION_CONFIG,    train_config as TRAIN_CONFIG,)from llama_cookbook.data.concatenator import ConcatDatasetfrom llama_cookbook.policies import AnyPrecisionAdamW, apply_fsdp_checkpointingfrom llama_cookbook.utils import fsdp_auto_wrap_policyfrom llama_cookbook.utils.config_utils import (    check_fsdp_config,    generate_dataset_config,    generate_peft_config,    get_dataloader_kwargs,    update_config,)from llama_cookbook.utils.dataset_utils import (    get_custom_data_collator,    get_preprocessed_dataset,)from llama_cookbook.utils.fsdp_utils import hsdp_device_mesh, get_policiesfrom llama_cookbook.utils.train_utils import (    clear_gpu_cache,    freeze_transformer_layers,    freeze_LLM_only,    print_model_size,    print_frozen_model_status,    setup,    setup_environ_flags,    train,)from peft import get_peft_model, PeftModelfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategyfrom torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffloadfrom torch.optim.lr_scheduler import StepLRfrom transformers import (    AutoConfig,    AutoProcessor,    AutoTokenizer,    BitsAndBytesConfig,    LlamaForCausalLM,    MllamaForConditionalGeneration,)from transformers.models.llama.modeling_llama import LlamaDecoderLayerfrom transformers.models.mllama.modeling_mllama import (    MllamaCrossAttentionDecoderLayer,    MllamaSelfAttentionDecoderLayer,    MllamaVisionEncoderLayer,)def setup_wandb(train_config, fsdp_config, **kwargs):    try:        import wandb    except ImportError:        raise ImportError(            "You are trying to use wandb which is not currently installed. "            "Please install it using pip install wandb"        )    from llama_cookbook.configs import wandb_config as WANDB_CONFIG    wandb_config = WANDB_CONFIG()    update_config(wandb_config, **kwargs)    init_dict = dataclasses.asdict(wandb_config)    run = wandb.init(**init_dict)    run.config.update(train_config)    run.config.update(fsdp_config, allow_val_change=True)    return rundef main(**kwargs):    # Update the configuration for the training and sharding process    train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()    update_config((train_config, fsdp_config), **kwargs)    # Set the seeds for reproducibility    if is_xpu_available():        torch.xpu.manual_seed(train_config.seed)    torch.manual_seed(train_config.seed)    random.seed(train_config.seed)    np.random.seed(train_config.seed)    if train_config.enable_fsdp:        setup()        # torchrun specific        local_rank = int(os.environ["LOCAL_RANK"])        rank = int(os.environ["RANK"])        world_size = int(os.environ["WORLD_SIZE"])    if torch.distributed.is_initialized():        if is_xpu_available():            torch.xpu.set_device(local_rank)        elif torch.cuda.is_available():            torch.cuda.set_device(local_rank)        clear_gpu_cache(local_rank)        setup_environ_flags(rank)    wandb_run = None    if train_config.use_wandb:        if not train_config.enable_fsdp or rank == 0:            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)     # setting quantization configs    bnb_config = None    if train_config.quantization:        if type(train_config.quantization) == type(True):            warn(                "Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.",                FutureWarning,            )            train_config.quantization = "8bit"        if train_config.quantization == "8bit" and train_config.enable_fsdp:            raise ValueError(                "8bit quantization is not supported with FSDP, please use 4bit quantization"            )        quant_config = QUANTIZATION_CONFIG()        update_config(quant_config, **kwargs)        bnb_config = quant_config.create_bnb_config(train_config.quantization)               if train_config.enable_fsdp:            if train_config.quantization == "4bit":                bnb_config.bnb_4bit_quant_storage = bnb_config.bnb_4bit_compute_dtype                from logging import getLogger                logger = getLogger()                logger.warning(                    "FSDP and 4-bit QLoRA enabled. Setting `bnb_4bit_quant_storage` "                    f"to {bnb_config.bnb_4bit_compute_dtype} for compatibility."                )    # Load the pre-trained model and setup its configuration    use_cache = False if train_config.enable_fsdp else None    config = AutoConfig.from_pretrained(train_config.model_name)    if config.model_type == "mllama":        is_vision = True        model = MllamaForConditionalGeneration.from_pretrained(            train_config.model_name,            quantization_config=bnb_config,            attn_implementation="sdpa" if train_config.use_fast_kernels else None,            device_map=(                "auto"                if train_config.quantization and not train_config.enable_fsdp                else None            ),            torch_dtype=torch.float16 if train_config.use_fp16 else "auto",        )        processor = AutoProcessor.from_pretrained(            train_config.model_name            if train_config.tokenizer_name is None            else train_config.tokenizer_name        )        processor.tokenizer.padding_side = "right"        model.supports_gradient_checkpointing = True        model.language_model.supports_gradient_checkpointing = True    elif config.model_type == "llama":        is_vision = False        model = LlamaForCausalLM.from_pretrained(            train_config.model_name,            quantization_config=bnb_config,            use_cache=use_cache,            attn_implementation="sdpa" if train_config.use_fast_kernels else None,            device_map=(                "auto"                if train_config.quantization and not train_config.enable_fsdp                else None            ),            torch_dtype=torch.float16 if train_config.use_fp16 else "auto",        )    else:        raise ValueError(            f"Model type {config.model_type} is not supported. Please use llama or mllama model."        )    # Load the tokenizer and add special tokens    tokenizer = AutoTokenizer.from_pretrained(        train_config.model_name        if train_config.tokenizer_name is None        else train_config.tokenizer_name    )    if not tokenizer.pad_token_id:        tokenizer.pad_token_id = tokenizer.eos_token_id    # If there is a mismatch between tokenizer vocab size and embedding matrix,    # throw a warning and then expand the embedding matrix    if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:        print(            "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."        )        model.resize_token_embeddings(len(tokenizer))    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)    # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled    if (        train_config.enable_fsdp        and fsdp_config.pure_bf16        and not train_config.quantization    ):        model.to(torch.bfloat16)    if train_config.use_peft:        # Load the pre-trained peft model checkpoint and setup its configuration        if train_config.from_peft_checkpoint:            model = PeftModel.from_pretrained(                model, train_config.from_peft_checkpoint, is_trainable=True            )            peft_config = model.peft_config        # Generate the peft config and start fine-tuning from original model        else:            peft_config = generate_peft_config(train_config, kwargs)            model = get_peft_model(model, peft_config)        if wandb_run:            wandb_run.config.update(peft_config)        model.print_trainable_parameters()    hsdp_device_mesh_plan = None    if (        fsdp_config.hsdp        and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD    ):        hsdp_device_mesh_plan = hsdp_device_mesh(            replica_group_size=fsdp_config.replica_group_size,            sharding_group_size=fsdp_config.sharding_group_size,        )        print("HSDP device mesh is ready")    # setting up FSDP if enable_fsdp is enabled    if train_config.enable_fsdp:        check_fsdp_config(fsdp_config)        if not train_config.use_peft and train_config.freeze_layers:            freeze_transformer_layers(model, train_config.num_freeze_layers)            # print model size and frozen layers after freezing layers            print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)        if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":            freeze_LLM_only(model)            # print model size and frozen layers after freezing layers            print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)        mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)        # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models        if is_vision:            my_auto_wrapping_policy = fsdp_auto_wrap_policy(                model,                [                    MllamaSelfAttentionDecoderLayer,                    MllamaCrossAttentionDecoderLayer,                    MllamaVisionEncoderLayer,                ],            )        else:            # Create the FSDP wrapper for LlamaDecoderLayer in text models            my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])        device_id = 0        if is_xpu_available():            device_id = torch.xpu.current_device()        elif torch.cuda.is_available():            device_id = torch.cuda.current_device()               use_orig_params = train_config.freeze_LLM_only or train_config.use_peft                model = FSDP(            model,            auto_wrap_policy=(                my_auto_wrapping_policy if train_config.use_peft else wrapping_policy            ),            cpu_offload=(                CPUOffload(offload_params=True)                if fsdp_config.fsdp_cpu_offload                else None            ),            mixed_precision=(                mixed_precision_policy if not fsdp_config.pure_bf16 else None            ),            sharding_strategy=fsdp_config.sharding_strategy,            device_mesh=hsdp_device_mesh_plan,            device_id=device_id,            limit_all_gathers=True,            sync_module_states=train_config.low_cpu_fsdp,            param_init_fn=(                (                    lambda module: module.to_empty(                        device=torch.device("cuda"), recurse=False                    )                )                if train_config.low_cpu_fsdp and rank != 0                else None            ),            use_orig_params=use_orig_params,        )        if fsdp_config.fsdp_activation_checkpointing:            model.enable_input_require_grads()            model.gradient_checkpointing_enable()            apply_fsdp_checkpointing(model)    elif not train_config.quantization and not train_config.enable_fsdp:        if is_xpu_available():            model.to("xpu:0")        elif torch.cuda.is_available():            model.to("cuda")    dataset_config = generate_dataset_config(train_config, kwargs)    if is_vision:        dataset_processer = processor    else:        dataset_processer = tokenizer    # Load and preprocess the dataset for training and validation    dataset_train = get_preprocessed_dataset(        dataset_processer,        dataset_config,        split="train",    )    if not train_config.enable_fsdp or rank == 0:        print(f"--> Training Set Length = {len(dataset_train)}")    dataset_val = get_preprocessed_dataset(        dataset_processer,        dataset_config,        split="test",    )    if not train_config.enable_fsdp or rank == 0:        print(f"--> Validation Set Length = {len(dataset_val)}")    if train_config.batching_strategy == "packing":        if is_vision:            raise ValueError("Packing is not supported for vision datasets")        else:            dataset_train = ConcatDataset(                dataset_train, chunk_size=train_config.context_length            )    train_dl_kwargs = get_dataloader_kwargs(        train_config, dataset_train, dataset_processer, "train"    )    print("length of dataset_train", len(dataset_train))    custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)    if custom_data_collator:        print("custom_data_collator is used")        train_dl_kwargs["collate_fn"] = custom_data_collator    # Create DataLoaders for the training and validation dataset    train_dataloader = torch.utils.data.DataLoader(        dataset_train,        num_workers=train_config.num_workers_dataloader,        pin_memory=True,        **train_dl_kwargs,    )    print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")    eval_dataloader = None    if train_config.run_validation:        if train_config.batching_strategy == "packing":            if is_vision:                raise ValueError("Packing is not supported for vision datasets")            else:                dataset_val = ConcatDataset(                    dataset_val, chunk_size=train_config.context_length                )        val_dl_kwargs = get_dataloader_kwargs(            train_config, dataset_val, dataset_processer, "val"        )        if custom_data_collator:            val_dl_kwargs["collate_fn"] = custom_data_collator        eval_dataloader = torch.utils.data.DataLoader(            dataset_val,            num_workers=train_config.num_workers_dataloader,            pin_memory=True,            **val_dl_kwargs,        )        print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")        if len(eval_dataloader) == 0:            raise ValueError(                f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"            )        else:            print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")    # Initialize the optimizer and learning rate scheduler    if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":        optimizer = AnyPrecisionAdamW(            model.parameters(),            lr=train_config.lr,            momentum_dtype=torch.bfloat16,            variance_dtype=torch.bfloat16,            use_kahan_summation=False,            weight_decay=train_config.weight_decay,        )    else:        optimizer = optim.AdamW(            model.parameters(),            lr=train_config.lr,            weight_decay=train_config.weight_decay,        )    scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)    results = train(        model,        train_dataloader,        eval_dataloader,        tokenizer,        optimizer,        scheduler,        train_config.gradient_accumulation_steps,        train_config,        fsdp_config if train_config.enable_fsdp else None,        local_rank if train_config.enable_fsdp else None,        rank if train_config.enable_fsdp else None,        wandb_run,    )    if not train_config.enable_fsdp or rank == 0:        [print(f"Key: {k}, Value: {v}") for k, v in results.items()]        if train_config.use_wandb:            for k, v in results.items():                wandb_run.summary[k] = vif __name__ == "__main__":    fire.Fire(main)
 |