|
@@ -154,12 +154,13 @@ def main(**kwargs):
|
|
|
# Load the pre-trained peft model checkpoint and setup its configuration
|
|
|
if train_config.from_peft_checkpoint:
|
|
|
model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
|
|
|
+ peft_config = model.peft_config()
|
|
|
# Generate the peft config and start fine-tuning from original model
|
|
|
else:
|
|
|
peft_config = generate_peft_config(train_config, kwargs)
|
|
|
model = get_peft_model(model, peft_config)
|
|
|
- if wandb_run:
|
|
|
- wandb_run.config.update(peft_config)
|
|
|
+ if wandb_run:
|
|
|
+ wandb_run.config.update(peft_config)
|
|
|
model.print_trainable_parameters()
|
|
|
|
|
|
|