|
@@ -116,7 +116,7 @@ def main(**kwargs):
|
|
|
if not train_config.enable_fsdp or rank == 0:
|
|
|
wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
|
|
|
|
|
|
- # setting quantization configs
|
|
|
+ # setting quantization configs
|
|
|
bnb_config = None
|
|
|
if train_config.quantization:
|
|
|
if type(train_config.quantization) == type(True):
|
|
@@ -135,6 +135,17 @@ def main(**kwargs):
|
|
|
update_config(quant_config, **kwargs)
|
|
|
bnb_config = quant_config.create_bnb_config(train_config.quantization)
|
|
|
|
|
|
+
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ if train_config.quantization == "4bit":
|
|
|
+ bnb_config.bnb_4bit_quant_storage = bnb_config.bnb_4bit_compute_dtype
|
|
|
+ from logging import getLogger
|
|
|
+ logger = getLogger()
|
|
|
+ logger.warning(
|
|
|
+ "FSDP and 4-bit QLoRA enabled. Setting `bnb_4bit_quant_storage` "
|
|
|
+ f"to {bnb_config.bnb_4bit_compute_dtype} for compatibility."
|
|
|
+ )
|
|
|
+
|
|
|
# Load the pre-trained model and setup its configuration
|
|
|
use_cache = False if train_config.enable_fsdp else None
|
|
|
config = AutoConfig.from_pretrained(train_config.model_name)
|
|
@@ -264,10 +275,9 @@ def main(**kwargs):
|
|
|
elif torch.cuda.is_available():
|
|
|
device_id = torch.cuda.current_device()
|
|
|
|
|
|
- if train_config.freeze_LLM_only:
|
|
|
- use_orig_params = True
|
|
|
- else:
|
|
|
- use_orig_params = False
|
|
|
+
|
|
|
+ use_orig_params = train_config.freeze_LLM_only or train_config.use_peft
|
|
|
+
|
|
|
model = FSDP(
|
|
|
model,
|
|
|
auto_wrap_policy=(
|