Bläddra i källkod

Resume the fine-tuning process from the previous PEFT checkpoint folder (#531)

Kai Wu 11 månader sedan
förälder
incheckning
44b66374be

+ 3 - 2
docs/multi_gpu.md

@@ -34,7 +34,7 @@ The args used in the command above are:
 
 * `--use_peft` boolean flag to enable PEFT methods in the script
 
-* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`, `prefix`.
+* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`.
 
 We use `torchrun` here to spawn multiple processes for FSDP.
 
@@ -138,8 +138,9 @@ It lets us specify the training settings for everything from `model_name` to `da
     mixed_precision: bool=True
     val_batch_size: int=1
     dataset = "samsum_dataset"
-    peft_method: str = "lora" # None,llama_adapter, prefix
+    peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
     use_peft: bool=False
+    from_peft_checkpoint: str="" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1

+ 4 - 2
docs/single_gpu.md

@@ -27,7 +27,7 @@ The args used in the command above are:
 
 * `--use_peft` boolean flag to enable PEFT methods in the script
 
-* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`, `prefix`.
+* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`.
 
 * `--quantization` boolean flag to enable int8 quantization
 
@@ -94,8 +94,9 @@ It let us specify the training settings, everything from `model_name` to `datase
     mixed_precision: bool=True
     val_batch_size: int=1
     dataset = "samsum_dataset"
-    peft_method: str = "lora" # None,llama_adapter, prefix
+    peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
     use_peft: bool=False
+    from_peft_checkpoint: str="" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1
@@ -112,6 +113,7 @@ It let us specify the training settings, everything from `model_name` to `datase
     flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
     use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
     profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
+
 ```
 
 * [Datasets config file](../src/llama_recipes/configs/datasets.py) provides the available options for datasets.

+ 3 - 1
recipes/finetuning/README.md

@@ -48,8 +48,9 @@ It lets us specify the training settings for everything from `model_name` to `da
     mixed_precision: bool=True
     val_batch_size: int=1
     dataset = "samsum_dataset"
-    peft_method: str = "lora" # None,llama_adapter, prefix
+    peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
     use_peft: bool=False
+    from_peft_checkpoint: str="" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1
@@ -66,6 +67,7 @@ It lets us specify the training settings for everything from `model_name` to `da
     flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
     use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
     profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
+
 ```
 
 * [Datasets config file](../../src/llama_recipes/configs/datasets.py) provides the available options for datasets.

+ 1 - 0
src/llama_recipes/configs/training.py

@@ -31,6 +31,7 @@ class train_config:
     dataset = "samsum_dataset"
     peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
     use_peft: bool=False
+    from_peft_checkpoint: str="" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1

+ 11 - 5
src/llama_recipes/finetuning.py

@@ -8,7 +8,7 @@ import fire
 import random
 import torch
 import torch.optim as optim
-from peft import get_peft_model, prepare_model_for_kbit_training
+from peft import get_peft_model, prepare_model_for_kbit_training, PeftModel
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
     ShardingStrategy
@@ -134,7 +134,7 @@ def main(**kwargs):
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
     tokenizer.pad_token_id = tokenizer.eos_token_id
 
-    # If there is a mismatch between tokenizer vocab size and embedding matrix, 
+    # If there is a mismatch between tokenizer vocab size and embedding matrix,
     # throw a warning and then expand the embedding matrix
     if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
         print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
@@ -151,11 +151,17 @@ def main(**kwargs):
         model.to(torch.bfloat16)
 
     if train_config.use_peft:
-        peft_config = generate_peft_config(train_config, kwargs)
-        model = get_peft_model(model, peft_config)
-        model.print_trainable_parameters()
+        # Load the pre-trained peft model checkpoint and setup its configuration
+        if train_config.from_peft_checkpoint:
+            model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
+            peft_config = model.peft_config()
+        # Generate the peft config and start fine-tuning from original model
+        else:
+            peft_config = generate_peft_config(train_config, kwargs)
+            model = get_peft_model(model, peft_config)
         if wandb_run:
             wandb_run.config.update(peft_config)
+        model.print_trainable_parameters()
 
 
     hsdp_device_mesh = None