train_utils.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import time
  5. import yaml
  6. from contextlib import nullcontext
  7. from pathlib import Path
  8. from datetime import datetime
  9. import contextlib
  10. import torch
  11. import torch.cuda.nccl as nccl
  12. import torch.distributed as dist
  13. from torch.distributed.fsdp import StateDictType
  14. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  15. from tqdm import tqdm
  16. from transformers import LlamaTokenizer
  17. import json
  18. from llama_recipes.model_checkpointing import save_fsdp_model_checkpoint_full, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_peft_checkpoint, save_model_checkpoint
  19. from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
  20. from llama_recipes.utils.memory_utils import MemoryTrace
  21. from accelerate.utils import is_xpu_available, is_ccl_available
  22. from llama_recipes.utils.flop_utils import FlopMeasure
  23. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  24. tokenizer.pad_token_id = 0
  25. tokenizer.padding_side = "left"
  26. @contextlib.contextmanager
  27. def profile(cfg, local_rank=None):
  28. use_profiler: bool = cfg.use_profiler
  29. use_flop_counter: bool = cfg.flop_counter
  30. if use_flop_counter and use_profiler:
  31. raise ValueError("Cannot use both profiler and flop counter")
  32. if use_profiler:
  33. # profiler needs a warmup stage to get the accurate profiling results
  34. wait_step, warmup_step, active_step = 1, 2, 3
  35. min_step = wait_step + warmup_step + active_step + 1
  36. if cfg.max_train_step > 0 and cfg.max_train_step < min_step:
  37. raise ValueError(f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
  38. print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}")
  39. with torch.profiler.profile(
  40. activities=[
  41. torch.profiler.ProfilerActivity.CPU,
  42. torch.profiler.ProfilerActivity.CUDA,
  43. ],
  44. schedule=torch.profiler.schedule(wait=wait_step, warmup=warmup_step, active=active_step, repeat=1),
  45. on_trace_ready=torch.profiler.tensorboard_trace_handler(
  46. cfg.profiler_dir
  47. ),
  48. profile_memory=True,
  49. with_stack=False,
  50. with_flops=True,
  51. record_shapes=True,
  52. ) as torch_profiler:
  53. yield torch_profiler
  54. elif use_flop_counter:
  55. if cfg.max_train_step > 0 and cfg.max_train_step <= cfg.flop_counter_start:
  56. raise ValueError(f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
  57. with FlopMeasure(rank=local_rank,warmup_step=cfg.flop_counter_start) as flop_counter:
  58. yield flop_counter
  59. else:
  60. torch_profiler = contextlib.nullcontext()
  61. yield None
  62. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
  63. """
  64. Trains the model on the given dataloader
  65. Args:
  66. model: The model to be trained
  67. train_dataloader: The dataloader containing the training data
  68. optimizer: The optimizer used for training
  69. lr_scheduler: The learning rate scheduler
  70. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  71. num_epochs: The number of epochs to train for
  72. local_rank: The rank of the current node in a distributed setting
  73. train_config: The training configuration
  74. eval_dataloader: The dataloader containing the eval data
  75. tokenizer: tokenizer used in the eval for decoding the predicitons
  76. Returns: results dictionary containing average training and validation perplexity and loss
  77. """
  78. # Create a gradient scaler for fp16
  79. if train_config.use_fp16 and train_config.enable_fsdp:
  80. scaler = ShardedGradScaler()
  81. elif train_config.use_fp16 and not train_config.enable_fsdp:
  82. scaler = torch.cuda.amp.GradScaler()
  83. if train_config.enable_fsdp:
  84. world_size = int(os.environ["WORLD_SIZE"])
  85. autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
  86. train_prep = []
  87. train_loss = []
  88. val_prep = []
  89. val_loss =[]
  90. if train_config.save_metrics:
  91. if not os.path.exists(train_config.output_dir):
  92. os.makedirs(train_config.output_dir, exist_ok=True)
  93. metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
  94. train_step_perplexity = []
  95. train_step_loss = []
  96. val_step_loss = []
  97. val_step_perplexity = []
  98. epoch_times = []
  99. checkpoint_times = []
  100. results = {}
  101. best_val_loss = float("inf")
  102. total_train_steps = 0
  103. max_steps_reached = False # Flag to indicate max training steps reached
  104. # Start the training loop
  105. for epoch in range(train_config.num_epochs):
  106. print(f"Starting epoch {epoch}/{train_config.num_epochs}")
  107. print(f"train_config.max_train_step: {train_config.max_train_step}")
  108. # stop when the maximum number of training steps is reached
  109. if max_steps_reached:
  110. break
  111. epoch_start_time = time.perf_counter()
  112. with MemoryTrace() as memtrace: # track the memory usage
  113. model.train()
  114. total_loss = 0.0
  115. total_length = len(train_dataloader)//gradient_accumulation_steps
  116. pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
  117. with profile(train_config,local_rank) as profile_context:
  118. for step, batch in enumerate(train_dataloader):
  119. total_train_steps += 1
  120. # stop when the maximum number of training steps is reached
  121. if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
  122. max_steps_reached = True
  123. if not train_config.enable_fsdp or local_rank==0:
  124. print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1)
  125. break
  126. for key in batch.keys():
  127. if train_config.enable_fsdp:
  128. if is_xpu_available():
  129. batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
  130. else:
  131. batch[key] = batch[key].to(local_rank)
  132. else:
  133. if is_xpu_available():
  134. batch[key] = batch[key].to('xpu:0')
  135. elif torch.cuda.is_available():
  136. batch[key] = batch[key].to('cuda:0')
  137. with autocast():
  138. loss = model(**batch).loss
  139. loss = loss / gradient_accumulation_steps
  140. if train_config.save_metrics:
  141. train_step_loss.append(loss.detach().float().item())
  142. train_step_perplexity.append(float(torch.exp(loss.detach().float())))
  143. total_loss += loss.detach().float()
  144. if train_config.use_fp16:
  145. # if fp16 is enabled, use gradient scaler to handle gradient update
  146. scaler.scale(loss).backward()
  147. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  148. if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
  149. scaler.unscale_(optimizer)
  150. if train_config.enable_fsdp:
  151. model.clip_grad_norm_(train_config.gradient_clipping_threshold)
  152. else:
  153. torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
  154. scaler.step(optimizer)
  155. scaler.update()
  156. optimizer.zero_grad()
  157. pbar.update(1)
  158. else:
  159. # regular backpropagation when fp16 is not used
  160. loss.backward()
  161. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  162. if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
  163. if train_config.enable_fsdp:
  164. model.clip_grad_norm_(train_config.gradient_clipping_threshold)
  165. else:
  166. torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
  167. optimizer.step()
  168. optimizer.zero_grad()
  169. pbar.update(1)
  170. if train_config.use_profiler or train_config.flop_counter:
  171. profile_context.step()
  172. if train_config.flop_counter and profile_context.is_done():
  173. TFlops = profile_context.get_flops_per_sec() / 1e12
  174. if wandb_run:
  175. if not train_config.enable_fsdp or rank==0:
  176. wandb_run.log({
  177. 'train/epoch': epoch + 1,
  178. 'train/step': epoch * len(train_dataloader) + step,
  179. 'train/loss': loss.detach().float(),
  180. })
  181. pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
  182. if train_config.save_metrics:
  183. save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
  184. pbar.close()
  185. epoch_end_time = time.perf_counter()-epoch_start_time
  186. epoch_times.append(epoch_end_time)
  187. # Reducing total_loss across all devices if there's more than one CUDA device
  188. if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
  189. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  190. elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  191. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  192. train_epoch_loss = total_loss / len(train_dataloader)
  193. if train_config.enable_fsdp:
  194. train_epoch_loss = train_epoch_loss/world_size
  195. train_perplexity = torch.exp(train_epoch_loss)
  196. train_prep.append(float(train_perplexity))
  197. train_loss.append(float(train_epoch_loss))
  198. if not train_config.enable_fsdp or rank==0:
  199. memtrace.print_stats()
  200. # Update the learning rate as needed
  201. lr_scheduler.step()
  202. if train_config.run_validation:
  203. eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
  204. if train_config.save_metrics:
  205. val_step_loss.extend(temp_val_loss)
  206. val_step_perplexity.extend(temp_step_perplexity)
  207. checkpoint_start_time = time.perf_counter()
  208. if train_config.save_model and eval_epoch_loss < best_val_loss:
  209. if train_config.enable_fsdp:
  210. dist.barrier()
  211. if train_config.use_peft:
  212. if train_config.enable_fsdp:
  213. if rank==0:
  214. print(f"we are about to save the PEFT modules")
  215. else:
  216. print(f"we are about to save the PEFT modules")
  217. save_peft_checkpoint(model, train_config.output_dir)
  218. if train_config.enable_fsdp:
  219. if rank==0:
  220. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  221. else:
  222. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  223. else:
  224. if not train_config.enable_fsdp:
  225. save_model_checkpoint(model, train_config.output_dir)
  226. elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  227. print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
  228. print("=====================================================")
  229. save_fsdp_model_checkpoint_full(
  230. model, optimizer, rank, train_config, epoch=epoch
  231. )
  232. if train_config.save_optimizer:
  233. print(" Saving the FSDP optimizer using FULL_STATE_DICT")
  234. print("=====================================================")
  235. save_optimizer_checkpoint(
  236. model, optimizer, rank, train_config, epoch=epoch
  237. )
  238. elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  239. if train_config.save_optimizer:
  240. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  241. print("=====================================================")
  242. save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  243. else:
  244. print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
  245. print("=====================================================")
  246. save_model_and_optimizer_sharded(model, rank, train_config)
  247. if train_config.enable_fsdp:
  248. dist.barrier()
  249. checkpoint_end_time = time.perf_counter() - checkpoint_start_time
  250. checkpoint_times.append(checkpoint_end_time)
  251. if eval_epoch_loss < best_val_loss:
  252. best_val_loss = eval_epoch_loss
  253. if train_config.enable_fsdp:
  254. if rank==0:
  255. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  256. else:
  257. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  258. val_loss.append(float(best_val_loss))
  259. val_prep.append(float(eval_ppl))
  260. if train_config.enable_fsdp:
  261. if rank==0:
  262. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  263. else:
  264. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  265. # Saving the results every epoch to plot later
  266. if train_config.save_metrics:
  267. save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
  268. avg_epoch_time = sum(epoch_times)/ len(epoch_times)
  269. avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
  270. avg_train_prep = sum(train_prep)/len(train_prep)
  271. avg_train_loss = sum(train_loss)/len(train_loss)
  272. if train_config.run_validation:
  273. avg_eval_prep = sum(val_prep)/len(val_prep)
  274. avg_eval_loss = sum(val_loss)/len(val_loss)
  275. results['avg_train_prep'] = avg_train_prep
  276. results['avg_train_loss'] = avg_train_loss
  277. if train_config.run_validation:
  278. results['avg_eval_prep'] = avg_eval_prep
  279. results['avg_eval_loss'] = avg_eval_loss
  280. results["avg_epoch_time"] = avg_epoch_time
  281. results["avg_checkpoint_time"] = avg_checkpoint_time
  282. if train_config.save_metrics:
  283. results["metrics_filename"] = metrics_filename
  284. if train_config.flop_counter:
  285. results["model_tflops"]= TFlops
  286. #saving the training params including fsdp setting for reference.
  287. if train_config.enable_fsdp and not train_config.use_peft and rank==0:
  288. save_train_params(train_config, fsdp_config, rank)
  289. return results
  290. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
  291. """
  292. Evaluates the model on the given dataloader
  293. Args:
  294. model: The model to evaluate
  295. eval_dataloader: The dataloader containing the evaluation data
  296. local_rank: The rank of the current node in a distributed setting
  297. tokenizer: The tokenizer used to decode predictions
  298. Returns: eval_ppl, eval_epoch_loss
  299. """
  300. if train_config.enable_fsdp:
  301. world_size = int(os.environ["WORLD_SIZE"])
  302. model.eval()
  303. eval_preds = []
  304. val_step_loss = []
  305. val_step_perplexity = []
  306. eval_loss = 0.0 # Initialize evaluation loss
  307. total_eval_steps = 0
  308. with MemoryTrace() as memtrace:
  309. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
  310. total_eval_steps += 1
  311. # stop when the maximum number of eval steps is reached
  312. if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
  313. if not train_config.enable_fsdp or local_rank==0:
  314. print("max eval steps reached, stopping evaluation, total_eval_steps: ", total_eval_steps - 1)
  315. break
  316. for key in batch.keys():
  317. if train_config.enable_fsdp:
  318. batch[key] = batch[key].to(local_rank)
  319. else:
  320. if is_xpu_available():
  321. batch[key] = batch[key].to('xpu:0')
  322. else:
  323. batch[key] = batch[key].to('cuda:0')
  324. # Ensure no gradients are computed for this scope to save memory
  325. with torch.no_grad():
  326. # Forward pass and compute loss
  327. outputs = model(**batch)
  328. loss = outputs.loss
  329. if train_config.save_metrics:
  330. val_step_loss.append(loss.detach().float().item())
  331. val_step_perplexity.append(float(torch.exp(loss.detach().float())))
  332. eval_loss += loss.detach().float()
  333. # Decode predictions and add to evaluation predictions list
  334. preds = torch.argmax(outputs.logits, -1)
  335. eval_preds.extend(
  336. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  337. )
  338. # If there's more than one CUDA device, reduce evaluation loss across all devices
  339. if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
  340. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  341. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  342. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  343. # Compute average loss and perplexity
  344. eval_epoch_loss = eval_loss / len(eval_dataloader)
  345. if train_config.enable_fsdp:
  346. eval_epoch_loss = eval_epoch_loss/world_size
  347. eval_ppl = torch.exp(eval_epoch_loss)
  348. # Print evaluation metrics
  349. if train_config.enable_fsdp:
  350. if local_rank==0:
  351. print(f" {eval_ppl=} {eval_epoch_loss=}")
  352. else:
  353. print(f" {eval_ppl=} {eval_epoch_loss=}")
  354. if wandb_run:
  355. wandb_run.log({
  356. 'eval/perplexity': eval_ppl,
  357. 'eval/loss': eval_epoch_loss,
  358. }, commit=False)
  359. return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
  360. def freeze_transformer_layers(model, num_layer):
  361. for i, layer in enumerate(model.model.layers):
  362. if i < num_layer:
  363. for param in layer.parameters():
  364. param.requires_grad = False
  365. def check_frozen_layers_peft_model(model):
  366. for i, layer in enumerate(model.base_model.model.model.layers):
  367. for name, param in layer.named_parameters():
  368. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  369. def setup():
  370. """Initialize the process group for distributed training"""
  371. if is_ccl_available():
  372. # distributed training on xpus
  373. dist.init_process_group("ccl")
  374. else:
  375. dist.init_process_group("nccl")
  376. def setup_environ_flags(rank):
  377. """Set environment flags for debugging purposes"""
  378. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  379. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  380. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  381. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  382. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  383. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  384. if rank == 0:
  385. print(f"--> Running with torch dist debug set to detail")
  386. def cleanup():
  387. """Clean up the process group after training"""
  388. dist.destroy_process_group()
  389. def clear_gpu_cache(rank=None):
  390. """Clear the GPU cache for all ranks"""
  391. if rank == 0:
  392. print(f"Clearing GPU cache for all ranks")
  393. if is_xpu_available():
  394. torch.xpu_empty_cache()
  395. else:
  396. torch.cuda.empty_cache()
  397. def get_parameter_dtypes(model):
  398. """Get the data types of model parameters"""
  399. parameter_dtypes = {}
  400. for name, parameter in model.named_parameters():
  401. parameter_dtypes[name] = parameter.dtype
  402. return parameter_dtypes
  403. def print_model_size(model, config, rank: int = 0) -> None:
  404. """
  405. Print model name, the number of trainable parameters and initialization time.
  406. Args:
  407. model: The PyTorch model.
  408. model_name (str): Name of the model.
  409. init_time_start (float): Initialization start time.
  410. init_time_end (float): Initialization end time.
  411. rank (int, optional): Current process's rank. Defaults to 0.
  412. """
  413. if rank == 0:
  414. print(f"--> Model {config.model_name}")
  415. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  416. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  417. def get_policies(cfg, rank):
  418. """Get the policies for mixed precision and fsdp wrapping"""
  419. verify_bfloat_support = ((
  420. torch.version.cuda
  421. and torch.cuda.is_bf16_supported()
  422. and torch.version.cuda >= "11.0"
  423. and dist.is_nccl_available()
  424. and nccl.version() >= (2, 10)
  425. ) or
  426. (is_xpu_available()))
  427. mixed_precision_policy = None
  428. wrapping_policy = None
  429. # Mixed precision
  430. if cfg.mixed_precision:
  431. bf16_ready = verify_bfloat_support
  432. if bf16_ready and not cfg.use_fp16:
  433. mixed_precision_policy = bfSixteen
  434. if rank == 0:
  435. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  436. elif cfg.use_fp16:
  437. mixed_precision_policy = fpSixteen
  438. if rank == 0:
  439. print(f"FP16 enabled")
  440. else:
  441. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  442. wrapping_policy = get_llama_wrapper()
  443. return mixed_precision_policy, wrapping_policy
  444. def save_train_params(train_config, fsdp_config, rank):
  445. """
  446. This function saves the train_config and FSDP config into a train_params.yaml.
  447. This will be used by converter script in the inference folder to fetch the HF model name or path.
  448. It also would be hepful as a log for future references.
  449. """
  450. # Convert the train_config and fsdp_config objects to dictionaries,
  451. # converting all values to strings to ensure they can be serialized into a YAML file
  452. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  453. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  454. # Merge the two dictionaries into one
  455. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  456. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  457. folder_name = (
  458. train_config.dist_checkpoint_root_folder
  459. + "/"
  460. + train_config.dist_checkpoint_folder
  461. + "-"
  462. + train_config.model_name
  463. )
  464. save_dir = Path.cwd() / folder_name
  465. # If the directory does not exist, create it
  466. if not os.path.exists(save_dir):
  467. os.makedirs(save_dir)
  468. # Convert the dictionary to a YAML string
  469. config_yaml = yaml.dump(train_params_dict, indent=4)
  470. file_name = os.path.join(save_dir,'train_params.yaml')
  471. # Check if there's a directory with the same name as the file
  472. if os.path.isdir(file_name):
  473. print(f"Error: {file_name} is a directory, not a file.")
  474. else:
  475. # Write the YAML string to the file
  476. with open(file_name, 'w') as f:
  477. f.write(config_yaml)
  478. if rank==0:
  479. print(f"training params are saved in {file_name}")
  480. def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
  481. metrics_data = {
  482. "train_step_loss": train_step_loss,
  483. "train_epoch_loss": train_epoch_loss,
  484. "train_step_perplexity": train_step_ppl,
  485. "train_epoch_perplexity": train_epoch_ppl,
  486. "val_step_loss": val_step_loss,
  487. "val_epoch_loss": val_epoch_loss,
  488. "val_step_perplexity": val_step_ppl,
  489. "val_epoch_perplexity": val_epoch_ppl
  490. }
  491. with open(output_filename, "w") as f:
  492. json.dump(metrics_data, f)