瀏覽代碼

Use new get_model_state_dict api for save_pretrained peft model (#629)

Matthias Reso 8 月之前
父節點
當前提交
eca526526c

+ 1 - 0
src/llama_recipes/model_checkpointing/__init__.py

@@ -4,6 +4,7 @@
 from llama_recipes.model_checkpointing.checkpoint_handler import (
     load_model_checkpoint,
     save_model_checkpoint,
+    save_peft_checkpoint,
     load_optimizer_checkpoint,
     save_optimizer_checkpoint,
     save_model_and_optimizer_sharded,

+ 10 - 1
src/llama_recipes/model_checkpointing/checkpoint_handler.py

@@ -26,6 +26,7 @@ from torch.distributed.checkpoint.default_planner import (
 )
 
 
+from torch.distributed.checkpoint.state_dict import get_model_state_dict, StateDictOptions
 from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
 import torch.distributed._shard.checkpoint as dist_cp
 import torch.distributed as dist
@@ -264,4 +265,12 @@ def load_sharded_model_single_gpu(model,model_path):
     model.load_state_dict(state_dict["model"])
     
     print(f"Sharded state checkpoint loaded from {model_path}")
-    return model
+    return model
+
+def save_peft_checkpoint(model, model_path):
+    """save_pretrained peft model"""
+
+    options = StateDictOptions(full_state_dict=True, cpu_offload=True)
+
+    state_dict = get_model_state_dict(model, options=options)
+    model.save_pretrained(model_path, state_dict=state_dict)

+ 2 - 2
src/llama_recipes/utils/train_utils.py

@@ -20,7 +20,7 @@ from transformers import LlamaTokenizer
 import json
 
 
-from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
+from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_peft_checkpoint
 from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 from accelerate.utils import is_xpu_available, is_ccl_available
@@ -235,7 +235,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             print(f"we are about to save the PEFT modules")
                     else:
                         print(f"we are about to save the PEFT modules")
-                    model.save_pretrained(train_config.output_dir)
+                    save_peft_checkpoint(model, train_config.output_dir)
                     if train_config.enable_fsdp:
                         if rank==0:
                             print(f"PEFT modules are saved in {train_config.output_dir} directory")