Prechádzať zdrojové kódy

Fix deprecation warning in checkpointing util

Matthias Reso 6 mesiacov pred
rodič
commit
24bd6de6fe

+ 61 - 71
src/llama_recipes/model_checkpointing/checkpoint_handler.py

@@ -1,35 +1,37 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-from pathlib import Path
-from datetime import datetime
-import torch
 import time
+from datetime import datetime
+from pathlib import Path
 
-from torch.distributed.fsdp import (
-    FullyShardedDataParallel as FSDP,
-    StateDictType,
-    FullStateDictConfig,  # general model non-sharded, non-flattened params
-    LocalStateDictConfig,  # flattened params, usable only by FSDP
-    # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
-)
+import torch
+import torch.distributed as dist
 
-from torch.distributed._shard.checkpoint import (
+from torch.distributed.checkpoint import (
     FileSystemReader,
     FileSystemWriter,
-    save_state_dict,
     load_state_dict,
+    save_state_dict,
 )
 from torch.distributed.checkpoint.default_planner import (
-    DefaultSavePlanner,
     DefaultLoadPlanner,
+    DefaultSavePlanner,
 )
 
+from torch.distributed.checkpoint.state_dict import (
+    get_model_state_dict,
+    StateDictOptions,
+)
 
-from torch.distributed.checkpoint.state_dict import get_model_state_dict, StateDictOptions
+from torch.distributed.fsdp import (
+    FullStateDictConfig,  # general model non-sharded, non-flattened params
+    FullyShardedDataParallel as FSDP,
+    LocalStateDictConfig,  # flattened params, usable only by FSDP
+    StateDictType,
+    # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
+)
 from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
-import torch.distributed._shard.checkpoint as dist_cp
-import torch.distributed as dist
 
 
 def get_date_of_run():
@@ -62,7 +64,7 @@ def load_model_sharded(model, rank, cfg):
             print(f"No sharded_state_dict checkpoint directory found...skipping")
         return
     if rank == 0:
-         print(f"loading model from model path: {load_dir} ")
+        print(f"loading model from model path: {load_dir} ")
     reader = FileSystemReader(load_dir)
 
     with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
@@ -70,8 +72,8 @@ def load_model_sharded(model, rank, cfg):
         if rank == 0:
             ck = checkpoint.keys()
             print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
-      
-        dist_cp.load_state_dict(
+
+        load_state_dict(
             state_dict=checkpoint,
             storage_reader=reader,
         )
@@ -84,9 +86,9 @@ def load_model_sharded(model, rank, cfg):
         print(f"Sharded state checkpoint loaded from {load_dir}")
 
 
-def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
+def save_model_and_optimizer_sharded(model, rank, cfg, optim=None):
     """save model and optimizer via sharded_state_dict to save_dir"""
-    
+
     folder_name = (
         cfg.dist_checkpoint_root_folder
         + "/"
@@ -99,30 +101,29 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
     if rank == 0:
         print(f"Saving model to {save_dir}")
 
-    distributed_writer = dist_cp.FileSystemWriter(
+    distributed_writer = FileSystemWriter(
         save_dir,
     )
     t0 = time.perf_counter()
 
     with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
-        
+
         state_dict = {"model": model.state_dict()}
         if optim is not None:
             state_dict["optim"] = FSDP.optim_state_dict(model, optim)
 
-        dist_cp.save_state_dict(
+        save_state_dict(
             state_dict=state_dict,
             storage_writer=distributed_writer,
             planner=DefaultSavePlanner(),
-            
         )
     dist.barrier()
     t1 = time.perf_counter()
     if rank == 0:
         print(f"Sharded state checkpoint saved to {save_dir}")
-        print(
-            f"Checkpoint Time = {t1-t0:.4f}\n"
-        )
+        print(f"Checkpoint Time = {t1-t0:.4f}\n")
+
+
 def save_fsdp_model_checkpoint_full(
     model,
     optimizer,
@@ -138,29 +139,26 @@ def save_fsdp_model_checkpoint_full(
         cpu_state = model.state_dict()
 
         print(f"saving process: rank {rank}  done w model state_dict\n")
-   
 
     if rank == 0:
         print(f"--> saving model ...")
         # create save path
         folder_name = (
-        cfg.dist_checkpoint_root_folder
-        + "/"
-        + cfg.dist_checkpoint_folder
-        + "-"
-        + cfg.model_name
+            cfg.dist_checkpoint_root_folder
+            + "/"
+            + cfg.dist_checkpoint_folder
+            + "-"
+            + cfg.model_name
         )
         save_dir = Path.cwd() / folder_name
         save_dir.mkdir(parents=True, exist_ok=True)
-        save_name = cfg.model_name.replace("/","--") + "-" + str(epoch) + ".pt"
+        save_name = cfg.model_name.replace("/", "--") + "-" + str(epoch) + ".pt"
         save_full_path = str(save_dir) + "/" + save_name
 
         # save model
         torch.save(cpu_state, save_full_path)
 
-        
         print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
-      
 
 
 def load_model_checkpoint(model, rank, cfg):
@@ -181,42 +179,36 @@ def load_model_checkpoint(model, rank, cfg):
         )
         return
 
-
     model_checkpoint = torch.load(full_state_dict_model_path)
     # integrate into loaded model
     model.load_state_dict(model_checkpoint)
 
-    
     print(f"model checkpoint loaded to rank0 cpu")
 
 
 def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
     """save optimizer state via full state dict"""
 
-   
     print(f"--> optim state call on rank {rank}\n")
 
     # pull all sharded optimizer states to rank0 cpu...
 
     optim_state = FSDP.full_optim_state_dict(model, optimizer)
 
-    
     print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
 
     if rank == 0:
         folder_name = (
-        cfg.dist_checkpoint_root_folder
-        + "/"
-        + cfg.dist_checkpoint_folder
-        + "-"
-        + cfg.model_name
+            cfg.dist_checkpoint_root_folder
+            + "/"
+            + cfg.dist_checkpoint_folder
+            + "-"
+            + cfg.model_name
         )
         save_dir = Path.cwd() / folder_name
         save_dir.mkdir(parents=True, exist_ok=True)
 
-        opt_save_name = (
-            "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
-        )
+        opt_save_name = "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
         opt_save_full_path = save_dir / opt_save_name
 
         print(f"--> saving optimizer state...")
@@ -231,7 +223,6 @@ def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
     this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
     """
 
-
     if not optimizer_checkpoint_path.is_file():
         print(
             f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
@@ -248,43 +239,42 @@ def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
 
     print(f"optimizer shard loaded on rank {rank}")
 
-def load_sharded_model_single_gpu(model,model_path):
-    
+
+def load_sharded_model_single_gpu(model, model_path):
+
     reader = FileSystemReader(model_path)
-    
-    state_dict = {
-        "model": model.state_dict()
-    }
-    
-    dist_cp.load_state_dict(
-                state_dict=state_dict,
-                storage_reader= FileSystemReader(model_path),
-                no_dist=True,
-            )
-    
+
+    state_dict = {"model": model.state_dict()}
+
+    load_state_dict(
+        state_dict=state_dict,
+        storage_reader=FileSystemReader(model_path),
+        no_dist=True,
+    )
+
     model.load_state_dict(state_dict["model"])
-    
+
     print(f"Sharded state checkpoint loaded from {model_path}")
     return model
 
+
 def save_peft_checkpoint(model, model_path):
     """save_pretrained peft model"""
 
     options = StateDictOptions(full_state_dict=True, cpu_offload=True)
-    
+
     if isinstance(model, FSDP):
         state_dict = get_model_state_dict(model, options=options)
         model.save_pretrained(model_path, state_dict=state_dict)
     else:
         model.save_pretrained(model_path)
-    
-    
+
+
 def save_model_checkpoint(model, output_dir):
     """save model when not peft and on single device"""
-    
+
     output_file = Path(output_dir) / "model.pt"
-    
+
     state_dict = model.state_dict()
-    
+
     torch.save(state_dict, output_file)
-