|
@@ -38,23 +38,30 @@ def throughput_measure_context(cfg, local_rank=None):
|
|
|
if use_flop_counter and use_profiler:
|
|
|
raise ValueError("Cannot use both profiler and flop counter")
|
|
|
if use_profiler:
|
|
|
- print(f"profiling is activated and results will be saved in {cfg.profiler_dir}")
|
|
|
+ # profiler needs a warmup stage to get the accurate profiling results
|
|
|
+ wait_step, warmup_step, active_step = 1, 2, 3
|
|
|
+ min_step = wait_step + warmup_step + active_step + 1
|
|
|
+ if cfg.max_train_step > 0 and cfg.max_train_step < min_step:
|
|
|
+ raise ValueError(f"pytorch profiler requires at least {min_step} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
|
|
|
+ print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}")
|
|
|
with torch.profiler.profile(
|
|
|
activities=[
|
|
|
torch.profiler.ProfilerActivity.CPU,
|
|
|
torch.profiler.ProfilerActivity.CUDA,
|
|
|
],
|
|
|
- schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1),
|
|
|
+ schedule=torch.profiler.schedule(wait=wait_step, warmup=warmup_step, active=active_step, repeat=1),
|
|
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
|
|
cfg.profiler_dir
|
|
|
),
|
|
|
profile_memory=True,
|
|
|
with_stack=False,
|
|
|
+ with_flops=True,
|
|
|
record_shapes=True,
|
|
|
) as torch_profiler:
|
|
|
yield torch_profiler
|
|
|
elif use_flop_counter:
|
|
|
- torch_profiler = contextlib.nullcontext()
|
|
|
+ if cfg.max_train_step > 0 and cfg.max_train_step < cfg.flop_counter_startpoint:
|
|
|
+ raise ValueError(f"flop counter requires at least {cfg.flop_counter_startpoint} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
|
|
|
with FlopCounterMode(rank=local_rank) as flop_counter:
|
|
|
yield flop_counter
|
|
|
else:
|
|
@@ -134,7 +141,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
|
|
|
max_steps_reached = True
|
|
|
if not train_config.enable_fsdp or local_rank==0:
|
|
|
- print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
|
|
|
+ print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1)
|
|
|
break
|
|
|
if train_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint:
|
|
|
print("start flop counting at the step: ", total_train_steps)
|
|
@@ -184,6 +191,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
pbar.update(1)
|
|
|
+ if train_config.use_profiler:
|
|
|
+ measure_context.step()
|
|
|
if train_config.flop_counter and measure_context.is_ready():
|
|
|
TFlops = get_total_flops(measure_context) / 1e12
|
|
|
measure_context.stop_counting()
|
|
@@ -221,7 +230,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
|
|
|
# Update the learning rate as needed
|
|
|
lr_scheduler.step()
|
|
|
-
|
|
|
if train_config.run_validation:
|
|
|
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
|
|
|
if train_config.save_metrics:
|