|
@@ -1,6 +1,7 @@
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
+from collections import Counter
|
|
|
import os
|
|
|
|
|
|
import dataclasses
|
|
@@ -8,7 +9,7 @@ import fire
|
|
|
import random
|
|
|
import torch
|
|
|
import torch.optim as optim
|
|
|
-from peft import get_peft_model, prepare_model_for_kbit_training, PeftModel
|
|
|
+from peft import get_peft_model, PeftModel
|
|
|
from torch.distributed.fsdp import (
|
|
|
FullyShardedDataParallel as FSDP,
|
|
|
ShardingStrategy
|
|
@@ -18,6 +19,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
|
|
from torch.optim.lr_scheduler import StepLR
|
|
|
from transformers import (
|
|
|
AutoTokenizer,
|
|
|
+ BitsAndBytesConfig,
|
|
|
LlamaForCausalLM,
|
|
|
LlamaConfig,
|
|
|
)
|
|
@@ -25,6 +27,7 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
|
|
|
from llama_recipes.configs import fsdp_config as FSDP_CONFIG
|
|
|
from llama_recipes.configs import train_config as TRAIN_CONFIG
|
|
|
+from llama_recipes.configs import quantization_config as QUANTIZATION_CONFIG
|
|
|
from llama_recipes.data.concatenator import ConcatDataset
|
|
|
from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
|
|
|
|
|
@@ -48,6 +51,7 @@ from llama_recipes.utils.train_utils import (
|
|
|
get_policies,
|
|
|
)
|
|
|
from accelerate.utils import is_xpu_available
|
|
|
+from warnings import warn
|
|
|
|
|
|
def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
try:
|
|
@@ -66,7 +70,6 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
run.config.update(fsdp_config, allow_val_change=True)
|
|
|
return run
|
|
|
|
|
|
-
|
|
|
def main(**kwargs):
|
|
|
# Update the configuration for the training and sharding process
|
|
|
train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
|
|
@@ -97,38 +100,31 @@ def main(**kwargs):
|
|
|
if train_config.use_wandb:
|
|
|
if not train_config.enable_fsdp or rank==0:
|
|
|
wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
|
|
|
+
|
|
|
+ #setting quantization configs
|
|
|
+ bnb_config = None
|
|
|
+ if train_config.quantization:
|
|
|
+ if type(train_config.quantization) == type(True):
|
|
|
+ warn("Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", FutureWarning)
|
|
|
+ train_config.quantization = "8bit"
|
|
|
+
|
|
|
+ if train_config.quantization == "8bit" and train_config.enable_fsdp:
|
|
|
+ raise ValueError("8bit quantization is not supported with FSDP, please use 4bit quantization")
|
|
|
+
|
|
|
+ quant_config = QUANTIZATION_CONFIG()
|
|
|
+ update_config(quant_config, **kwargs)
|
|
|
+ bnb_config = quant_config.create_bnb_config(train_config.quantization)
|
|
|
|
|
|
# Load the pre-trained model and setup its configuration
|
|
|
use_cache = False if train_config.enable_fsdp else None
|
|
|
- if train_config.enable_fsdp and train_config.low_cpu_fsdp:
|
|
|
- """
|
|
|
- for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
|
|
|
- this avoids cpu oom when loading large models like llama 70B, in which case
|
|
|
- model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
|
|
|
- overhead and currently requires latest nightly.
|
|
|
- """
|
|
|
- if rank == 0:
|
|
|
- model = LlamaForCausalLM.from_pretrained(
|
|
|
- train_config.model_name,
|
|
|
- load_in_8bit=True if train_config.quantization else None,
|
|
|
- device_map="auto" if train_config.quantization else None,
|
|
|
- use_cache=use_cache,
|
|
|
- attn_implementation="sdpa" if train_config.use_fast_kernels else None,
|
|
|
- )
|
|
|
- else:
|
|
|
- llama_config = LlamaConfig.from_pretrained(train_config.model_name)
|
|
|
- llama_config.use_cache = use_cache
|
|
|
- with torch.device("meta"):
|
|
|
- model = LlamaForCausalLM(llama_config)
|
|
|
-
|
|
|
- else:
|
|
|
- model = LlamaForCausalLM.from_pretrained(
|
|
|
- train_config.model_name,
|
|
|
- load_in_8bit=True if train_config.quantization else None,
|
|
|
- device_map="auto" if train_config.quantization else None,
|
|
|
- use_cache=use_cache,
|
|
|
- attn_implementation="sdpa" if train_config.use_fast_kernels else None,
|
|
|
- )
|
|
|
+ model = LlamaForCausalLM.from_pretrained(
|
|
|
+ train_config.model_name,
|
|
|
+ quantization_config=bnb_config,
|
|
|
+ use_cache=use_cache,
|
|
|
+ attn_implementation="sdpa" if train_config.use_fast_kernels else None,
|
|
|
+ device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
|
|
|
+ torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
|
|
|
+ )
|
|
|
|
|
|
# Load the tokenizer and add special tokens
|
|
|
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
|
|
@@ -142,14 +138,10 @@ def main(**kwargs):
|
|
|
|
|
|
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
|
|
|
|
|
|
- # Prepare the model for int8 training if quantization is enabled
|
|
|
- if train_config.quantization:
|
|
|
- model = prepare_model_for_kbit_training(model)
|
|
|
-
|
|
|
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
|
|
|
- if train_config.enable_fsdp and fsdp_config.pure_bf16:
|
|
|
+ if train_config.enable_fsdp and fsdp_config.pure_bf16 and not train_config.quantization:
|
|
|
model.to(torch.bfloat16)
|
|
|
-
|
|
|
+
|
|
|
if train_config.use_peft:
|
|
|
# Load the pre-trained peft model checkpoint and setup its configuration
|
|
|
if train_config.from_peft_checkpoint:
|
|
@@ -181,7 +173,6 @@ def main(**kwargs):
|
|
|
device_id = torch.xpu.current_device()
|
|
|
elif torch.cuda.is_available():
|
|
|
device_id = torch.cuda.current_device()
|
|
|
-
|
|
|
model = FSDP(
|
|
|
model,
|
|
|
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
|
|
@@ -195,8 +186,10 @@ def main(**kwargs):
|
|
|
param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
|
|
|
if train_config.low_cpu_fsdp and rank != 0 else None,
|
|
|
)
|
|
|
- if fsdp_config.fsdp_activation_checkpointing:
|
|
|
- apply_fsdp_checkpointing(model)
|
|
|
+ if fsdp_config.fsdp_activation_checkpointing:
|
|
|
+ model.enable_input_require_grads()
|
|
|
+ model.gradient_checkpointing_enable()
|
|
|
+ apply_fsdp_checkpointing(model)
|
|
|
elif not train_config.quantization and not train_config.enable_fsdp:
|
|
|
if is_xpu_available():
|
|
|
model.to("xpu:0")
|
|
@@ -211,7 +204,6 @@ def main(**kwargs):
|
|
|
dataset_config,
|
|
|
split="train",
|
|
|
)
|
|
|
-
|
|
|
if not train_config.enable_fsdp or rank == 0:
|
|
|
print(f"--> Training Set Length = {len(dataset_train)}")
|
|
|
|
|
@@ -271,7 +263,6 @@ def main(**kwargs):
|
|
|
weight_decay=train_config.weight_decay,
|
|
|
)
|
|
|
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
|
|
|
-
|
|
|
# Start the training process
|
|
|
results = train(
|
|
|
model,
|