|
@@ -27,7 +27,7 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
|
|
|
from llama_recipes.configs import fsdp_config as FSDP_CONFIG
|
|
|
from llama_recipes.configs import train_config as TRAIN_CONFIG
|
|
|
-from llama_recipes.configs.quantization import quantization_config
|
|
|
+from llama_recipes.configs import quantization_config as QUANTIZATION_CONFIG
|
|
|
from llama_recipes.data.concatenator import ConcatDataset
|
|
|
from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
|
|
|
|
|
@@ -51,6 +51,7 @@ from llama_recipes.utils.train_utils import (
|
|
|
get_policies,
|
|
|
)
|
|
|
from accelerate.utils import is_xpu_available
|
|
|
+from warnings import warn
|
|
|
|
|
|
def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
try:
|
|
@@ -100,14 +101,19 @@ def main(**kwargs):
|
|
|
if not train_config.enable_fsdp or rank==0:
|
|
|
wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
|
|
|
|
|
|
- #setting quantization configs
|
|
|
- quant_config = quantization_config(
|
|
|
- quant_type=train_config.quantization,
|
|
|
- compute_dtype=torch.bfloat16,
|
|
|
- use_double_quant=True,
|
|
|
- quant_storage=torch.bfloat16
|
|
|
- ) if train_config.quantization else None
|
|
|
- bnb_config = quant_config.create_bnb_config() if quant_config else None
|
|
|
+ #setting quantization configs
|
|
|
+ bnb_config = None
|
|
|
+ if train_config.quantization:
|
|
|
+ if type(train_config.quantization) == type(True):
|
|
|
+ warn("Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", FutureWarning)
|
|
|
+ train_config.quantization = "8bit"
|
|
|
+
|
|
|
+ if train_config.quantization == "8bit" and train_config.enable_fsdp:
|
|
|
+ raise ValueError("8bit quantization is not supported with FSDP, please use 4bit quantization")
|
|
|
+
|
|
|
+ quant_config = QUANTIZATION_CONFIG()
|
|
|
+ update_config(quant_config, **kwargs)
|
|
|
+ bnb_config = quant_config.create_bnb_config(train_config.quantization)
|
|
|
|
|
|
# Load the pre-trained model and setup its configuration
|
|
|
use_cache = False if train_config.enable_fsdp else None
|