| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326 | # Copyright (c) Meta Platforms, Inc. and affiliates.# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.import osimport sysfrom typing import Listimport fireimport torchimport transformersfrom datasets import load_datasetfrom tqdm import tqdm"""Unused imports:import torch.nn as nnimport bitsandbytes as bnb"""from torch.nn import functional as Ffrom peft import (    LoraConfig,    get_peft_model,    get_peft_model_state_dict,    prepare_model_for_int8_training,    set_peft_model_state_dict,)from transformers import LlamaForCausalLM, LlamaTokenizerfrom torch.distributed.fsdp import StateDictTypeimport torch.distributed as distfrom pkg_resources import packagingfrom .memory_utils import MemoryTraceimport model_checkpointingimport torch.cuda.nccl as ncclfrom torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScalerfrom pathlib import Pathsys.path.append(str(Path(__file__).resolve().parent.parent))from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapperdef set_tokenizer_params(tokenizer: LlamaTokenizer):    tokenizer.pad_token_id = 0    tokenizer.padding_side = "left"    # Converting Bytes to Megabytesdef byte2mb(x):    return int(x / 2**20)def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):    """    Trains the model on the given dataloader        Args:        model: The model to be trained        train_dataloader: The dataloader containing the training data        optimizer: The optimizer used for training        lr_scheduler: The learning rate scheduler        gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation        num_epochs: The number of epochs to train for        local_rank: The rank of the current node in a distributed setting        train_config: The training configuration        eval_dataloader: The dataloader containing the eval data        tokenizer: tokenizer used in the eval for decoding the predicitons        Returns: results dictionary containing average training and validation perplexity and loss    """    # Create a gradient scaler for fp16    if train_config.use_fp16 and train_config.enable_fsdp:        scaler = ShardedGradScaler()    elif train_config.use_fp16 and not train_config.enable_fsdp:        scaler = torch.cuda.amp.GradScaler()             train_prep = []    train_loss = []    val_prep = []    val_loss =[]    results = {}    best_val_loss = float("inf")    for epoch in range(train_config.num_epochs):        with MemoryTrace() as memtrace:  # track the memory usage            model.train()            total_loss = 0.0            data_set_len = 0                        for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):                for key in batch.keys():                    if train_config.enable_fsdp:                        batch[key] = batch[key].to(local_rank)                    else:                        batch[key] = batch[key].to('cuda')                       outputs = model(**batch)                loss = outputs.loss                loss = loss / gradient_accumulation_steps                total_loss += loss.detach().float()                first_key = next(iter(batch))                data_set_len += len(batch[first_key])                if train_config.use_fp16:                    # if fp16 is enabled, use gradient scaler to handle gradient update                    scaler.scale(loss).backward()                    if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:                        scaler.step(optimizer)                        scaler.update()                        optimizer.zero_grad()                else:                    # regular backpropagation when fp16 is not used                    loss.backward()                    if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:                        optimizer.step()                        optimizer.zero_grad()                                        print(f"\n step {step} is completed and loss is {loss.detach().float()}")                # Reducing total_loss across all devices if there's more than one CUDA device        if torch.cuda.device_count() > 1 and train_config.enable_fsdp:            dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)        train_epoch_loss = total_loss / data_set_len        train_perplexity = torch.exp(train_epoch_loss)                train_prep.append(train_perplexity)        train_loss.append(train_epoch_loss)                print(f"Max CUDA memory allocated was {memtrace.peak} GB")        print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")        print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")        print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")                # Update the learning rate as needed        lr_scheduler.step()                  if train_config.run_validation:            eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)               if train_config.save_model and eval_epoch_loss < best_val_loss:                                if  train_config.use_peft:                                        print(f"we are in the saving the PEFT modules")                    model.save_pretrained(train_config.output_dir)                       print(f"PEFT modules are saved in {train_config.output_dir} directory")                                    else:                    if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:                                                model_checkpointing.save_model_checkpoint(                            model, optimizer, rank, train_config, epoch=1                        )                    elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:                        print(" we are about to save the models *******")                                                model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)                        if train_config.save_optimizer:                            model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)                    if not train_config.use_peft and  train_config.save_optimizer:                        model_checkpointing.save_optimizer_checkpoint(                            model, optimizer, rank, train_config, epoch=1                        )                                                           if local_rank == 0 and eval_epoch_loss < best_val_loss:                best_val_loss = eval_epoch_loss                print(f"best eval loss on epoch {epoch} is {best_val_loss}")            val_loss.append(best_val_loss)            val_prep.append(eval_ppl)                        print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")        lr_scheduler.step()    avg_train_prep = sum(train_prep)/len(train_prep)    avg_train_loss = sum(train_loss)/len(train_loss)    if train_config.run_validation:        avg_eval_prep = sum(val_prep)/len(val_prep)         avg_eval_loss = sum(val_loss)/len(val_loss)     results['avg_train_prep'] = avg_train_prep    results['avg_train_loss'] = avg_train_loss    if train_config.run_validation:        results['avg_eval_prep'] = avg_eval_prep        results['avg_eval_loss'] = avg_eval_loss            return resultsdef evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):    """    Evaluates the model on the given dataloader        Args:        model: The model to evaluate        eval_dataloader: The dataloader containing the evaluation data        local_rank: The rank of the current node in a distributed setting        tokenizer: The tokenizer used to decode predictions        Returns: eval_ppl, eval_epoch_loss    """    model.eval()    eval_preds = []    eval_loss = 0.0  # Initialize evaluation loss    eval_dataset_len = 0    with MemoryTrace() as memtrace:        for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):            for key in batch.keys():                if train_config.enable_fsdp:                    batch[key] = batch[key].to(local_rank)                else:                    batch[key] = batch[key].to('cuda')            # Ensure no gradients are computed for this scope to save memory            with torch.no_grad():                # Forward pass and compute loss                outputs = model(**batch)                loss = outputs.loss                eval_loss += loss.detach().float()                first_key = next(iter(batch))                eval_dataset_len+= len(batch[first_key])                            # Decode predictions and add to evaluation predictions list            preds = torch.argmax(outputs.logits, -1)            eval_preds.extend(                tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)            )        # If there's more than one CUDA device, reduce evaluation loss across all devices    if torch.cuda.device_count() > 1 and train_config.enable_fsdp:        dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)        # Compute average loss and perplexity    eval_epoch_loss = eval_loss / eval_dataset_len    eval_ppl = torch.exp(eval_epoch_loss)        # Print evaluation metrics    print(f" {eval_ppl=} {eval_epoch_loss=}")    return eval_ppl, eval_epoch_lossdef freeze_transformer_layers(model, num_layer):   for i, layer in enumerate(model.model.layers):            if i < num_layer:                for param in layer.parameters():                    param.requires_grad = Falsedef check_frozen_layers_peft_model(model):     for i, layer in enumerate(model.base_model.model.model.layers):            for name, param in layer.named_parameters():                print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")                                def setup():    """Initialize the process group for distributed training"""    dist.init_process_group("nccl")def setup_environ_flags(rank):    """Set environment flags for debugging purposes"""    os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)    os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"    if rank == 0:        print(f"--> Running with torch dist debug set to detail")def cleanup():    """Clean up the process group after training"""    dist.destroy_process_group()def clear_gpu_cache(rank=None):    """Clear the GPU cache for all ranks"""    if rank == 0:        print(f"Clearing GPU cache for all ranks")    torch.cuda.empty_cache()def get_parameter_dtypes(model):    """Get the data types of model parameters"""    parameter_dtypes = {}    for name, parameter in model.named_parameters():        parameter_dtypes[name] = parameter.dtype    return parameter_dtypesdef print_model_size(model, config, rank: int = 0) -> None:    """    Print model name, the number of trainable parameters and initialization time.    Args:        model: The PyTorch model.        model_name (str): Name of the model.        init_time_start (float): Initialization start time.        init_time_end (float): Initialization end time.        rank (int, optional): Current process's rank. Defaults to 0.    """    if rank == 0:        print(f"--> Model {config.model_name}")        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)        print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")def get_policies(cfg, rank):    """Get the policies for mixed precision and fsdp wrapping"""        verify_bfloat_support = (    torch.version.cuda    and torch.cuda.is_bf16_supported()    and packaging.version.parse(torch.version.cuda).release >= (11, 0)    and dist.is_nccl_available()    and nccl.version() >= (2, 10)    )    mixed_precision_policy = None    wrapping_policy = None    # Mixed precision    if cfg.mixed_precision:        bf16_ready = verify_bfloat_support        if bf16_ready and not cfg.use_fp16:            mixed_precision_policy = bfSixteen_mixed            if rank == 0:                print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")        elif cfg.use_fp16:            mixed_precision_policy = fpSixteen            if rank == 0:                print(f"FP16 enabled")        else:            print(f"bFloat16 support not present. Using FP32, and not mixed precision")    wrapping_policy = get_llama_wrapper()    return mixed_precision_policy, wrapping_policy
 |