瀏覽代碼

Improve model checkpoint saving logic (#691)

Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>
Lucas Ventura 6 月之前
父節點
當前提交
2774065891
共有 1 個文件被更改,包括 49 次插入45 次删除
  1. 49 45
      src/llama_recipes/utils/train_utils.py

+ 49 - 45
src/llama_recipes/utils/train_utils.py

@@ -220,70 +220,74 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
         # Update the learning rate as needed
         lr_scheduler.step()
+        should_save_model = train_config.save_model
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
             if train_config.save_metrics:
                 val_step_loss.extend(temp_val_loss)
                 val_step_perplexity.extend(temp_step_perplexity)
-
-            checkpoint_start_time = time.perf_counter()
-            if train_config.save_model and eval_epoch_loss < best_val_loss:
+            should_save_model = train_config.save_model and eval_epoch_loss < best_val_loss
+        
+        checkpoint_start_time = time.perf_counter()
+        if should_save_model:
+            if train_config.enable_fsdp:
+                dist.barrier()
+            if train_config.use_peft:
                 if train_config.enable_fsdp:
-                    dist.barrier()
-                if train_config.use_peft:
-                    if train_config.enable_fsdp:
-                        if rank==0:
-                            print(f"we are about to save the PEFT modules")
-                    else:
+                    if rank==0:
                         print(f"we are about to save the PEFT modules")
-                    save_peft_checkpoint(model, train_config.output_dir)
-                    if train_config.enable_fsdp:
-                        if rank==0:
-                            print(f"PEFT modules are saved in {train_config.output_dir} directory")
-                    else:
+                else:
+                    print(f"we are about to save the PEFT modules")
+                save_peft_checkpoint(model, train_config.output_dir)
+                if train_config.enable_fsdp:
+                    if rank==0:
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
-
                 else:
-                    if not train_config.enable_fsdp:
-                        save_model_checkpoint(model, train_config.output_dir)
-                        
-                    elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
-                        print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
+                    print(f"PEFT modules are saved in {train_config.output_dir} directory")
+
+            else:
+                if not train_config.enable_fsdp:
+                    save_model_checkpoint(model, train_config.output_dir)
+                    
+                elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
+                    print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
+                    print("=====================================================")
+                    save_fsdp_model_checkpoint_full(
+                        model, optimizer, rank, train_config, epoch=epoch
+                    )
+                    
+                    if train_config.save_optimizer:
+                        print(" Saving the FSDP optimizer using FULL_STATE_DICT")
                         print("=====================================================")
-                        save_fsdp_model_checkpoint_full(
+                        save_optimizer_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                         )
-                        
-                        if train_config.save_optimizer:
-                            print(" Saving the FSDP optimizer using FULL_STATE_DICT")
-                            print("=====================================================")
-                            save_optimizer_checkpoint(
-                                model, optimizer, rank, train_config, epoch=epoch
-                            )
-                        
-                    elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
-
-                        if train_config.save_optimizer:
-                            print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
-                            print("=====================================================")
-                            save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
-                        else:
-                            print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
-                            print("=====================================================")
-                            save_model_and_optimizer_sharded(model, rank, train_config)
+                    
+                elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
 
-                        
-                if train_config.enable_fsdp:
-                    dist.barrier()
-            checkpoint_end_time = time.perf_counter() - checkpoint_start_time
-            checkpoint_times.append(checkpoint_end_time)
+                    if train_config.save_optimizer:
+                        print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
+                        print("=====================================================")
+                        save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
+                    else:
+                        print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
+                        print("=====================================================")
+                        save_model_and_optimizer_sharded(model, rank, train_config)
+
+                    
+            if train_config.enable_fsdp:
+                dist.barrier()
+        checkpoint_end_time = time.perf_counter() - checkpoint_start_time
+        checkpoint_times.append(checkpoint_end_time)
+
+        if train_config.run_validation:
             if eval_epoch_loss < best_val_loss:
                 best_val_loss = eval_epoch_loss
                 if train_config.enable_fsdp:
                     if rank==0:
                         print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
                 else:
-                    print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
+                        print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
             val_loss.append(float(best_val_loss))
             val_prep.append(float(eval_ppl))
         if train_config.enable_fsdp: