| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238 | # Copyright (c) Meta Platforms, Inc. and affiliates.# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.import osimport sysfrom typing import List, Unionimport fireimport torchimport transformersfrom datasets import load_datasetimport os.path as ospfrom tqdm import tqdm# Unused imports removedfrom utils import fsdp_auto_wrap_policyfrom transformers import (    LlamaForCausalLM,    LlamaTokenizer,    AutoModelForCausalLM,    AutoModelForSeq2SeqLM,    AutoTokenizer,    default_data_collator,    BitsAndBytesConfig)import torch.distributed as dist# Unused imports removedfrom utils.train_utils import (    set_tokenizer_params,    train,    evaluation,    freeze_transformer_layers,    check_frozen_layers_peft_model,    setup,    setup_environ_flags,    cleanup,    clear_gpu_cache,    get_parameter_dtypes,    print_model_size,    get_policies  )from utils.dataset_utils import get_preprocessed_datasetfrom utils.config_utils import (    update_config,    generate_peft_config,    generate_dataset_config,)from peft import get_peft_model, TaskType, prepare_model_for_int8_trainingimport configsfrom torch.distributed.fsdp import (    FullyShardedDataParallel as FSDP,    MixedPrecision,)from torch.utils.data import DistributedSamplerimport policiesfrom policies import AnyPrecisionAdamWfrom configs import fsdp_config, train_configimport torch.optim as optimfrom torch.optim.lr_scheduler import StepLRfrom pkg_resources import packagingimport torchimport torch.cuda.nccl as ncclimport torch.distributed as distfrom transformers.models.llama.modeling_llama import LlamaDecoderLayerdef main(**kwargs):    # Update the configuration for the training and sharding process    update_config((train_config, fsdp_config), **kwargs)    # Set the seeds for reproducibility    torch.cuda.manual_seed(train_config.seed)    torch.manual_seed(train_config.seed)    if train_config.enable_fsdp:        setup()        # torchrun specific        local_rank = int(os.environ["LOCAL_RANK"])        rank = int(os.environ["RANK"])        world_size = int(os.environ["WORLD_SIZE"])    if torch.distributed.is_initialized():        torch.cuda.set_device(rank)        setup_environ_flags(rank)        # Calculate gradient accumulation steps    gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size         # Load the pre-trained model and setup its configuration    model = LlamaForCausalLM.from_pretrained(        train_config.model_name,        load_in_8bit=True if train_config.quantization else None,        device_map="auto" if train_config.quantization else None,    )        print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)        # Prepare the model for int8 training if quantization is enabled    if train_config.quantization:        model = prepare_model_for_int8_training(model)            # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled    if train_config.enable_fsdp and fsdp_config.pure_bf16:        model.to(torch.bfloat16)    # Load the tokenizer and add special tokens    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)    tokenizer.add_special_tokens(            {                            "pad_token": "<PAD>",            }        )    if train_config.use_peft:        peft_config = generate_peft_config(train_config, kwargs)        model = get_peft_model(model, peft_config)        model.print_trainable_parameters()        #setting up FSDP if enable_fsdp is enabled    if train_config.enable_fsdp:        if not train_config.use_peft and train_config.freeze_layers:                        freeze_transformer_layers(train_config.num_freeze_layers)        mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)           model = FSDP(            model,            auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,            mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,            sharding_strategy=fsdp_config.sharding_strategy,            device_id=torch.cuda.current_device(),            limit_all_gathers=False,        )        if fsdp_config.fsdp_activation_checkpointing:            policies.apply_fsdp_checkpointing(model)    elif not train_config.quantization and not train_config.enable_fsdp:        model.to("cuda")    dataset_config = generate_dataset_config(train_config, kwargs)         # Load and preprocess the dataset for training and validation    dataset_train = get_preprocessed_dataset(        tokenizer,        dataset_config,        split="train",    )        if not train_config.enable_fsdp or rank == 0:        print(f"--> Training Set Length = {len(dataset_train)}")    dataset_val = get_preprocessed_dataset(        tokenizer,        dataset_config,        split="test",    )    if not train_config.enable_fsdp or rank == 0:            print(f"--> Validation Set Length = {len(dataset_val)}")    train_sampler = None    val_sampler = None    if train_config.enable_fsdp:        train_sampler = DistributedSampler(            dataset_train,            rank=dist.get_rank(),            num_replicas=dist.get_world_size(),            shuffle=True,        )        if train_config.run_validation:            val_sampler = DistributedSampler(                dataset_val,                rank=dist.get_rank(),                num_replicas=dist.get_world_size(),            )            # Create DataLoaders for the training and validation dataset    train_dataloader = torch.utils.data.DataLoader(        dataset_train,        batch_size=train_config.batch_size_training,        num_workers=train_config.num_workers_dataloader,        pin_memory=True,        sampler=train_sampler if train_sampler else None,        drop_last=True,        collate_fn=default_data_collator,    )    if train_config.run_validation:        eval_dataloader = torch.utils.data.DataLoader(            dataset_val,            batch_size=train_config.val_batch_size,            num_workers=train_config.num_workers_dataloader,            pin_memory=True,            sampler=val_sampler if val_sampler else None,            drop_last=True,            collate_fn=default_data_collator,        )            # Initialize the optimizer and learning rate scheduler    if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":        optimizer = AnyPrecisionAdamW(            model.parameters(),            lr=train_config.lr,            momentum_dtype=torch.bfloat16,            variance_dtype=torch.bfloat16,            use_kahan_summation=False,        )    else:        optimizer = optim.AdamW(            model.parameters(),            lr=train_config.lr,            weight_decay=0.0,        )    scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)    # Start the training process    results = train(        model,        train_dataloader,        eval_dataloader,         tokenizer,        optimizer,        scheduler,        gradient_accumulation_steps,        train_config,        fsdp_config if train_config.enable_fsdp else None,        local_rank if train_config.enable_fsdp else None,        rank if train_config.enable_fsdp else None,    )    if not train_config.enable_fsdp or rank==0:        [print(f'Key: {k}, Value: {v}') for k, v in results.items()]if __name__ == "__main__":    fire.Fire(main)
 |