123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- import dataclasses
- import os
- import random
- from collections import Counter
- from warnings import warn
- import fire
- import numpy as np
- import torch
- import torch.optim as optim
- from accelerate.utils import is_xpu_available
- from llama_cookbook.configs import (
- fsdp_config as FSDP_CONFIG,
- quantization_config as QUANTIZATION_CONFIG,
- train_config as TRAIN_CONFIG,
- )
- from llama_cookbook.data.concatenator import ConcatDataset
- from llama_cookbook.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
- from llama_cookbook.utils import fsdp_auto_wrap_policy
- from llama_cookbook.utils.config_utils import (
- check_fsdp_config,
- generate_dataset_config,
- generate_peft_config,
- get_dataloader_kwargs,
- update_config,
- )
- from llama_cookbook.utils.dataset_utils import (
- get_custom_data_collator,
- get_preprocessed_dataset,
- )
- from llama_cookbook.utils.fsdp_utils import hsdp_device_mesh, get_policies
- from llama_cookbook.utils.train_utils import (
- clear_gpu_cache,
- freeze_transformer_layers,
- freeze_LLM_only,
- print_model_size,
- print_frozen_model_status,
- setup,
- setup_environ_flags,
- train,
- )
- from peft import get_peft_model, PeftModel
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
- from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
- from torch.optim.lr_scheduler import StepLR
- from transformers import (
- AutoConfig,
- AutoProcessor,
- AutoTokenizer,
- BitsAndBytesConfig,
- LlamaForCausalLM,
- MllamaForConditionalGeneration,
- )
- from transformers.models.llama.modeling_llama import LlamaDecoderLayer
- from transformers.models.mllama.modeling_mllama import (
- MllamaCrossAttentionDecoderLayer,
- MllamaSelfAttentionDecoderLayer,
- MllamaVisionEncoderLayer,
- )
- def setup_wandb(train_config, fsdp_config, **kwargs):
- try:
- import wandb
- except ImportError:
- raise ImportError(
- "You are trying to use wandb which is not currently installed. "
- "Please install it using pip install wandb"
- )
- from llama_cookbook.configs import wandb_config as WANDB_CONFIG
- wandb_config = WANDB_CONFIG()
- update_config(wandb_config, **kwargs)
- init_dict = dataclasses.asdict(wandb_config)
- run = wandb.init(**init_dict)
- run.config.update(train_config)
- run.config.update(fsdp_config, allow_val_change=True)
- return run
- def main(**kwargs):
- # Update the configuration for the training and sharding process
- train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
- update_config((train_config, fsdp_config), **kwargs)
- # Set the seeds for reproducibility
- if is_xpu_available():
- torch.xpu.manual_seed(train_config.seed)
- torch.manual_seed(train_config.seed)
- random.seed(train_config.seed)
- np.random.seed(train_config.seed)
- if train_config.enable_fsdp:
- setup()
- # torchrun specific
- local_rank = int(os.environ["LOCAL_RANK"])
- rank = int(os.environ["RANK"])
- world_size = int(os.environ["WORLD_SIZE"])
- if torch.distributed.is_initialized():
- if is_xpu_available():
- torch.xpu.set_device(local_rank)
- elif torch.cuda.is_available():
- torch.cuda.set_device(local_rank)
- clear_gpu_cache(local_rank)
- setup_environ_flags(rank)
- wandb_run = None
- if train_config.use_wandb:
- if not train_config.enable_fsdp or rank == 0:
- wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
- # setting quantization configs
- bnb_config = None
- if train_config.quantization:
- if type(train_config.quantization) == type(True):
- warn(
- "Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.",
- FutureWarning,
- )
- train_config.quantization = "8bit"
- if train_config.quantization == "8bit" and train_config.enable_fsdp:
- raise ValueError(
- "8bit quantization is not supported with FSDP, please use 4bit quantization"
- )
- quant_config = QUANTIZATION_CONFIG()
- update_config(quant_config, **kwargs)
- bnb_config = quant_config.create_bnb_config(train_config.quantization)
- # Load the pre-trained model and setup its configuration
- use_cache = False if train_config.enable_fsdp else None
- config = AutoConfig.from_pretrained(train_config.model_name)
- if config.model_type == "mllama":
- is_vision = True
- model = MllamaForConditionalGeneration.from_pretrained(
- train_config.model_name,
- quantization_config=bnb_config,
- attn_implementation="sdpa" if train_config.use_fast_kernels else None,
- device_map=(
- "auto"
- if train_config.quantization and not train_config.enable_fsdp
- else None
- ),
- torch_dtype=torch.float16 if train_config.use_fp16 else "auto",
- )
- processor = AutoProcessor.from_pretrained(
- train_config.model_name
- if train_config.tokenizer_name is None
- else train_config.tokenizer_name
- )
- processor.tokenizer.padding_side = "right"
- model.supports_gradient_checkpointing = True
- model.language_model.supports_gradient_checkpointing = True
- elif config.model_type == "llama":
- is_vision = False
- model = LlamaForCausalLM.from_pretrained(
- train_config.model_name,
- quantization_config=bnb_config,
- use_cache=use_cache,
- attn_implementation="sdpa" if train_config.use_fast_kernels else None,
- device_map=(
- "auto"
- if train_config.quantization and not train_config.enable_fsdp
- else None
- ),
- torch_dtype=torch.float16 if train_config.use_fp16 else "auto",
- )
- else:
- raise ValueError(
- f"Model type {config.model_type} is not supported. Please use llama or mllama model."
- )
- # Load the tokenizer and add special tokens
- tokenizer = AutoTokenizer.from_pretrained(
- train_config.model_name
- if train_config.tokenizer_name is None
- else train_config.tokenizer_name
- )
- if not tokenizer.pad_token_id:
- tokenizer.pad_token_id = tokenizer.eos_token_id
- # If there is a mismatch between tokenizer vocab size and embedding matrix,
- # throw a warning and then expand the embedding matrix
- if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
- print(
- "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
- )
- model.resize_token_embeddings(len(tokenizer))
- print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
- # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
- if (
- train_config.enable_fsdp
- and fsdp_config.pure_bf16
- and not train_config.quantization
- ):
- model.to(torch.bfloat16)
- if train_config.use_peft:
- # Load the pre-trained peft model checkpoint and setup its configuration
- if train_config.from_peft_checkpoint:
- model = PeftModel.from_pretrained(
- model, train_config.from_peft_checkpoint, is_trainable=True
- )
- peft_config = model.peft_config
- # Generate the peft config and start fine-tuning from original model
- else:
- peft_config = generate_peft_config(train_config, kwargs)
- model = get_peft_model(model, peft_config)
- if wandb_run:
- wandb_run.config.update(peft_config)
- model.print_trainable_parameters()
- hsdp_device_mesh_plan = None
- if (
- fsdp_config.hsdp
- and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD
- ):
- hsdp_device_mesh_plan = hsdp_device_mesh(
- replica_group_size=fsdp_config.replica_group_size,
- sharding_group_size=fsdp_config.sharding_group_size,
- )
- print("HSDP device mesh is ready")
- # setting up FSDP if enable_fsdp is enabled
- if train_config.enable_fsdp:
- check_fsdp_config(fsdp_config)
- if not train_config.use_peft and train_config.freeze_layers:
- freeze_transformer_layers(model, train_config.num_freeze_layers)
- # print model size and frozen layers after freezing layers
- print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
- if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
- freeze_LLM_only(model)
- # print model size and frozen layers after freezing layers
- print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
- mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
- # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
- if is_vision:
- my_auto_wrapping_policy = fsdp_auto_wrap_policy(
- model,
- [
- MllamaSelfAttentionDecoderLayer,
- MllamaCrossAttentionDecoderLayer,
- MllamaVisionEncoderLayer,
- ],
- )
- else:
- # Create the FSDP wrapper for LlamaDecoderLayer in text models
- my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
- device_id = 0
- if is_xpu_available():
- device_id = torch.xpu.current_device()
- elif torch.cuda.is_available():
- device_id = torch.cuda.current_device()
- if train_config.freeze_LLM_only:
- use_orig_params = True
- else:
- use_orig_params = False
- model = FSDP(
- model,
- auto_wrap_policy=(
- my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
- ),
- cpu_offload=(
- CPUOffload(offload_params=True)
- if fsdp_config.fsdp_cpu_offload
- else None
- ),
- mixed_precision=(
- mixed_precision_policy if not fsdp_config.pure_bf16 else None
- ),
- sharding_strategy=fsdp_config.sharding_strategy,
- device_mesh=hsdp_device_mesh_plan,
- device_id=device_id,
- limit_all_gathers=True,
- sync_module_states=train_config.low_cpu_fsdp,
- param_init_fn=(
- (
- lambda module: module.to_empty(
- device=torch.device("cuda"), recurse=False
- )
- )
- if train_config.low_cpu_fsdp and rank != 0
- else None
- ),
- use_orig_params=use_orig_params,
- )
- if fsdp_config.fsdp_activation_checkpointing:
- model.enable_input_require_grads()
- model.gradient_checkpointing_enable()
- apply_fsdp_checkpointing(model)
- elif not train_config.quantization and not train_config.enable_fsdp:
- if is_xpu_available():
- model.to("xpu:0")
- elif torch.cuda.is_available():
- model.to("cuda")
- dataset_config = generate_dataset_config(train_config, kwargs)
- if is_vision:
- dataset_processer = processor
- else:
- dataset_processer = tokenizer
- # Load and preprocess the dataset for training and validation
- dataset_train = get_preprocessed_dataset(
- dataset_processer,
- dataset_config,
- split="train",
- )
- if not train_config.enable_fsdp or rank == 0:
- print(f"--> Training Set Length = {len(dataset_train)}")
- dataset_val = get_preprocessed_dataset(
- dataset_processer,
- dataset_config,
- split="test",
- )
- if not train_config.enable_fsdp or rank == 0:
- print(f"--> Validation Set Length = {len(dataset_val)}")
- if train_config.batching_strategy == "packing":
- if is_vision:
- raise ValueError("Packing is not supported for vision datasets")
- else:
- dataset_train = ConcatDataset(
- dataset_train, chunk_size=train_config.context_length
- )
- train_dl_kwargs = get_dataloader_kwargs(
- train_config, dataset_train, dataset_processer, "train"
- )
- print("length of dataset_train", len(dataset_train))
- custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
- if custom_data_collator:
- print("custom_data_collator is used")
- train_dl_kwargs["collate_fn"] = custom_data_collator
- # Create DataLoaders for the training and validation dataset
- train_dataloader = torch.utils.data.DataLoader(
- dataset_train,
- num_workers=train_config.num_workers_dataloader,
- pin_memory=True,
- **train_dl_kwargs,
- )
- print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
- eval_dataloader = None
- if train_config.run_validation:
- if train_config.batching_strategy == "packing":
- if is_vision:
- raise ValueError("Packing is not supported for vision datasets")
- else:
- dataset_val = ConcatDataset(
- dataset_val, chunk_size=train_config.context_length
- )
- val_dl_kwargs = get_dataloader_kwargs(
- train_config, dataset_val, dataset_processer, "val"
- )
- if custom_data_collator:
- val_dl_kwargs["collate_fn"] = custom_data_collator
- eval_dataloader = torch.utils.data.DataLoader(
- dataset_val,
- num_workers=train_config.num_workers_dataloader,
- pin_memory=True,
- **val_dl_kwargs,
- )
- print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
- if len(eval_dataloader) == 0:
- raise ValueError(
- f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
- )
- else:
- print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
- # Initialize the optimizer and learning rate scheduler
- if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
- optimizer = AnyPrecisionAdamW(
- model.parameters(),
- lr=train_config.lr,
- momentum_dtype=torch.bfloat16,
- variance_dtype=torch.bfloat16,
- use_kahan_summation=False,
- weight_decay=train_config.weight_decay,
- )
- else:
- optimizer = optim.AdamW(
- model.parameters(),
- lr=train_config.lr,
- weight_decay=train_config.weight_decay,
- )
- scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
- results = train(
- model,
- train_dataloader,
- eval_dataloader,
- tokenizer,
- optimizer,
- scheduler,
- train_config.gradient_accumulation_steps,
- train_config,
- fsdp_config if train_config.enable_fsdp else None,
- local_rank if train_config.enable_fsdp else None,
- rank if train_config.enable_fsdp else None,
- wandb_run,
- )
- if not train_config.enable_fsdp or rank == 0:
- [print(f"Key: {k}, Value: {v}") for k, v in results.items()]
- if train_config.use_wandb:
- for k, v in results.items():
- wandb_run.summary[k] = v
- if __name__ == "__main__":
- fire.Fire(main)
|