checkpoint_handler.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from pathlib import Path
  4. from datetime import datetime
  5. import torch
  6. import time
  7. from torch.distributed.fsdp import (
  8. FullyShardedDataParallel as FSDP,
  9. StateDictType,
  10. FullStateDictConfig, # general model non-sharded, non-flattened params
  11. LocalStateDictConfig, # flattened params, usable only by FSDP
  12. # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
  13. )
  14. from torch.distributed._shard.checkpoint import (
  15. FileSystemReader,
  16. FileSystemWriter,
  17. save_state_dict,
  18. load_state_dict,
  19. )
  20. from torch.distributed.checkpoint.default_planner import (
  21. DefaultSavePlanner,
  22. DefaultLoadPlanner,
  23. )
  24. from torch.distributed.checkpoint.state_dict import get_model_state_dict, StateDictOptions
  25. from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
  26. import torch.distributed._shard.checkpoint as dist_cp
  27. import torch.distributed as dist
  28. def get_date_of_run():
  29. """create date and time for file save uniqueness
  30. example: 2022-05-07-08:31:12_PM'
  31. """
  32. date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
  33. print(f"--> current date and time of run = {date_of_run}")
  34. return date_of_run
  35. # create singleton saving policies to avoid making over and over
  36. fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
  37. def load_model_sharded(model, rank, cfg):
  38. # torch.manual_seed(103)
  39. folder_name = (
  40. cfg.dist_checkpoint_root_folder
  41. + "/"
  42. + cfg.dist_checkpoint_folder
  43. + "-"
  44. + cfg.model_name
  45. )
  46. load_dir = Path.cwd() / folder_name
  47. if not load_dir.exists():
  48. if rank == 0:
  49. print(f"No sharded_state_dict checkpoint directory found...skipping")
  50. return
  51. if rank == 0:
  52. print(f"loading model from model path: {load_dir} ")
  53. reader = FileSystemReader(load_dir)
  54. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
  55. checkpoint = {"model": model.state_dict()}
  56. if rank == 0:
  57. ck = checkpoint.keys()
  58. print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
  59. dist_cp.load_state_dict(
  60. state_dict=checkpoint,
  61. storage_reader=reader,
  62. )
  63. if rank == 0:
  64. print(f"checkpoint after load_state_dict()")
  65. ck = checkpoint.keys()
  66. print(f" checkpoint key len = {len(ck)} and \n keys = {ck}")
  67. model.load_state_dict(checkpoint["model"])
  68. if rank == 0:
  69. print(f"Sharded state checkpoint loaded from {load_dir}")
  70. def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
  71. """save model and optimizer via sharded_state_dict to save_dir"""
  72. folder_name = (
  73. cfg.dist_checkpoint_root_folder
  74. + "/"
  75. + cfg.dist_checkpoint_folder
  76. + "-"
  77. + cfg.model_name
  78. )
  79. save_dir = Path.cwd() / folder_name
  80. if rank == 0:
  81. print(f"Saving model to {save_dir}")
  82. distributed_writer = dist_cp.FileSystemWriter(
  83. save_dir,
  84. )
  85. t0 = time.perf_counter()
  86. with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
  87. state_dict = {"model": model.state_dict()}
  88. if optim is not None:
  89. state_dict["optim"] = FSDP.optim_state_dict(model, optim)
  90. dist_cp.save_state_dict(
  91. state_dict=state_dict,
  92. storage_writer=distributed_writer,
  93. planner=DefaultSavePlanner(),
  94. )
  95. dist.barrier()
  96. t1 = time.perf_counter()
  97. if rank == 0:
  98. print(f"Sharded state checkpoint saved to {save_dir}")
  99. print(
  100. f"Checkpoint Time = {t1-t0:.4f}\n"
  101. )
  102. def save_model_checkpoint(
  103. model,
  104. optimizer,
  105. rank,
  106. cfg,
  107. epoch=1,
  108. ):
  109. """saving model via rank0 cpu streaming and full_state_dict"""
  110. with FSDP.state_dict_type(
  111. model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
  112. ):
  113. cpu_state = model.state_dict()
  114. print(f"saving process: rank {rank} done w model state_dict\n")
  115. if rank == 0:
  116. print(f"--> saving model ...")
  117. # create save path
  118. folder_name = (
  119. cfg.dist_checkpoint_root_folder
  120. + "/"
  121. + cfg.dist_checkpoint_folder
  122. + "-"
  123. + cfg.model_name
  124. )
  125. save_dir = Path.cwd() / folder_name
  126. save_dir.mkdir(parents=True, exist_ok=True)
  127. save_name = cfg.model_name + "-" + str(epoch) + ".pt"
  128. save_full_path = str(save_dir) + "/" + save_name
  129. # save model
  130. torch.save(cpu_state, save_full_path)
  131. print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
  132. def load_model_checkpoint(model, rank, cfg):
  133. """load local checkpoint to rank0 cpu
  134. must be called * before * passing to FSDP"""
  135. if rank != 0:
  136. return
  137. # where is the checkpoint at...
  138. full_state_dict_model_path = (
  139. Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename
  140. )
  141. # is it present...
  142. if not full_state_dict_model_path.is_file():
  143. print(
  144. f"model checkpoint {full_state_dict_model_path} not present. Returning..."
  145. )
  146. return
  147. model_checkpoint = torch.load(full_state_dict_model_path)
  148. # integrate into loaded model
  149. model.load_state_dict(model_checkpoint)
  150. print(f"model checkpoint loaded to rank0 cpu")
  151. def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
  152. """save optimizer state via full state dict"""
  153. print(f"--> optim state call on rank {rank}\n")
  154. # pull all sharded optimizer states to rank0 cpu...
  155. optim_state = FSDP.full_optim_state_dict(model, optimizer)
  156. print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
  157. if rank == 0:
  158. folder_name = (
  159. cfg.dist_checkpoint_root_folder
  160. + "/"
  161. + cfg.dist_checkpoint_folder
  162. + "-"
  163. + cfg.model_name
  164. )
  165. save_dir = Path.cwd() / folder_name
  166. save_dir.mkdir(parents=True, exist_ok=True)
  167. opt_save_name = (
  168. "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
  169. )
  170. opt_save_full_path = save_dir / opt_save_name
  171. print(f"--> saving optimizer state...")
  172. torch.save(optim_state, opt_save_full_path)
  173. print(f"--> saved {opt_save_full_path} to disk")
  174. def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
  175. """load an fsdp optimizer full_state checkpoint using scatter method
  176. this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
  177. """
  178. if not optimizer_checkpoint_path.is_file():
  179. print(
  180. f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
  181. )
  182. return
  183. full_osd = None
  184. if rank == 0:
  185. full_osd = torch.load(optimizer_checkpoint_path)
  186. # called from all ranks, though only rank0 has a valid param for full_osd
  187. sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
  188. print(f"optimizer shard loaded on rank {rank}")
  189. def load_sharded_model_single_gpu(model,model_path):
  190. reader = FileSystemReader(model_path)
  191. state_dict = {
  192. "model": model.state_dict()
  193. }
  194. dist_cp.load_state_dict(
  195. state_dict=state_dict,
  196. storage_reader= FileSystemReader(model_path),
  197. no_dist=True,
  198. )
  199. model.load_state_dict(state_dict["model"])
  200. print(f"Sharded state checkpoint loaded from {model_path}")
  201. return model
  202. def save_peft_checkpoint(model, model_path):
  203. """save_pretrained peft model"""
  204. options = StateDictOptions(full_state_dict=True, cpu_offload=True)
  205. state_dict = get_model_state_dict(model, options=options)
  206. model.save_pretrained(model_path, state_dict=state_dict)