train_utils.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import time
  5. import yaml
  6. from contextlib import nullcontext
  7. from pathlib import Path
  8. from pkg_resources import packaging
  9. from datetime import datetime
  10. import contextlib
  11. import torch
  12. import torch.cuda.nccl as nccl
  13. import torch.distributed as dist
  14. from torch.distributed.fsdp import StateDictType
  15. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  16. from tqdm import tqdm
  17. from transformers import LlamaTokenizer
  18. import json
  19. from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
  20. from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
  21. from llama_recipes.utils.memory_utils import MemoryTrace
  22. from accelerate.utils import is_xpu_available, is_ccl_available
  23. from llama_recipes.utils.flop_utils import FlopMeasure
  24. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  25. tokenizer.pad_token_id = 0
  26. tokenizer.padding_side = "left"
  27. @contextlib.contextmanager
  28. def profile(cfg, local_rank=None):
  29. use_profiler: bool = cfg.use_profiler
  30. use_flop_counter: bool = cfg.flop_counter
  31. if use_flop_counter and use_profiler:
  32. raise ValueError("Cannot use both profiler and flop counter")
  33. if use_profiler:
  34. # profiler needs a warmup stage to get the accurate profiling results
  35. wait_step, warmup_step, active_step = 1, 2, 3
  36. min_step = wait_step + warmup_step + active_step + 1
  37. if cfg.max_train_step > 0 and cfg.max_train_step < min_step:
  38. raise ValueError(f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
  39. print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}")
  40. with torch.profiler.profile(
  41. activities=[
  42. torch.profiler.ProfilerActivity.CPU,
  43. torch.profiler.ProfilerActivity.CUDA,
  44. ],
  45. schedule=torch.profiler.schedule(wait=wait_step, warmup=warmup_step, active=active_step, repeat=1),
  46. on_trace_ready=torch.profiler.tensorboard_trace_handler(
  47. cfg.profiler_dir
  48. ),
  49. profile_memory=True,
  50. with_stack=False,
  51. with_flops=True,
  52. record_shapes=True,
  53. ) as torch_profiler:
  54. yield torch_profiler
  55. elif use_flop_counter:
  56. if cfg.max_train_step > 0 and cfg.max_train_step <= cfg.flop_counter_start:
  57. raise ValueError(f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
  58. with FlopMeasure(rank=local_rank,warmup_step=cfg.flop_counter_start) as flop_counter:
  59. yield flop_counter
  60. else:
  61. torch_profiler = contextlib.nullcontext()
  62. yield None
  63. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
  64. """
  65. Trains the model on the given dataloader
  66. Args:
  67. model: The model to be trained
  68. train_dataloader: The dataloader containing the training data
  69. optimizer: The optimizer used for training
  70. lr_scheduler: The learning rate scheduler
  71. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  72. num_epochs: The number of epochs to train for
  73. local_rank: The rank of the current node in a distributed setting
  74. train_config: The training configuration
  75. eval_dataloader: The dataloader containing the eval data
  76. tokenizer: tokenizer used in the eval for decoding the predicitons
  77. Returns: results dictionary containing average training and validation perplexity and loss
  78. """
  79. # Create a gradient scaler for fp16
  80. if train_config.use_fp16 and train_config.enable_fsdp:
  81. scaler = ShardedGradScaler()
  82. elif train_config.use_fp16 and not train_config.enable_fsdp:
  83. scaler = torch.cuda.amp.GradScaler()
  84. if train_config.enable_fsdp:
  85. world_size = int(os.environ["WORLD_SIZE"])
  86. autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
  87. train_prep = []
  88. train_loss = []
  89. val_prep = []
  90. val_loss =[]
  91. if train_config.save_metrics:
  92. if not os.path.exists(train_config.output_dir):
  93. os.makedirs(train_config.output_dir, exist_ok=True)
  94. metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
  95. train_step_perplexity = []
  96. train_step_loss = []
  97. val_step_loss = []
  98. val_step_perplexity = []
  99. epoch_times = []
  100. checkpoint_times = []
  101. results = {}
  102. best_val_loss = float("inf")
  103. total_train_steps = 0
  104. max_steps_reached = False # Flag to indicate max training steps reached
  105. # Start the training loop
  106. for epoch in range(train_config.num_epochs):
  107. # stop when the maximum number of training steps is reached
  108. if max_steps_reached:
  109. break
  110. epoch_start_time = time.perf_counter()
  111. with MemoryTrace() as memtrace: # track the memory usage
  112. model.train()
  113. total_loss = 0.0
  114. total_length = len(train_dataloader)//gradient_accumulation_steps
  115. pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
  116. with profile(train_config,local_rank) as profile_context:
  117. for step, batch in enumerate(train_dataloader):
  118. total_train_steps += 1
  119. # stop when the maximum number of training steps is reached
  120. if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
  121. max_steps_reached = True
  122. if not train_config.enable_fsdp or local_rank==0:
  123. print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1)
  124. break
  125. for key in batch.keys():
  126. if train_config.enable_fsdp:
  127. if is_xpu_available():
  128. batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
  129. else:
  130. batch[key] = batch[key].to(local_rank)
  131. else:
  132. if is_xpu_available():
  133. batch[key] = batch[key].to('xpu:0')
  134. else:
  135. batch[key] = batch[key].to('cuda:0')
  136. with autocast():
  137. loss = model(**batch).loss
  138. loss = loss / gradient_accumulation_steps
  139. if train_config.save_metrics:
  140. train_step_loss.append(loss.detach().float().item())
  141. train_step_perplexity.append(float(torch.exp(loss.detach().float())))
  142. total_loss += loss.detach().float()
  143. if train_config.use_fp16:
  144. # if fp16 is enabled, use gradient scaler to handle gradient update
  145. scaler.scale(loss).backward()
  146. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  147. if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
  148. scaler.unscale_(optimizer)
  149. if train_config.enable_fsdp:
  150. model.clip_grad_norm_(train_config.gradient_clipping_threshold)
  151. else:
  152. torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
  153. scaler.step(optimizer)
  154. scaler.update()
  155. optimizer.zero_grad()
  156. pbar.update(1)
  157. else:
  158. # regular backpropagation when fp16 is not used
  159. loss.backward()
  160. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  161. if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
  162. if train_config.enable_fsdp:
  163. model.clip_grad_norm_(train_config.gradient_clipping_threshold)
  164. else:
  165. torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
  166. optimizer.step()
  167. optimizer.zero_grad()
  168. pbar.update(1)
  169. if train_config.use_profiler or train_config.flop_counter:
  170. profile_context.step()
  171. if train_config.flop_counter and profile_context.is_done():
  172. TFlops = profile_context.get_flops_per_sec() / 1e12
  173. if wandb_run:
  174. if not train_config.enable_fsdp or rank==0:
  175. wandb_run.log({
  176. 'train/epoch': epoch + 1,
  177. 'train/step': epoch * len(train_dataloader) + step,
  178. 'train/loss': loss.detach().float(),
  179. })
  180. pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
  181. if train_config.save_metrics:
  182. save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
  183. pbar.close()
  184. epoch_end_time = time.perf_counter()-epoch_start_time
  185. epoch_times.append(epoch_end_time)
  186. # Reducing total_loss across all devices if there's more than one CUDA device
  187. if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
  188. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  189. elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  190. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  191. train_epoch_loss = total_loss / len(train_dataloader)
  192. if train_config.enable_fsdp:
  193. train_epoch_loss = train_epoch_loss/world_size
  194. train_perplexity = torch.exp(train_epoch_loss)
  195. train_prep.append(float(train_perplexity))
  196. train_loss.append(float(train_epoch_loss))
  197. if not train_config.enable_fsdp or rank==0:
  198. memtrace.print_stats()
  199. # Update the learning rate as needed
  200. lr_scheduler.step()
  201. if train_config.run_validation:
  202. eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
  203. if train_config.save_metrics:
  204. val_step_loss.extend(temp_val_loss)
  205. val_step_perplexity.extend(temp_step_perplexity)
  206. checkpoint_start_time = time.perf_counter()
  207. if train_config.save_model and eval_epoch_loss < best_val_loss:
  208. if train_config.enable_fsdp:
  209. dist.barrier()
  210. if train_config.use_peft:
  211. if train_config.enable_fsdp:
  212. if rank==0:
  213. print(f"we are about to save the PEFT modules")
  214. else:
  215. print(f"we are about to save the PEFT modules")
  216. model.save_pretrained(train_config.output_dir)
  217. if train_config.enable_fsdp:
  218. if rank==0:
  219. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  220. else:
  221. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  222. else:
  223. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  224. save_model_checkpoint(
  225. model, optimizer, rank, train_config, epoch=epoch
  226. )
  227. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  228. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  229. print("=====================================================")
  230. save_model_and_optimizer_sharded(model, rank, train_config)
  231. if train_config.save_optimizer:
  232. save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  233. print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
  234. print("=====================================================")
  235. if not train_config.use_peft and train_config.save_optimizer:
  236. save_optimizer_checkpoint(
  237. model, optimizer, rank, train_config, epoch=epoch
  238. )
  239. print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
  240. print("=====================================================")
  241. if train_config.enable_fsdp:
  242. dist.barrier()
  243. checkpoint_end_time = time.perf_counter() - checkpoint_start_time
  244. checkpoint_times.append(checkpoint_end_time)
  245. if eval_epoch_loss < best_val_loss:
  246. best_val_loss = eval_epoch_loss
  247. if train_config.enable_fsdp:
  248. if rank==0:
  249. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  250. else:
  251. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  252. val_loss.append(float(best_val_loss))
  253. val_prep.append(float(eval_ppl))
  254. if train_config.enable_fsdp:
  255. if rank==0:
  256. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  257. else:
  258. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  259. # Saving the results every epoch to plot later
  260. if train_config.save_metrics:
  261. save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
  262. avg_epoch_time = sum(epoch_times)/ len(epoch_times)
  263. avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
  264. avg_train_prep = sum(train_prep)/len(train_prep)
  265. avg_train_loss = sum(train_loss)/len(train_loss)
  266. if train_config.run_validation:
  267. avg_eval_prep = sum(val_prep)/len(val_prep)
  268. avg_eval_loss = sum(val_loss)/len(val_loss)
  269. results['avg_train_prep'] = avg_train_prep
  270. results['avg_train_loss'] = avg_train_loss
  271. if train_config.run_validation:
  272. results['avg_eval_prep'] = avg_eval_prep
  273. results['avg_eval_loss'] = avg_eval_loss
  274. results["avg_epoch_time"] = avg_epoch_time
  275. results["avg_checkpoint_time"] = avg_checkpoint_time
  276. if train_config.save_metrics:
  277. results["metrics_filename"] = metrics_filename
  278. if train_config.flop_counter:
  279. results["model_tflops"]= TFlops
  280. #saving the training params including fsdp setting for reference.
  281. if train_config.enable_fsdp and not train_config.use_peft and rank==0:
  282. save_train_params(train_config, fsdp_config, rank)
  283. return results
  284. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
  285. """
  286. Evaluates the model on the given dataloader
  287. Args:
  288. model: The model to evaluate
  289. eval_dataloader: The dataloader containing the evaluation data
  290. local_rank: The rank of the current node in a distributed setting
  291. tokenizer: The tokenizer used to decode predictions
  292. Returns: eval_ppl, eval_epoch_loss
  293. """
  294. if train_config.enable_fsdp:
  295. world_size = int(os.environ["WORLD_SIZE"])
  296. model.eval()
  297. eval_preds = []
  298. val_step_loss = []
  299. val_step_perplexity = []
  300. eval_loss = 0.0 # Initialize evaluation loss
  301. total_eval_steps = 0
  302. with MemoryTrace() as memtrace:
  303. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
  304. total_eval_steps += 1
  305. # stop when the maximum number of eval steps is reached
  306. if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
  307. if not train_config.enable_fsdp or local_rank==0:
  308. print("max eval steps reached, stopping evaluation, total_eval_steps: ", total_eval_steps - 1)
  309. break
  310. for key in batch.keys():
  311. if train_config.enable_fsdp:
  312. batch[key] = batch[key].to(local_rank)
  313. else:
  314. if is_xpu_available():
  315. batch[key] = batch[key].to('xpu:0')
  316. else:
  317. batch[key] = batch[key].to('cuda:0')
  318. # Ensure no gradients are computed for this scope to save memory
  319. with torch.no_grad():
  320. # Forward pass and compute loss
  321. outputs = model(**batch)
  322. loss = outputs.loss
  323. if train_config.save_metrics:
  324. val_step_loss.append(loss.detach().float().item())
  325. val_step_perplexity.append(float(torch.exp(loss.detach().float())))
  326. eval_loss += loss.detach().float()
  327. # Decode predictions and add to evaluation predictions list
  328. preds = torch.argmax(outputs.logits, -1)
  329. eval_preds.extend(
  330. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  331. )
  332. # If there's more than one CUDA device, reduce evaluation loss across all devices
  333. if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
  334. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  335. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  336. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  337. # Compute average loss and perplexity
  338. eval_epoch_loss = eval_loss / len(eval_dataloader)
  339. if train_config.enable_fsdp:
  340. eval_epoch_loss = eval_epoch_loss/world_size
  341. eval_ppl = torch.exp(eval_epoch_loss)
  342. # Print evaluation metrics
  343. if train_config.enable_fsdp:
  344. if local_rank==0:
  345. print(f" {eval_ppl=} {eval_epoch_loss=}")
  346. else:
  347. print(f" {eval_ppl=} {eval_epoch_loss=}")
  348. if wandb_run:
  349. wandb_run.log({
  350. 'eval/perplexity': eval_ppl,
  351. 'eval/loss': eval_epoch_loss,
  352. }, commit=False)
  353. return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
  354. def freeze_transformer_layers(model, num_layer):
  355. for i, layer in enumerate(model.model.layers):
  356. if i < num_layer:
  357. for param in layer.parameters():
  358. param.requires_grad = False
  359. def check_frozen_layers_peft_model(model):
  360. for i, layer in enumerate(model.base_model.model.model.layers):
  361. for name, param in layer.named_parameters():
  362. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  363. def setup():
  364. """Initialize the process group for distributed training"""
  365. if is_ccl_available():
  366. # distributed training on xpus
  367. dist.init_process_group("ccl")
  368. else:
  369. dist.init_process_group("nccl")
  370. def setup_environ_flags(rank):
  371. """Set environment flags for debugging purposes"""
  372. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  373. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  374. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  375. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  376. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  377. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  378. if rank == 0:
  379. print(f"--> Running with torch dist debug set to detail")
  380. def cleanup():
  381. """Clean up the process group after training"""
  382. dist.destroy_process_group()
  383. def clear_gpu_cache(rank=None):
  384. """Clear the GPU cache for all ranks"""
  385. if rank == 0:
  386. print(f"Clearing GPU cache for all ranks")
  387. if is_xpu_available():
  388. torch.xpu_empty_cache()
  389. else:
  390. torch.cuda.empty_cache()
  391. def get_parameter_dtypes(model):
  392. """Get the data types of model parameters"""
  393. parameter_dtypes = {}
  394. for name, parameter in model.named_parameters():
  395. parameter_dtypes[name] = parameter.dtype
  396. return parameter_dtypes
  397. def print_model_size(model, config, rank: int = 0) -> None:
  398. """
  399. Print model name, the number of trainable parameters and initialization time.
  400. Args:
  401. model: The PyTorch model.
  402. model_name (str): Name of the model.
  403. init_time_start (float): Initialization start time.
  404. init_time_end (float): Initialization end time.
  405. rank (int, optional): Current process's rank. Defaults to 0.
  406. """
  407. if rank == 0:
  408. print(f"--> Model {config.model_name}")
  409. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  410. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  411. def get_policies(cfg, rank):
  412. """Get the policies for mixed precision and fsdp wrapping"""
  413. verify_bfloat_support = ((
  414. torch.version.cuda
  415. and torch.cuda.is_bf16_supported()
  416. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  417. and dist.is_nccl_available()
  418. and nccl.version() >= (2, 10)
  419. ) or
  420. (is_xpu_available()))
  421. mixed_precision_policy = None
  422. wrapping_policy = None
  423. # Mixed precision
  424. if cfg.mixed_precision:
  425. bf16_ready = verify_bfloat_support
  426. if bf16_ready and not cfg.use_fp16:
  427. mixed_precision_policy = bfSixteen
  428. if rank == 0:
  429. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  430. elif cfg.use_fp16:
  431. mixed_precision_policy = fpSixteen
  432. if rank == 0:
  433. print(f"FP16 enabled")
  434. else:
  435. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  436. wrapping_policy = get_llama_wrapper()
  437. return mixed_precision_policy, wrapping_policy
  438. def save_train_params(train_config, fsdp_config, rank):
  439. """
  440. This function saves the train_config and FSDP config into a train_params.yaml.
  441. This will be used by converter script in the inference folder to fetch the HF model name or path.
  442. It also would be hepful as a log for future references.
  443. """
  444. # Convert the train_config and fsdp_config objects to dictionaries,
  445. # converting all values to strings to ensure they can be serialized into a YAML file
  446. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  447. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  448. # Merge the two dictionaries into one
  449. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  450. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  451. folder_name = (
  452. train_config.dist_checkpoint_root_folder
  453. + "/"
  454. + train_config.dist_checkpoint_folder
  455. + "-"
  456. + train_config.model_name
  457. )
  458. save_dir = Path.cwd() / folder_name
  459. # If the directory does not exist, create it
  460. if not os.path.exists(save_dir):
  461. os.makedirs(save_dir)
  462. # Convert the dictionary to a YAML string
  463. config_yaml = yaml.dump(train_params_dict, indent=4)
  464. file_name = os.path.join(save_dir,'train_params.yaml')
  465. # Check if there's a directory with the same name as the file
  466. if os.path.isdir(file_name):
  467. print(f"Error: {file_name} is a directory, not a file.")
  468. else:
  469. # Write the YAML string to the file
  470. with open(file_name, 'w') as f:
  471. f.write(config_yaml)
  472. if rank==0:
  473. print(f"training params are saved in {file_name}")
  474. def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
  475. metrics_data = {
  476. "train_step_loss": train_step_loss,
  477. "train_epoch_loss": train_epoch_loss,
  478. "train_step_perplexity": train_step_ppl,
  479. "train_epoch_perplexity": train_epoch_ppl,
  480. "val_step_loss": val_step_loss,
  481. "val_epoch_loss": val_epoch_loss,
  482. "val_step_perplexity": val_step_ppl,
  483. "val_epoch_perplexity": val_epoch_ppl
  484. }
  485. with open(output_filename, "w") as f:
  486. json.dump(metrics_data, f)