# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import os import sys from typing import List, Union import fire import torch import transformers from datasets import load_dataset import os.path as osp from tqdm import tqdm # Unused imports removed from utils import fsdp_auto_wrap_policy from transformers import ( LlamaForCausalLM, LlamaTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, default_data_collator, BitsAndBytesConfig ) import torch.distributed as dist # Unused imports removed from utils.train_utils import ( set_tokenizer_params, train, evaluation, freeze_transformer_layers, check_frozen_layers_peft_model, setup, setup_environ_flags, cleanup, clear_gpu_cache, get_parameter_dtypes, print_model_size, get_policies ) from utils.dataset_utils import get_preprocessed_dataset from utils.config_utils import ( update_config, generate_peft_config, generate_dataset_config, ) from peft import get_peft_model, TaskType, prepare_model_for_int8_training import configs from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, MixedPrecision, ) from torch.utils.data import DistributedSampler import policies from policies import AnyPrecisionAdamW from configs import fsdp_config, train_config import torch.optim as optim from torch.optim.lr_scheduler import StepLR from pkg_resources import packaging import torch import torch.cuda.nccl as nccl import torch.distributed as dist from transformers.models.llama.modeling_llama import LlamaDecoderLayer def main(**kwargs): # Update the configuration for the training and sharding process update_config((train_config, fsdp_config), **kwargs) # Set the seeds for reproducibility torch.cuda.manual_seed(train_config.seed) torch.manual_seed(train_config.seed) if train_config.enable_fsdp: setup() # torchrun specific local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) if torch.distributed.is_initialized(): torch.cuda.set_device(rank) setup_environ_flags(rank) # Calculate gradient accumulation steps gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size # Load the pre-trained model and setup its configuration model = LlamaForCausalLM.from_pretrained( train_config.model_name, load_in_8bit=True if train_config.quantization else None, device_map="auto" if train_config.quantization else None, ) print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) # Prepare the model for int8 training if quantization is enabled if train_config.quantization: model = prepare_model_for_int8_training(model) # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled if train_config.enable_fsdp and fsdp_config.pure_bf16: model.to(torch.bfloat16) # Load the tokenizer and add special tokens tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name) tokenizer.add_special_tokens( { "pad_token": "", } ) if train_config.use_peft: peft_config = generate_peft_config(train_config, kwargs) model = get_peft_model(model, peft_config) model.print_trainable_parameters() #setting up FSDP if enable_fsdp is enabled if train_config.enable_fsdp: if not train_config.use_peft and train_config.freeze_layers: freeze_transformer_layers(train_config.num_freeze_layers) mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer) model = FSDP( model, auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, sharding_strategy=fsdp_config.sharding_strategy, device_id=torch.cuda.current_device(), limit_all_gathers=False, ) if fsdp_config.fsdp_activation_checkpointing: policies.apply_fsdp_checkpointing(model) elif not train_config.quantization and not train_config.enable_fsdp: model.to("cuda") dataset_config = generate_dataset_config(train_config, kwargs) # Load and preprocess the dataset for training and validation dataset_train = get_preprocessed_dataset( tokenizer, dataset_config, split="train", ) if not train_config.enable_fsdp or rank == 0: print(f"--> Training Set Length = {len(dataset_train)}") dataset_val = get_preprocessed_dataset( tokenizer, dataset_config, split="test", ) if not train_config.enable_fsdp or rank == 0: print(f"--> Validation Set Length = {len(dataset_val)}") train_sampler = None val_sampler = None if train_config.enable_fsdp: train_sampler = DistributedSampler( dataset_train, rank=dist.get_rank(), num_replicas=dist.get_world_size(), shuffle=True, ) if train_config.run_validation: val_sampler = DistributedSampler( dataset_val, rank=dist.get_rank(), num_replicas=dist.get_world_size(), ) # Create DataLoaders for the training and validation dataset train_dataloader = torch.utils.data.DataLoader( dataset_train, batch_size=train_config.batch_size_training, num_workers=train_config.num_workers_dataloader, pin_memory=True, sampler=train_sampler if train_sampler else None, drop_last=True, collate_fn=default_data_collator, ) if train_config.run_validation: eval_dataloader = torch.utils.data.DataLoader( dataset_val, batch_size=train_config.val_batch_size, num_workers=train_config.num_workers_dataloader, pin_memory=True, sampler=val_sampler if val_sampler else None, drop_last=True, collate_fn=default_data_collator, ) # Initialize the optimizer and learning rate scheduler if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision": optimizer = AnyPrecisionAdamW( model.parameters(), lr=train_config.lr, momentum_dtype=torch.bfloat16, variance_dtype=torch.bfloat16, use_kahan_summation=False, ) else: optimizer = optim.AdamW( model.parameters(), lr=train_config.lr, weight_decay=0.0, ) scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) # Start the training process results = train( model, train_dataloader, eval_dataloader, tokenizer, optimizer, scheduler, gradient_accumulation_steps, train_config, fsdp_config if train_config.enable_fsdp else None, local_rank if train_config.enable_fsdp else None, rank if train_config.enable_fsdp else None, ) if not train_config.enable_fsdp or rank==0: [print(f'Key: {k}, Value: {v}') for k, v in results.items()] if __name__ == "__main__": fire.Fire(main)