|
@@ -8,6 +8,7 @@ from contextlib import nullcontext
|
|
|
from pathlib import Path
|
|
|
from pkg_resources import packaging
|
|
|
from datetime import datetime
|
|
|
+import contextlib
|
|
|
|
|
|
|
|
|
import torch
|
|
@@ -55,7 +56,7 @@ def throughput_measure_context(cfg, local_rank=None):
|
|
|
elif use_flop_counter:
|
|
|
torch_profiler = contextlib.nullcontext()
|
|
|
with FlopCounterMode(rank=local_rank) as flop_counter:
|
|
|
- yeild flop_counter
|
|
|
+ yield flop_counter
|
|
|
else:
|
|
|
torch_profiler = contextlib.nullcontext()
|
|
|
yield None
|
|
@@ -135,10 +136,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if not train_config.enable_fsdp or local_rank==0:
|
|
|
print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
|
|
|
break
|
|
|
- if traing_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint:
|
|
|
+ if train_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint:
|
|
|
print("start flop counting at the step: ", total_train_steps)
|
|
|
measure_context.start()
|
|
|
- if traing_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint + 1:
|
|
|
+ if train_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint + 1:
|
|
|
print("stop flop counting at the step: ", total_train_steps)
|
|
|
TFlops = get_total_flops(flop_counter) / 1e12
|
|
|
measure_context.stop()
|