|
@@ -8,6 +8,7 @@ from contextlib import nullcontext
|
|
|
from pathlib import Path
|
|
|
from pkg_resources import packaging
|
|
|
from datetime import datetime
|
|
|
+import contextlib
|
|
|
|
|
|
|
|
|
import torch
|
|
@@ -24,14 +25,48 @@ from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_
|
|
|
from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
|
|
|
from llama_recipes.utils.memory_utils import MemoryTrace
|
|
|
from accelerate.utils import is_xpu_available, is_ccl_available
|
|
|
-
|
|
|
+from llama_recipes.utils.flop_utils import FlopMeasure
|
|
|
def set_tokenizer_params(tokenizer: LlamaTokenizer):
|
|
|
tokenizer.pad_token_id = 0
|
|
|
tokenizer.padding_side = "left"
|
|
|
|
|
|
-# Converting Bytes to Megabytes
|
|
|
-def byte2mb(x):
|
|
|
- return int(x / 2**20)
|
|
|
+@contextlib.contextmanager
|
|
|
+def profile(cfg, local_rank=None):
|
|
|
+ use_profiler: bool = cfg.use_profiler
|
|
|
+ use_flop_counter: bool = cfg.flop_counter
|
|
|
+ if use_flop_counter and use_profiler:
|
|
|
+ raise ValueError("Cannot use both profiler and flop counter")
|
|
|
+ if use_profiler:
|
|
|
+ # profiler needs a warmup stage to get the accurate profiling results
|
|
|
+ wait_step, warmup_step, active_step = 1, 2, 3
|
|
|
+ min_step = wait_step + warmup_step + active_step + 1
|
|
|
+ if cfg.max_train_step > 0 and cfg.max_train_step < min_step:
|
|
|
+ raise ValueError(f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
|
|
|
+ print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}")
|
|
|
+ with torch.profiler.profile(
|
|
|
+ activities=[
|
|
|
+ torch.profiler.ProfilerActivity.CPU,
|
|
|
+ torch.profiler.ProfilerActivity.CUDA,
|
|
|
+ ],
|
|
|
+ schedule=torch.profiler.schedule(wait=wait_step, warmup=warmup_step, active=active_step, repeat=1),
|
|
|
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
|
|
+ cfg.profiler_dir
|
|
|
+ ),
|
|
|
+ profile_memory=True,
|
|
|
+ with_stack=False,
|
|
|
+ with_flops=True,
|
|
|
+ record_shapes=True,
|
|
|
+ ) as torch_profiler:
|
|
|
+ yield torch_profiler
|
|
|
+ elif use_flop_counter:
|
|
|
+ if cfg.max_train_step > 0 and cfg.max_train_step <= cfg.flop_counter_start:
|
|
|
+ raise ValueError(f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
|
|
|
+ with FlopMeasure(rank=local_rank,warmup_step=cfg.flop_counter_start) as flop_counter:
|
|
|
+ yield flop_counter
|
|
|
+ else:
|
|
|
+ torch_profiler = contextlib.nullcontext()
|
|
|
+ yield None
|
|
|
+
|
|
|
|
|
|
def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
|
|
|
"""
|
|
@@ -62,7 +97,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
|
|
|
|
|
|
autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
|
|
|
-
|
|
|
train_prep = []
|
|
|
train_loss = []
|
|
|
val_prep = []
|
|
@@ -92,73 +126,77 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
total_loss = 0.0
|
|
|
total_length = len(train_dataloader)//gradient_accumulation_steps
|
|
|
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
|
|
|
- for step, batch in enumerate(train_dataloader):
|
|
|
- total_train_steps += 1
|
|
|
- # stop when the maximum number of training steps is reached
|
|
|
- if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
|
|
|
- max_steps_reached = True
|
|
|
- if not train_config.enable_fsdp or local_rank==0:
|
|
|
- print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
|
|
|
- break
|
|
|
- for key in batch.keys():
|
|
|
- if train_config.enable_fsdp:
|
|
|
- if is_xpu_available():
|
|
|
- batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
|
|
|
+ with profile(train_config,local_rank) as profile_context:
|
|
|
+ for step, batch in enumerate(train_dataloader):
|
|
|
+ total_train_steps += 1
|
|
|
+ # stop when the maximum number of training steps is reached
|
|
|
+ if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
|
|
|
+ max_steps_reached = True
|
|
|
+ if not train_config.enable_fsdp or local_rank==0:
|
|
|
+ print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1)
|
|
|
+ break
|
|
|
+ for key in batch.keys():
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ if is_xpu_available():
|
|
|
+ batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
|
|
|
+ else:
|
|
|
+ batch[key] = batch[key].to(local_rank)
|
|
|
else:
|
|
|
- batch[key] = batch[key].to(local_rank)
|
|
|
- else:
|
|
|
|
|
|
- if is_xpu_available():
|
|
|
- batch[key] = batch[key].to('xpu:0')
|
|
|
- else:
|
|
|
- batch[key] = batch[key].to('cuda:0')
|
|
|
- with autocast():
|
|
|
- loss = model(**batch).loss
|
|
|
- loss = loss / gradient_accumulation_steps
|
|
|
- if train_config.save_metrics:
|
|
|
- train_step_loss.append(loss.detach().float().item())
|
|
|
- train_step_perplexity.append(float(torch.exp(loss.detach().float())))
|
|
|
- total_loss += loss.detach().float()
|
|
|
- if train_config.use_fp16:
|
|
|
- # if fp16 is enabled, use gradient scaler to handle gradient update
|
|
|
- scaler.scale(loss).backward()
|
|
|
- if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
- if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
|
|
|
- scaler.unscale_(optimizer)
|
|
|
- if train_config.enable_fsdp:
|
|
|
- model.clip_grad_norm_(train_config.gradient_clipping_threshold)
|
|
|
+ if is_xpu_available():
|
|
|
+ batch[key] = batch[key].to('xpu:0')
|
|
|
else:
|
|
|
- torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
|
|
|
- scaler.step(optimizer)
|
|
|
- scaler.update()
|
|
|
- optimizer.zero_grad()
|
|
|
- pbar.update(1)
|
|
|
- else:
|
|
|
- # regular backpropagation when fp16 is not used
|
|
|
- loss.backward()
|
|
|
- if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
- if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
|
|
|
- if train_config.enable_fsdp:
|
|
|
- model.clip_grad_norm_(train_config.gradient_clipping_threshold)
|
|
|
- else:
|
|
|
- torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
|
|
|
- optimizer.step()
|
|
|
- optimizer.zero_grad()
|
|
|
- pbar.update(1)
|
|
|
-
|
|
|
- if wandb_run:
|
|
|
- if not train_config.enable_fsdp or rank==0:
|
|
|
- wandb_run.log({
|
|
|
- 'train/epoch': epoch + 1,
|
|
|
- 'train/step': epoch * len(train_dataloader) + step,
|
|
|
- 'train/loss': loss.detach().float(),
|
|
|
- })
|
|
|
-
|
|
|
- pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
|
|
|
-
|
|
|
- if train_config.save_metrics:
|
|
|
- save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
|
|
|
- pbar.close()
|
|
|
+ batch[key] = batch[key].to('cuda:0')
|
|
|
+ with autocast():
|
|
|
+ loss = model(**batch).loss
|
|
|
+ loss = loss / gradient_accumulation_steps
|
|
|
+ if train_config.save_metrics:
|
|
|
+ train_step_loss.append(loss.detach().float().item())
|
|
|
+ train_step_perplexity.append(float(torch.exp(loss.detach().float())))
|
|
|
+ total_loss += loss.detach().float()
|
|
|
+ if train_config.use_fp16:
|
|
|
+ # if fp16 is enabled, use gradient scaler to handle gradient update
|
|
|
+ scaler.scale(loss).backward()
|
|
|
+ if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
+ if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
|
|
|
+ scaler.unscale_(optimizer)
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ model.clip_grad_norm_(train_config.gradient_clipping_threshold)
|
|
|
+ else:
|
|
|
+ torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
|
|
|
+ scaler.step(optimizer)
|
|
|
+ scaler.update()
|
|
|
+ optimizer.zero_grad()
|
|
|
+ pbar.update(1)
|
|
|
+ else:
|
|
|
+ # regular backpropagation when fp16 is not used
|
|
|
+ loss.backward()
|
|
|
+ if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
+ if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ model.clip_grad_norm_(train_config.gradient_clipping_threshold)
|
|
|
+ else:
|
|
|
+ torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
|
|
|
+ optimizer.step()
|
|
|
+ optimizer.zero_grad()
|
|
|
+ pbar.update(1)
|
|
|
+ if train_config.use_profiler or train_config.flop_counter:
|
|
|
+ profile_context.step()
|
|
|
+ if train_config.flop_counter and profile_context.is_done():
|
|
|
+ TFlops = profile_context.get_flops_per_sec() / 1e12
|
|
|
+ if wandb_run:
|
|
|
+ if not train_config.enable_fsdp or rank==0:
|
|
|
+ wandb_run.log({
|
|
|
+ 'train/epoch': epoch + 1,
|
|
|
+ 'train/step': epoch * len(train_dataloader) + step,
|
|
|
+ 'train/loss': loss.detach().float(),
|
|
|
+ })
|
|
|
+
|
|
|
+ pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
|
|
|
+
|
|
|
+ if train_config.save_metrics:
|
|
|
+ save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
|
|
|
+ pbar.close()
|
|
|
|
|
|
epoch_end_time = time.perf_counter()-epoch_start_time
|
|
|
epoch_times.append(epoch_end_time)
|
|
@@ -180,7 +218,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
|
|
|
# Update the learning rate as needed
|
|
|
lr_scheduler.step()
|
|
|
-
|
|
|
if train_config.run_validation:
|
|
|
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
|
|
|
if train_config.save_metrics:
|
|
@@ -266,7 +303,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
results["avg_checkpoint_time"] = avg_checkpoint_time
|
|
|
if train_config.save_metrics:
|
|
|
results["metrics_filename"] = metrics_filename
|
|
|
-
|
|
|
+ if train_config.flop_counter:
|
|
|
+ results["model_tflops"]= TFlops
|
|
|
#saving the training params including fsdp setting for reference.
|
|
|
if train_config.enable_fsdp and not train_config.use_peft and rank==0:
|
|
|
save_train_params(train_config, fsdp_config, rank)
|