|
@@ -8,7 +8,7 @@ import fire
|
|
|
import random
|
|
|
import torch
|
|
|
import torch.optim as optim
|
|
|
-from peft import get_peft_model, prepare_model_for_kbit_training
|
|
|
+from peft import get_peft_model, prepare_model_for_kbit_training, PeftModel
|
|
|
from torch.distributed.fsdp import (
|
|
|
FullyShardedDataParallel as FSDP,
|
|
|
ShardingStrategy
|
|
@@ -134,7 +134,7 @@ def main(**kwargs):
|
|
|
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
|
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
|
|
|
- # If there is a mismatch between tokenizer vocab size and embedding matrix,
|
|
|
+ # If there is a mismatch between tokenizer vocab size and embedding matrix,
|
|
|
# throw a warning and then expand the embedding matrix
|
|
|
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
|
|
|
print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
|
|
@@ -151,11 +151,16 @@ def main(**kwargs):
|
|
|
model.to(torch.bfloat16)
|
|
|
|
|
|
if train_config.use_peft:
|
|
|
- peft_config = generate_peft_config(train_config, kwargs)
|
|
|
- model = get_peft_model(model, peft_config)
|
|
|
+ # Load the pre-trained peft model checkpoint and setup its configuration
|
|
|
+ if train_config.from_peft_checkpoint:
|
|
|
+ model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
|
|
|
+ # Generate the peft config and start fine-tuning from original model
|
|
|
+ else:
|
|
|
+ peft_config = generate_peft_config(train_config, kwargs)
|
|
|
+ model = get_peft_model(model, peft_config)
|
|
|
+ if wandb_run:
|
|
|
+ wandb_run.config.update(peft_config)
|
|
|
model.print_trainable_parameters()
|
|
|
- if wandb_run:
|
|
|
- wandb_run.config.update(peft_config)
|
|
|
|
|
|
|
|
|
hsdp_device_mesh = None
|