Selaa lähdekoodia

handling incorrect profiling early stop caused by max_train_steps and add profiler.step() for each train step

Kai Wu 1 vuosi sitten
vanhempi
commit
69e46887b4
1 muutettua tiedostoa jossa 13 lisäystä ja 5 poistoa
  1. 13 5
      src/llama_recipes/utils/train_utils.py

+ 13 - 5
src/llama_recipes/utils/train_utils.py

@@ -38,23 +38,30 @@ def throughput_measure_context(cfg, local_rank=None):
     if use_flop_counter and use_profiler:
         raise ValueError("Cannot use both profiler and flop counter")
     if use_profiler:
-        print(f"profiling is activated and results will be saved in {cfg.profiler_dir}")
+        # profiler needs a warmup stage to get the accurate profiling results
+        wait_step, warmup_step, active_step = 1, 2, 3
+        min_step = wait_step + warmup_step + active_step + 1
+        if cfg.max_train_step > 0 and cfg.max_train_step < min_step:
+            raise ValueError(f"pytorch profiler requires at least {min_step} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
+        print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}")
         with torch.profiler.profile(
             activities=[
                 torch.profiler.ProfilerActivity.CPU,
                 torch.profiler.ProfilerActivity.CUDA,
             ],
-            schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1),
+            schedule=torch.profiler.schedule(wait=wait_step, warmup=warmup_step, active=active_step, repeat=1),
             on_trace_ready=torch.profiler.tensorboard_trace_handler(
                 cfg.profiler_dir
             ),
             profile_memory=True,
             with_stack=False,
+            with_flops=True,
             record_shapes=True,
         ) as torch_profiler:
             yield torch_profiler
     elif use_flop_counter:
-        torch_profiler = contextlib.nullcontext()
+        if cfg.max_train_step > 0 and cfg.max_train_step < cfg.flop_counter_startpoint:
+            raise ValueError(f"flop counter requires at least {cfg.flop_counter_startpoint} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
         with FlopCounterMode(rank=local_rank) as flop_counter:
             yield flop_counter
     else:
@@ -134,7 +141,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
                         max_steps_reached = True
                         if not train_config.enable_fsdp or local_rank==0:
-                            print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
+                            print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1)
                         break
                     if train_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint:
                         print("start flop counting at the step: ", total_train_steps)
@@ -184,6 +191,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             optimizer.step()
                             optimizer.zero_grad()
                             pbar.update(1)
+                    if train_config.use_profiler:
+                        measure_context.step()
                     if train_config.flop_counter and measure_context.is_ready():
                         TFlops = get_total_flops(measure_context) / 1e12
                         measure_context.stop_counting()
@@ -221,7 +230,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
         # Update the learning rate as needed
         lr_scheduler.step()
-
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
             if train_config.save_metrics: