|
@@ -31,7 +31,7 @@ def set_tokenizer_params(tokenizer: LlamaTokenizer):
|
|
|
tokenizer.padding_side = "left"
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
-def throughput_measure_context(cfg, local_rank=None):
|
|
|
+def profile(cfg, local_rank=None):
|
|
|
use_profiler: bool = cfg.use_profiler
|
|
|
use_flop_counter: bool = cfg.flop_counter
|
|
|
if use_flop_counter and use_profiler:
|
|
@@ -41,7 +41,7 @@ def throughput_measure_context(cfg, local_rank=None):
|
|
|
wait_step, warmup_step, active_step = 1, 2, 3
|
|
|
min_step = wait_step + warmup_step + active_step + 1
|
|
|
if cfg.max_train_step > 0 and cfg.max_train_step < min_step:
|
|
|
- raise ValueError(f"pytorch profiler requires at least {min_step} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
|
|
|
+ raise ValueError(f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
|
|
|
print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}")
|
|
|
with torch.profiler.profile(
|
|
|
activities=[
|
|
@@ -97,7 +97,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
|
|
|
|
|
|
autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
|
|
|
-
|
|
|
train_prep = []
|
|
|
train_loss = []
|
|
|
val_prep = []
|
|
@@ -127,7 +126,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
total_loss = 0.0
|
|
|
total_length = len(train_dataloader)//gradient_accumulation_steps
|
|
|
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
|
|
|
- with throughput_measure_context(train_config,local_rank) as measure_context:
|
|
|
+ with profile(train_config,local_rank) as profile_context:
|
|
|
for step, batch in enumerate(train_dataloader):
|
|
|
total_train_steps += 1
|
|
|
# stop when the maximum number of training steps is reached
|
|
@@ -138,7 +137,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
break
|
|
|
if train_config.flop_counter and total_train_steps == train_config.flop_counter_start:
|
|
|
print("start flop counting at the step: ", total_train_steps)
|
|
|
- measure_context.start_counting()
|
|
|
+ profile_context.start_counting()
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
|
if is_xpu_available():
|
|
@@ -185,10 +184,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
optimizer.zero_grad()
|
|
|
pbar.update(1)
|
|
|
if train_config.use_profiler:
|
|
|
- measure_context.step()
|
|
|
- if train_config.flop_counter and measure_context.is_ready():
|
|
|
- TFlops = measure_context.get_total_flops() / 1e12
|
|
|
- measure_context.stop_counting()
|
|
|
+ profile_context.step()
|
|
|
+ if train_config.flop_counter and profile_context.is_ready():
|
|
|
+ TFlops = profile_context.get_total_flops() / 1e12
|
|
|
+ profile_context.stop_counting()
|
|
|
if wandb_run:
|
|
|
if not train_config.enable_fsdp or rank==0:
|
|
|
wandb_run.log({
|