| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 | # Copyright (c) Meta Platforms, Inc. and affiliates.# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.import osimport fireimport torchimport torch.distributed as distimport torch.optim as optimfrom peft import get_peft_model, prepare_model_for_int8_trainingfrom pkg_resources import packagingfrom torch.distributed.fsdp import (    FullyShardedDataParallel as FSDP,)from torch.optim.lr_scheduler import StepLRfrom torch.utils.data import DistributedSamplerfrom transformers import (    LlamaForCausalLM,    LlamaTokenizer,    LlamaConfig,    default_data_collator,)from transformers.models.llama.modeling_llama import LlamaDecoderLayerimport policiesfrom configs import fsdp_config, train_configfrom policies import AnyPrecisionAdamWfrom utils import fsdp_auto_wrap_policyfrom utils.config_utils import (    update_config,    generate_peft_config,    generate_dataset_config,)from utils.dataset_utils import get_preprocessed_datasetfrom utils.train_utils import (    train,    freeze_transformer_layers,    setup,    setup_environ_flags,    clear_gpu_cache,    print_model_size,    get_policies)def main(**kwargs):    # Update the configuration for the training and sharding process    update_config((train_config, fsdp_config), **kwargs)    # Set the seeds for reproducibility    torch.cuda.manual_seed(train_config.seed)    torch.manual_seed(train_config.seed)    if train_config.enable_fsdp:        setup()        # torchrun specific        local_rank = int(os.environ["LOCAL_RANK"])        rank = int(os.environ["RANK"])        world_size = int(os.environ["WORLD_SIZE"])    if torch.distributed.is_initialized():        torch.cuda.set_device(local_rank)        clear_gpu_cache(local_rank)        setup_environ_flags(rank)    # Calculate gradient accumulation steps    gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size    # Load the pre-trained model and setup its configuration    if train_config.enable_fsdp and train_config.low_cpu_fsdp:        """        for FSDP, we can save cpu memory by loading pretrained model on rank0 only.        this avoids cpu oom when loading large models like llama 70B, in which case        model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms        overhead and currently requires latest nightly.        """        v = packaging.version.parse(torch.__version__)        verify_latest_nightly = v.is_devrelease and v.dev >= 20230701        if not verify_latest_nightly:            raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "                            "please install latest nightly.")        if rank == 0:            model = LlamaForCausalLM.from_pretrained(                train_config.model_name,                load_in_8bit=True if train_config.quantization else None,                device_map="auto" if train_config.quantization else None,            )        else:            llama_config = LlamaConfig.from_pretrained(train_config.model_name)            with torch.device("meta"):                model = LlamaForCausalLM(llama_config)    else:        model = LlamaForCausalLM.from_pretrained(            train_config.model_name,            load_in_8bit=True if train_config.quantization else None,            device_map="auto" if train_config.quantization else None,        )    if train_config.enable_fsdp and train_config.use_fast_kernels:        """        For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable        using of Flash Attention or Xformer memory-efficient kernels         based on the hardware being used. This would speed up fine-tuning.        """        try:            from optimum.bettertransformer import BetterTransformer            model = BetterTransformer.transform(model)         except ImportError:            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)    # Prepare the model for int8 training if quantization is enabled    if train_config.quantization:        model = prepare_model_for_int8_training(model)    # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled    if train_config.enable_fsdp and fsdp_config.pure_bf16:        model.to(torch.bfloat16)    # Load the tokenizer and add special tokens    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)    tokenizer.add_special_tokens(            {                "pad_token": "<PAD>",            }        )    if train_config.use_peft:        peft_config = generate_peft_config(train_config, kwargs)        model = get_peft_model(model, peft_config)        model.print_trainable_parameters()    #setting up FSDP if enable_fsdp is enabled    if train_config.enable_fsdp:        if not train_config.use_peft and train_config.freeze_layers:            freeze_transformer_layers(train_config.num_freeze_layers)        mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)        model = FSDP(            model,            auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,            mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,            sharding_strategy=fsdp_config.sharding_strategy,            device_id=torch.cuda.current_device(),            limit_all_gathers=True,            sync_module_states=train_config.low_cpu_fsdp,            param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)            if train_config.low_cpu_fsdp and rank != 0 else None,        )        if fsdp_config.fsdp_activation_checkpointing:            policies.apply_fsdp_checkpointing(model)    elif not train_config.quantization and not train_config.enable_fsdp:        model.to("cuda")    dataset_config = generate_dataset_config(train_config, kwargs)     # Load and preprocess the dataset for training and validation    dataset_train = get_preprocessed_dataset(        tokenizer,        dataset_config,        split="train",    )    if not train_config.enable_fsdp or rank == 0:        print(f"--> Training Set Length = {len(dataset_train)}")    dataset_val = get_preprocessed_dataset(        tokenizer,        dataset_config,        split="test",    )    if not train_config.enable_fsdp or rank == 0:            print(f"--> Validation Set Length = {len(dataset_val)}")    train_sampler = None    val_sampler = None    if train_config.enable_fsdp:        train_sampler = DistributedSampler(            dataset_train,            rank=dist.get_rank(),            num_replicas=dist.get_world_size(),            shuffle=True,        )        if train_config.run_validation:            val_sampler = DistributedSampler(                dataset_val,                rank=dist.get_rank(),                num_replicas=dist.get_world_size(),            )    # Create DataLoaders for the training and validation dataset    train_dataloader = torch.utils.data.DataLoader(        dataset_train,        batch_size=train_config.batch_size_training,        num_workers=train_config.num_workers_dataloader,        pin_memory=True,        sampler=train_sampler if train_sampler else None,        drop_last=True,        collate_fn=default_data_collator,    )    if train_config.run_validation:        eval_dataloader = torch.utils.data.DataLoader(            dataset_val,            batch_size=train_config.val_batch_size,            num_workers=train_config.num_workers_dataloader,            pin_memory=True,            sampler=val_sampler if val_sampler else None,            drop_last=True,            collate_fn=default_data_collator,        )    # Initialize the optimizer and learning rate scheduler    if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":        optimizer = AnyPrecisionAdamW(            model.parameters(),            lr=train_config.lr,            momentum_dtype=torch.bfloat16,            variance_dtype=torch.bfloat16,            use_kahan_summation=False,        )    else:        optimizer = optim.AdamW(            model.parameters(),            lr=train_config.lr,            weight_decay=0.0,        )    scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)    # Start the training process    results = train(        model,        train_dataloader,        eval_dataloader,        tokenizer,        optimizer,        scheduler,        gradient_accumulation_steps,        train_config,        fsdp_config if train_config.enable_fsdp else None,        local_rank if train_config.enable_fsdp else None,        rank if train_config.enable_fsdp else None,    )    if not train_config.enable_fsdp or rank==0:        [print(f'Key: {k}, Value: {v}') for k, v in results.items()]if __name__ == "__main__":    fire.Fire(main)
 |