Forráskód Böngészése

resume the finetuning given the path of the previous peft checkpoint folder

Kai Wu 11 hónapja
szülő
commit
480c4f2b5e
2 módosított fájl, 12 hozzáadás és 6 törlés
  1. 1 0
      src/llama_recipes/configs/training.py
  2. 11 6
      src/llama_recipes/finetuning.py

+ 1 - 0
src/llama_recipes/configs/training.py

@@ -31,6 +31,7 @@ class train_config:
     dataset = "samsum_dataset"
     peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
     use_peft: bool=False
+    from_peft_checkpoint: str="" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1

+ 11 - 6
src/llama_recipes/finetuning.py

@@ -8,7 +8,7 @@ import fire
 import random
 import torch
 import torch.optim as optim
-from peft import get_peft_model, prepare_model_for_kbit_training
+from peft import get_peft_model, prepare_model_for_kbit_training, PeftModel
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
     ShardingStrategy
@@ -134,7 +134,7 @@ def main(**kwargs):
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
     tokenizer.pad_token_id = tokenizer.eos_token_id
 
-    # If there is a mismatch between tokenizer vocab size and embedding matrix, 
+    # If there is a mismatch between tokenizer vocab size and embedding matrix,
     # throw a warning and then expand the embedding matrix
     if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
         print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
@@ -151,11 +151,16 @@ def main(**kwargs):
         model.to(torch.bfloat16)
 
     if train_config.use_peft:
-        peft_config = generate_peft_config(train_config, kwargs)
-        model = get_peft_model(model, peft_config)
+        # Load the pre-trained peft model checkpoint and setup its configuration
+        if train_config.from_peft_checkpoint:
+            model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
+        # Generate the peft config and start fine-tuning from original model
+        else:
+            peft_config = generate_peft_config(train_config, kwargs)
+            model = get_peft_model(model, peft_config)
+            if wandb_run:
+                wandb_run.config.update(peft_config)
         model.print_trainable_parameters()
-        if wandb_run:
-            wandb_run.config.update(peft_config)
 
 
     hsdp_device_mesh = None