|
@@ -38,7 +38,7 @@ def throughput_measure_context(cfg, local_rank=None):
|
|
|
if use_flop_counter and use_profiler:
|
|
|
raise ValueError("Cannot use both profiler and flop counter")
|
|
|
if use_profiler:
|
|
|
- print(f"profiling is activated and results will be saved in {cfg.profile_dir}")
|
|
|
+ print(f"profiling is activated and results will be saved in {cfg.profiler_dir}")
|
|
|
with torch.profiler.profile(
|
|
|
activities=[
|
|
|
torch.profiler.ProfilerActivity.CPU,
|
|
@@ -46,7 +46,7 @@ def throughput_measure_context(cfg, local_rank=None):
|
|
|
],
|
|
|
schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1),
|
|
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
|
|
- cfg.profile_dir
|
|
|
+ cfg.profiler_dir
|
|
|
),
|
|
|
profile_memory=True,
|
|
|
with_stack=False,
|
|
@@ -138,11 +138,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
break
|
|
|
if train_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint:
|
|
|
print("start flop counting at the step: ", total_train_steps)
|
|
|
- measure_context.start()
|
|
|
- if train_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint + 1:
|
|
|
- print("stop flop counting at the step: ", total_train_steps)
|
|
|
- TFlops = get_total_flops(flop_counter) / 1e12
|
|
|
- measure_context.stop()
|
|
|
+ measure_context.start_counting()
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
|
if is_xpu_available():
|
|
@@ -188,7 +184,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
pbar.update(1)
|
|
|
-
|
|
|
+ if train_config.flop_counter and measure_context.is_ready():
|
|
|
+ TFlops = get_total_flops(measure_context) / 1e12
|
|
|
+ measure_context.stop_counting()
|
|
|
if wandb_run:
|
|
|
if not train_config.enable_fsdp or rank==0:
|
|
|
wandb_run.log({
|