|
@@ -1,35 +1,37 @@
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
-from pathlib import Path
|
|
|
-from datetime import datetime
|
|
|
-import torch
|
|
|
import time
|
|
|
+from datetime import datetime
|
|
|
+from pathlib import Path
|
|
|
|
|
|
-from torch.distributed.fsdp import (
|
|
|
- FullyShardedDataParallel as FSDP,
|
|
|
- StateDictType,
|
|
|
- FullStateDictConfig, # general model non-sharded, non-flattened params
|
|
|
- LocalStateDictConfig, # flattened params, usable only by FSDP
|
|
|
- # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
|
|
|
-)
|
|
|
+import torch
|
|
|
+import torch.distributed as dist
|
|
|
|
|
|
-from torch.distributed._shard.checkpoint import (
|
|
|
+from torch.distributed.checkpoint import (
|
|
|
FileSystemReader,
|
|
|
FileSystemWriter,
|
|
|
- save_state_dict,
|
|
|
load_state_dict,
|
|
|
+ save_state_dict,
|
|
|
)
|
|
|
from torch.distributed.checkpoint.default_planner import (
|
|
|
- DefaultSavePlanner,
|
|
|
DefaultLoadPlanner,
|
|
|
+ DefaultSavePlanner,
|
|
|
)
|
|
|
|
|
|
+from torch.distributed.checkpoint.state_dict import (
|
|
|
+ get_model_state_dict,
|
|
|
+ StateDictOptions,
|
|
|
+)
|
|
|
|
|
|
-from torch.distributed.checkpoint.state_dict import get_model_state_dict, StateDictOptions
|
|
|
+from torch.distributed.fsdp import (
|
|
|
+ FullStateDictConfig, # general model non-sharded, non-flattened params
|
|
|
+ FullyShardedDataParallel as FSDP,
|
|
|
+ LocalStateDictConfig, # flattened params, usable only by FSDP
|
|
|
+ StateDictType,
|
|
|
+ # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
|
|
|
+)
|
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
|
|
-import torch.distributed._shard.checkpoint as dist_cp
|
|
|
-import torch.distributed as dist
|
|
|
|
|
|
|
|
|
def get_date_of_run():
|
|
@@ -62,7 +64,7 @@ def load_model_sharded(model, rank, cfg):
|
|
|
print(f"No sharded_state_dict checkpoint directory found...skipping")
|
|
|
return
|
|
|
if rank == 0:
|
|
|
- print(f"loading model from model path: {load_dir} ")
|
|
|
+ print(f"loading model from model path: {load_dir} ")
|
|
|
reader = FileSystemReader(load_dir)
|
|
|
|
|
|
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
|
|
@@ -70,8 +72,8 @@ def load_model_sharded(model, rank, cfg):
|
|
|
if rank == 0:
|
|
|
ck = checkpoint.keys()
|
|
|
print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
|
|
|
-
|
|
|
- dist_cp.load_state_dict(
|
|
|
+
|
|
|
+ load_state_dict(
|
|
|
state_dict=checkpoint,
|
|
|
storage_reader=reader,
|
|
|
)
|
|
@@ -84,9 +86,9 @@ def load_model_sharded(model, rank, cfg):
|
|
|
print(f"Sharded state checkpoint loaded from {load_dir}")
|
|
|
|
|
|
|
|
|
-def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
|
|
|
+def save_model_and_optimizer_sharded(model, rank, cfg, optim=None):
|
|
|
"""save model and optimizer via sharded_state_dict to save_dir"""
|
|
|
-
|
|
|
+
|
|
|
folder_name = (
|
|
|
cfg.dist_checkpoint_root_folder
|
|
|
+ "/"
|
|
@@ -99,30 +101,29 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
|
|
|
if rank == 0:
|
|
|
print(f"Saving model to {save_dir}")
|
|
|
|
|
|
- distributed_writer = dist_cp.FileSystemWriter(
|
|
|
+ distributed_writer = FileSystemWriter(
|
|
|
save_dir,
|
|
|
)
|
|
|
t0 = time.perf_counter()
|
|
|
|
|
|
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
|
|
|
-
|
|
|
+
|
|
|
state_dict = {"model": model.state_dict()}
|
|
|
if optim is not None:
|
|
|
state_dict["optim"] = FSDP.optim_state_dict(model, optim)
|
|
|
|
|
|
- dist_cp.save_state_dict(
|
|
|
+ save_state_dict(
|
|
|
state_dict=state_dict,
|
|
|
storage_writer=distributed_writer,
|
|
|
planner=DefaultSavePlanner(),
|
|
|
-
|
|
|
)
|
|
|
dist.barrier()
|
|
|
t1 = time.perf_counter()
|
|
|
if rank == 0:
|
|
|
print(f"Sharded state checkpoint saved to {save_dir}")
|
|
|
- print(
|
|
|
- f"Checkpoint Time = {t1-t0:.4f}\n"
|
|
|
- )
|
|
|
+ print(f"Checkpoint Time = {t1-t0:.4f}\n")
|
|
|
+
|
|
|
+
|
|
|
def save_fsdp_model_checkpoint_full(
|
|
|
model,
|
|
|
optimizer,
|
|
@@ -138,29 +139,26 @@ def save_fsdp_model_checkpoint_full(
|
|
|
cpu_state = model.state_dict()
|
|
|
|
|
|
print(f"saving process: rank {rank} done w model state_dict\n")
|
|
|
-
|
|
|
|
|
|
if rank == 0:
|
|
|
print(f"--> saving model ...")
|
|
|
# create save path
|
|
|
folder_name = (
|
|
|
- cfg.dist_checkpoint_root_folder
|
|
|
- + "/"
|
|
|
- + cfg.dist_checkpoint_folder
|
|
|
- + "-"
|
|
|
- + cfg.model_name
|
|
|
+ cfg.dist_checkpoint_root_folder
|
|
|
+ + "/"
|
|
|
+ + cfg.dist_checkpoint_folder
|
|
|
+ + "-"
|
|
|
+ + cfg.model_name
|
|
|
)
|
|
|
save_dir = Path.cwd() / folder_name
|
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
- save_name = cfg.model_name.replace("/","--") + "-" + str(epoch) + ".pt"
|
|
|
+ save_name = cfg.model_name.replace("/", "--") + "-" + str(epoch) + ".pt"
|
|
|
save_full_path = str(save_dir) + "/" + save_name
|
|
|
|
|
|
# save model
|
|
|
torch.save(cpu_state, save_full_path)
|
|
|
|
|
|
-
|
|
|
print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
|
|
|
-
|
|
|
|
|
|
|
|
|
def load_model_checkpoint(model, rank, cfg):
|
|
@@ -181,42 +179,36 @@ def load_model_checkpoint(model, rank, cfg):
|
|
|
)
|
|
|
return
|
|
|
|
|
|
-
|
|
|
model_checkpoint = torch.load(full_state_dict_model_path)
|
|
|
# integrate into loaded model
|
|
|
model.load_state_dict(model_checkpoint)
|
|
|
|
|
|
-
|
|
|
print(f"model checkpoint loaded to rank0 cpu")
|
|
|
|
|
|
|
|
|
def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
|
|
|
"""save optimizer state via full state dict"""
|
|
|
|
|
|
-
|
|
|
print(f"--> optim state call on rank {rank}\n")
|
|
|
|
|
|
# pull all sharded optimizer states to rank0 cpu...
|
|
|
|
|
|
optim_state = FSDP.full_optim_state_dict(model, optimizer)
|
|
|
|
|
|
-
|
|
|
print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
|
|
|
|
|
|
if rank == 0:
|
|
|
folder_name = (
|
|
|
- cfg.dist_checkpoint_root_folder
|
|
|
- + "/"
|
|
|
- + cfg.dist_checkpoint_folder
|
|
|
- + "-"
|
|
|
- + cfg.model_name
|
|
|
+ cfg.dist_checkpoint_root_folder
|
|
|
+ + "/"
|
|
|
+ + cfg.dist_checkpoint_folder
|
|
|
+ + "-"
|
|
|
+ + cfg.model_name
|
|
|
)
|
|
|
save_dir = Path.cwd() / folder_name
|
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
- opt_save_name = (
|
|
|
- "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
|
|
|
- )
|
|
|
+ opt_save_name = "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
|
|
|
opt_save_full_path = save_dir / opt_save_name
|
|
|
|
|
|
print(f"--> saving optimizer state...")
|
|
@@ -231,7 +223,6 @@ def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
|
|
|
this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
|
|
|
"""
|
|
|
|
|
|
-
|
|
|
if not optimizer_checkpoint_path.is_file():
|
|
|
print(
|
|
|
f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
|
|
@@ -248,43 +239,42 @@ def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
|
|
|
|
|
|
print(f"optimizer shard loaded on rank {rank}")
|
|
|
|
|
|
-def load_sharded_model_single_gpu(model,model_path):
|
|
|
-
|
|
|
+
|
|
|
+def load_sharded_model_single_gpu(model, model_path):
|
|
|
+
|
|
|
reader = FileSystemReader(model_path)
|
|
|
-
|
|
|
- state_dict = {
|
|
|
- "model": model.state_dict()
|
|
|
- }
|
|
|
-
|
|
|
- dist_cp.load_state_dict(
|
|
|
- state_dict=state_dict,
|
|
|
- storage_reader= FileSystemReader(model_path),
|
|
|
- no_dist=True,
|
|
|
- )
|
|
|
-
|
|
|
+
|
|
|
+ state_dict = {"model": model.state_dict()}
|
|
|
+
|
|
|
+ load_state_dict(
|
|
|
+ state_dict=state_dict,
|
|
|
+ storage_reader=FileSystemReader(model_path),
|
|
|
+ no_dist=True,
|
|
|
+ )
|
|
|
+
|
|
|
model.load_state_dict(state_dict["model"])
|
|
|
-
|
|
|
+
|
|
|
print(f"Sharded state checkpoint loaded from {model_path}")
|
|
|
return model
|
|
|
|
|
|
+
|
|
|
def save_peft_checkpoint(model, model_path):
|
|
|
"""save_pretrained peft model"""
|
|
|
|
|
|
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
|
|
|
-
|
|
|
+
|
|
|
if isinstance(model, FSDP):
|
|
|
state_dict = get_model_state_dict(model, options=options)
|
|
|
model.save_pretrained(model_path, state_dict=state_dict)
|
|
|
else:
|
|
|
model.save_pretrained(model_path)
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
+
|
|
|
def save_model_checkpoint(model, output_dir):
|
|
|
"""save model when not peft and on single device"""
|
|
|
-
|
|
|
+
|
|
|
output_file = Path(output_dir) / "model.pt"
|
|
|
-
|
|
|
+
|
|
|
state_dict = model.state_dict()
|
|
|
-
|
|
|
+
|
|
|
torch.save(state_dict, output_file)
|
|
|
-
|