|
@@ -1,34 +1,34 @@
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
+import contextlib
|
|
|
+import json
|
|
|
import os
|
|
|
import time
|
|
|
-import yaml
|
|
|
from contextlib import nullcontext
|
|
|
-from pathlib import Path
|
|
|
from datetime import datetime
|
|
|
-import contextlib
|
|
|
-
|
|
|
+from pathlib import Path
|
|
|
|
|
|
import torch
|
|
|
-import torch.cuda.nccl as nccl
|
|
|
import torch.distributed as dist
|
|
|
+import yaml
|
|
|
+from accelerate.utils import is_ccl_available, is_xpu_available
|
|
|
+
|
|
|
+from llama_recipes.model_checkpointing import save_checkpoint
|
|
|
+from llama_recipes.policies import bfSixteen, fpSixteen, get_llama_wrapper
|
|
|
+from llama_recipes.utils.flop_utils import FlopMeasure
|
|
|
+from llama_recipes.utils.memory_utils import MemoryTrace
|
|
|
from torch.distributed.fsdp import StateDictType
|
|
|
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
|
|
from tqdm import tqdm
|
|
|
from transformers import LlamaTokenizer
|
|
|
-import json
|
|
|
|
|
|
|
|
|
-from llama_recipes.model_checkpointing import save_fsdp_model_checkpoint_full, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_peft_checkpoint, save_model_checkpoint
|
|
|
-from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
|
|
|
-from llama_recipes.utils.memory_utils import MemoryTrace
|
|
|
-from accelerate.utils import is_xpu_available, is_ccl_available
|
|
|
-from llama_recipes.utils.flop_utils import FlopMeasure
|
|
|
def set_tokenizer_params(tokenizer: LlamaTokenizer):
|
|
|
tokenizer.pad_token_id = 0
|
|
|
tokenizer.padding_side = "left"
|
|
|
|
|
|
+
|
|
|
@contextlib.contextmanager
|
|
|
def profile(cfg, local_rank=None):
|
|
|
use_profiler: bool = cfg.use_profiler
|
|
@@ -40,17 +40,21 @@ def profile(cfg, local_rank=None):
|
|
|
wait_step, warmup_step, active_step = 1, 2, 3
|
|
|
min_step = wait_step + warmup_step + active_step + 1
|
|
|
if cfg.max_train_step > 0 and cfg.max_train_step < min_step:
|
|
|
- raise ValueError(f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
|
|
|
- print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}")
|
|
|
+ raise ValueError(
|
|
|
+ f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}"
|
|
|
+ )
|
|
|
+ print(
|
|
|
+ f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}"
|
|
|
+ )
|
|
|
with torch.profiler.profile(
|
|
|
activities=[
|
|
|
torch.profiler.ProfilerActivity.CPU,
|
|
|
torch.profiler.ProfilerActivity.CUDA,
|
|
|
],
|
|
|
- schedule=torch.profiler.schedule(wait=wait_step, warmup=warmup_step, active=active_step, repeat=1),
|
|
|
- on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
|
|
- cfg.profiler_dir
|
|
|
+ schedule=torch.profiler.schedule(
|
|
|
+ wait=wait_step, warmup=warmup_step, active=active_step, repeat=1
|
|
|
),
|
|
|
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(cfg.profiler_dir),
|
|
|
profile_memory=True,
|
|
|
with_stack=False,
|
|
|
with_flops=True,
|
|
@@ -59,15 +63,32 @@ def profile(cfg, local_rank=None):
|
|
|
yield torch_profiler
|
|
|
elif use_flop_counter:
|
|
|
if cfg.max_train_step > 0 and cfg.max_train_step <= cfg.flop_counter_start:
|
|
|
- raise ValueError(f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
|
|
|
- with FlopMeasure(rank=local_rank,warmup_step=cfg.flop_counter_start) as flop_counter:
|
|
|
+ raise ValueError(
|
|
|
+ f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}"
|
|
|
+ )
|
|
|
+ with FlopMeasure(
|
|
|
+ rank=local_rank, warmup_step=cfg.flop_counter_start
|
|
|
+ ) as flop_counter:
|
|
|
yield flop_counter
|
|
|
else:
|
|
|
torch_profiler = contextlib.nullcontext()
|
|
|
yield None
|
|
|
|
|
|
|
|
|
-def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
|
|
|
+def train(
|
|
|
+ model,
|
|
|
+ train_dataloader,
|
|
|
+ eval_dataloader,
|
|
|
+ tokenizer,
|
|
|
+ optimizer,
|
|
|
+ lr_scheduler,
|
|
|
+ gradient_accumulation_steps,
|
|
|
+ train_config,
|
|
|
+ fsdp_config=None,
|
|
|
+ local_rank=None,
|
|
|
+ rank=None,
|
|
|
+ wandb_run=None,
|
|
|
+):
|
|
|
"""
|
|
|
Trains the model on the given dataloader
|
|
|
|
|
@@ -93,13 +114,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if train_config.enable_fsdp:
|
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
|
|
|
|
-
|
|
|
-
|
|
|
autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
|
|
|
train_prep = []
|
|
|
train_loss = []
|
|
|
val_prep = []
|
|
|
- val_loss =[]
|
|
|
+ val_loss = []
|
|
|
|
|
|
if train_config.save_metrics:
|
|
|
if not os.path.exists(train_config.output_dir):
|
|
@@ -127,45 +146,70 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
with MemoryTrace() as memtrace: # track the memory usage
|
|
|
model.train()
|
|
|
total_loss = 0.0
|
|
|
- total_length = len(train_dataloader)//gradient_accumulation_steps
|
|
|
- pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
|
|
|
- with profile(train_config,local_rank) as profile_context:
|
|
|
+ total_length = len(train_dataloader) // gradient_accumulation_steps
|
|
|
+ pbar = tqdm(
|
|
|
+ colour="blue",
|
|
|
+ desc=f"Training Epoch: {epoch+1}",
|
|
|
+ total=total_length,
|
|
|
+ dynamic_ncols=True,
|
|
|
+ )
|
|
|
+ with profile(train_config, local_rank) as profile_context:
|
|
|
for step, batch in enumerate(train_dataloader):
|
|
|
total_train_steps += 1
|
|
|
# stop when the maximum number of training steps is reached
|
|
|
- if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
|
|
|
+ if (
|
|
|
+ train_config.max_train_step > 0
|
|
|
+ and total_train_steps > train_config.max_train_step
|
|
|
+ ):
|
|
|
max_steps_reached = True
|
|
|
- if not train_config.enable_fsdp or local_rank==0:
|
|
|
- print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1)
|
|
|
+ if not train_config.enable_fsdp or local_rank == 0:
|
|
|
+ print(
|
|
|
+ "max training steps reached, stopping training, total train steps finished: ",
|
|
|
+ total_train_steps - 1,
|
|
|
+ )
|
|
|
break
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
|
if is_xpu_available():
|
|
|
- batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
|
|
|
+ batch[key] = batch[key].to(
|
|
|
+ torch.device(f"xpu:{local_rank}")
|
|
|
+ )
|
|
|
else:
|
|
|
batch[key] = batch[key].to(local_rank)
|
|
|
else:
|
|
|
if is_xpu_available():
|
|
|
- batch[key] = batch[key].to('xpu:0')
|
|
|
+ batch[key] = batch[key].to("xpu:0")
|
|
|
elif torch.cuda.is_available():
|
|
|
- batch[key] = batch[key].to('cuda:0')
|
|
|
+ batch[key] = batch[key].to("cuda:0")
|
|
|
with autocast():
|
|
|
loss = model(**batch).loss
|
|
|
total_loss += loss.detach().float()
|
|
|
loss = loss / gradient_accumulation_steps
|
|
|
if train_config.save_metrics:
|
|
|
train_step_loss.append(loss.detach().float().item())
|
|
|
- train_step_perplexity.append(float(torch.exp(loss.detach().float())))
|
|
|
+ train_step_perplexity.append(
|
|
|
+ float(torch.exp(loss.detach().float()))
|
|
|
+ )
|
|
|
if train_config.use_fp16:
|
|
|
# if fp16 is enabled, use gradient scaler to handle gradient update
|
|
|
scaler.scale(loss).backward()
|
|
|
- if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
- if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
|
|
|
+ if (step + 1) % gradient_accumulation_steps == 0 or step == len(
|
|
|
+ train_dataloader
|
|
|
+ ) - 1:
|
|
|
+ if (
|
|
|
+ train_config.gradient_clipping
|
|
|
+ and train_config.gradient_clipping_threshold > 0.0
|
|
|
+ ):
|
|
|
scaler.unscale_(optimizer)
|
|
|
if train_config.enable_fsdp:
|
|
|
- model.clip_grad_norm_(train_config.gradient_clipping_threshold)
|
|
|
+ model.clip_grad_norm_(
|
|
|
+ train_config.gradient_clipping_threshold
|
|
|
+ )
|
|
|
else:
|
|
|
- torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
|
|
|
+ torch.nn.utils.clip_grad_norm_(
|
|
|
+ model.parameters(),
|
|
|
+ train_config.gradient_clipping_threshold,
|
|
|
+ )
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
optimizer.zero_grad()
|
|
@@ -173,12 +217,22 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
else:
|
|
|
# regular backpropagation when fp16 is not used
|
|
|
loss.backward()
|
|
|
- if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
- if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
|
|
|
+ if (step + 1) % gradient_accumulation_steps == 0 or step == len(
|
|
|
+ train_dataloader
|
|
|
+ ) - 1:
|
|
|
+ if (
|
|
|
+ train_config.gradient_clipping
|
|
|
+ and train_config.gradient_clipping_threshold > 0.0
|
|
|
+ ):
|
|
|
if train_config.enable_fsdp:
|
|
|
- model.clip_grad_norm_(train_config.gradient_clipping_threshold)
|
|
|
+ model.clip_grad_norm_(
|
|
|
+ train_config.gradient_clipping_threshold
|
|
|
+ )
|
|
|
else:
|
|
|
- torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
|
|
|
+ torch.nn.utils.clip_grad_norm_(
|
|
|
+ model.parameters(),
|
|
|
+ train_config.gradient_clipping_threshold,
|
|
|
+ )
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
pbar.update(1)
|
|
@@ -187,96 +241,71 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if train_config.flop_counter and profile_context.is_done():
|
|
|
TFlops = profile_context.get_flops_per_sec() / 1e12
|
|
|
if wandb_run:
|
|
|
- if not train_config.enable_fsdp or rank==0:
|
|
|
- wandb_run.log({
|
|
|
- 'train/epoch': epoch + 1,
|
|
|
- 'train/step': epoch * len(train_dataloader) + step,
|
|
|
- 'train/loss': loss.detach().float(),
|
|
|
- })
|
|
|
-
|
|
|
- pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
|
|
|
+ if not train_config.enable_fsdp or rank == 0:
|
|
|
+ wandb_run.log(
|
|
|
+ {
|
|
|
+ "train/epoch": epoch + 1,
|
|
|
+ "train/step": epoch * len(train_dataloader) + step,
|
|
|
+ "train/loss": loss.detach().float(),
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ pbar.set_description(
|
|
|
+ f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})"
|
|
|
+ )
|
|
|
|
|
|
if train_config.save_metrics:
|
|
|
- save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
|
|
|
+ save_to_json(
|
|
|
+ metrics_filename,
|
|
|
+ train_step_loss,
|
|
|
+ train_loss,
|
|
|
+ train_step_perplexity,
|
|
|
+ train_prep,
|
|
|
+ val_step_loss,
|
|
|
+ val_loss,
|
|
|
+ val_step_perplexity,
|
|
|
+ val_prep,
|
|
|
+ )
|
|
|
pbar.close()
|
|
|
|
|
|
- epoch_end_time = time.perf_counter()-epoch_start_time
|
|
|
+ epoch_end_time = time.perf_counter() - epoch_start_time
|
|
|
epoch_times.append(epoch_end_time)
|
|
|
# Reducing total_loss across all devices if there's more than one CUDA device
|
|
|
- if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
|
|
|
+ if is_xpu_available() and (
|
|
|
+ torch.xpu.device_count() > 1 and train_config.enable_fsdp
|
|
|
+ ):
|
|
|
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
|
|
|
elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
|
|
|
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
|
|
|
train_epoch_loss = total_loss / len(train_dataloader)
|
|
|
if train_config.enable_fsdp:
|
|
|
- train_epoch_loss = train_epoch_loss/world_size
|
|
|
+ train_epoch_loss = train_epoch_loss / world_size
|
|
|
train_perplexity = torch.exp(train_epoch_loss)
|
|
|
|
|
|
train_prep.append(float(train_perplexity))
|
|
|
train_loss.append(float(train_epoch_loss))
|
|
|
|
|
|
- if not train_config.enable_fsdp or rank==0:
|
|
|
+ if not train_config.enable_fsdp or rank == 0:
|
|
|
memtrace.print_stats()
|
|
|
|
|
|
# Update the learning rate as needed
|
|
|
lr_scheduler.step()
|
|
|
should_save_model = train_config.save_model
|
|
|
if train_config.run_validation:
|
|
|
- eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
|
|
|
+ eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(
|
|
|
+ model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run
|
|
|
+ )
|
|
|
if train_config.save_metrics:
|
|
|
val_step_loss.extend(temp_val_loss)
|
|
|
val_step_perplexity.extend(temp_step_perplexity)
|
|
|
- should_save_model = train_config.save_model and eval_epoch_loss < best_val_loss
|
|
|
-
|
|
|
+ should_save_model = (
|
|
|
+ train_config.save_model and eval_epoch_loss < best_val_loss
|
|
|
+ )
|
|
|
+
|
|
|
checkpoint_start_time = time.perf_counter()
|
|
|
if should_save_model:
|
|
|
- if train_config.enable_fsdp:
|
|
|
- dist.barrier()
|
|
|
- if train_config.use_peft:
|
|
|
- if train_config.enable_fsdp:
|
|
|
- if rank==0:
|
|
|
- print(f"we are about to save the PEFT modules")
|
|
|
- else:
|
|
|
- print(f"we are about to save the PEFT modules")
|
|
|
- save_peft_checkpoint(model, train_config.output_dir)
|
|
|
- if train_config.enable_fsdp:
|
|
|
- if rank==0:
|
|
|
- print(f"PEFT modules are saved in {train_config.output_dir} directory")
|
|
|
- else:
|
|
|
- print(f"PEFT modules are saved in {train_config.output_dir} directory")
|
|
|
-
|
|
|
- else:
|
|
|
- if not train_config.enable_fsdp:
|
|
|
- save_model_checkpoint(model, train_config.output_dir)
|
|
|
-
|
|
|
- elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
|
|
|
- print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
|
|
|
- print("=====================================================")
|
|
|
- save_fsdp_model_checkpoint_full(
|
|
|
- model, optimizer, rank, train_config, epoch=epoch
|
|
|
- )
|
|
|
-
|
|
|
- if train_config.save_optimizer:
|
|
|
- print(" Saving the FSDP optimizer using FULL_STATE_DICT")
|
|
|
- print("=====================================================")
|
|
|
- save_optimizer_checkpoint(
|
|
|
- model, optimizer, rank, train_config, epoch=epoch
|
|
|
- )
|
|
|
-
|
|
|
- elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
|
|
|
+ save_checkpoint(model, optimizer, train_config, fsdp_config, epoch)
|
|
|
|
|
|
- if train_config.save_optimizer:
|
|
|
- print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
|
|
|
- print("=====================================================")
|
|
|
- save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
|
|
|
- else:
|
|
|
- print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
|
|
|
- print("=====================================================")
|
|
|
- save_model_and_optimizer_sharded(model, rank, train_config)
|
|
|
-
|
|
|
-
|
|
|
- if train_config.enable_fsdp:
|
|
|
- dist.barrier()
|
|
|
checkpoint_end_time = time.perf_counter() - checkpoint_start_time
|
|
|
checkpoint_times.append(checkpoint_end_time)
|
|
|
|
|
@@ -284,48 +313,67 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if eval_epoch_loss < best_val_loss:
|
|
|
best_val_loss = eval_epoch_loss
|
|
|
if train_config.enable_fsdp:
|
|
|
- if rank==0:
|
|
|
+ if rank == 0:
|
|
|
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
|
|
|
else:
|
|
|
- print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
|
|
|
+ print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
|
|
|
val_loss.append(float(eval_epoch_loss))
|
|
|
val_prep.append(float(eval_ppl))
|
|
|
if train_config.enable_fsdp:
|
|
|
- if rank==0:
|
|
|
- print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
|
|
|
+ if rank == 0:
|
|
|
+ print(
|
|
|
+ f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
|
|
|
+ )
|
|
|
else:
|
|
|
- print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
|
|
|
+ print(
|
|
|
+ f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
|
|
|
+ )
|
|
|
|
|
|
# Saving the results every epoch to plot later
|
|
|
if train_config.save_metrics:
|
|
|
- save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
|
|
|
+ save_to_json(
|
|
|
+ metrics_filename,
|
|
|
+ train_step_loss,
|
|
|
+ train_loss,
|
|
|
+ train_step_perplexity,
|
|
|
+ train_prep,
|
|
|
+ val_step_loss,
|
|
|
+ val_loss,
|
|
|
+ val_step_perplexity,
|
|
|
+ val_prep,
|
|
|
+ )
|
|
|
|
|
|
- avg_epoch_time = sum(epoch_times)/ len(epoch_times)
|
|
|
- avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
|
|
|
- avg_train_prep = sum(train_prep)/len(train_prep)
|
|
|
- avg_train_loss = sum(train_loss)/len(train_loss)
|
|
|
+ avg_epoch_time = sum(epoch_times) / len(epoch_times)
|
|
|
+ avg_checkpoint_time = (
|
|
|
+ sum(checkpoint_times) / len(checkpoint_times)
|
|
|
+ if len(checkpoint_times) > 0
|
|
|
+ else 0
|
|
|
+ )
|
|
|
+ avg_train_prep = sum(train_prep) / len(train_prep)
|
|
|
+ avg_train_loss = sum(train_loss) / len(train_loss)
|
|
|
if train_config.run_validation:
|
|
|
- avg_eval_prep = sum(val_prep)/len(val_prep)
|
|
|
- avg_eval_loss = sum(val_loss)/len(val_loss)
|
|
|
+ avg_eval_prep = sum(val_prep) / len(val_prep)
|
|
|
+ avg_eval_loss = sum(val_loss) / len(val_loss)
|
|
|
|
|
|
- results['avg_train_prep'] = avg_train_prep
|
|
|
- results['avg_train_loss'] = avg_train_loss
|
|
|
+ results["avg_train_prep"] = avg_train_prep
|
|
|
+ results["avg_train_loss"] = avg_train_loss
|
|
|
if train_config.run_validation:
|
|
|
- results['avg_eval_prep'] = avg_eval_prep
|
|
|
- results['avg_eval_loss'] = avg_eval_loss
|
|
|
+ results["avg_eval_prep"] = avg_eval_prep
|
|
|
+ results["avg_eval_loss"] = avg_eval_loss
|
|
|
results["avg_epoch_time"] = avg_epoch_time
|
|
|
results["avg_checkpoint_time"] = avg_checkpoint_time
|
|
|
if train_config.save_metrics:
|
|
|
results["metrics_filename"] = metrics_filename
|
|
|
if train_config.flop_counter:
|
|
|
- results["model_tflops"]= TFlops
|
|
|
- #saving the training params including fsdp setting for reference.
|
|
|
- if train_config.enable_fsdp and not train_config.use_peft and rank==0:
|
|
|
+ results["model_tflops"] = TFlops
|
|
|
+ # saving the training params including fsdp setting for reference.
|
|
|
+ if train_config.enable_fsdp and not train_config.use_peft and rank == 0:
|
|
|
save_train_params(train_config, fsdp_config, rank)
|
|
|
|
|
|
return results
|
|
|
|
|
|
-def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
|
|
|
+
|
|
|
+def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
|
|
|
"""
|
|
|
Evaluates the model on the given dataloader
|
|
|
|
|
@@ -346,21 +394,34 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
|
|
|
eval_loss = 0.0 # Initialize evaluation loss
|
|
|
total_eval_steps = 0
|
|
|
with MemoryTrace() as memtrace:
|
|
|
- for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
|
|
|
+ for step, batch in enumerate(
|
|
|
+ tqdm(
|
|
|
+ eval_dataloader,
|
|
|
+ colour="green",
|
|
|
+ desc="evaluating Epoch",
|
|
|
+ dynamic_ncols=True,
|
|
|
+ )
|
|
|
+ ):
|
|
|
total_eval_steps += 1
|
|
|
# stop when the maximum number of eval steps is reached
|
|
|
- if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
|
|
|
- if not train_config.enable_fsdp or local_rank==0:
|
|
|
- print("max eval steps reached, stopping evaluation, total_eval_steps: ", total_eval_steps - 1)
|
|
|
+ if (
|
|
|
+ train_config.max_eval_step > 0
|
|
|
+ and total_eval_steps > train_config.max_eval_step
|
|
|
+ ):
|
|
|
+ if not train_config.enable_fsdp or local_rank == 0:
|
|
|
+ print(
|
|
|
+ "max eval steps reached, stopping evaluation, total_eval_steps: ",
|
|
|
+ total_eval_steps - 1,
|
|
|
+ )
|
|
|
break
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
|
batch[key] = batch[key].to(local_rank)
|
|
|
else:
|
|
|
if is_xpu_available():
|
|
|
- batch[key] = batch[key].to('xpu:0')
|
|
|
+ batch[key] = batch[key].to("xpu:0")
|
|
|
else:
|
|
|
- batch[key] = batch[key].to('cuda:0')
|
|
|
+ batch[key] = batch[key].to("cuda:0")
|
|
|
# Ensure no gradients are computed for this scope to save memory
|
|
|
with torch.no_grad():
|
|
|
# Forward pass and compute loss
|
|
@@ -374,11 +435,15 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
|
|
|
# Decode predictions and add to evaluation predictions list
|
|
|
preds = torch.argmax(outputs.logits, -1)
|
|
|
eval_preds.extend(
|
|
|
- tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
|
|
|
+ tokenizer.batch_decode(
|
|
|
+ preds.detach().cpu().numpy(), skip_special_tokens=True
|
|
|
+ )
|
|
|
)
|
|
|
|
|
|
# If there's more than one CUDA device, reduce evaluation loss across all devices
|
|
|
- if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
|
|
|
+ if is_xpu_available() and (
|
|
|
+ torch.xpu.device_count() > 1 and train_config.enable_fsdp
|
|
|
+ ):
|
|
|
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
|
|
|
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
|
|
|
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
|
|
@@ -386,35 +451,39 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
|
|
|
# Compute average loss and perplexity
|
|
|
eval_epoch_loss = eval_loss / len(eval_dataloader)
|
|
|
if train_config.enable_fsdp:
|
|
|
- eval_epoch_loss = eval_epoch_loss/world_size
|
|
|
+ eval_epoch_loss = eval_epoch_loss / world_size
|
|
|
eval_ppl = torch.exp(eval_epoch_loss)
|
|
|
|
|
|
# Print evaluation metrics
|
|
|
if train_config.enable_fsdp:
|
|
|
- if local_rank==0:
|
|
|
+ if local_rank == 0:
|
|
|
print(f" {eval_ppl=} {eval_epoch_loss=}")
|
|
|
else:
|
|
|
print(f" {eval_ppl=} {eval_epoch_loss=}")
|
|
|
|
|
|
if wandb_run:
|
|
|
- wandb_run.log({
|
|
|
- 'eval/perplexity': eval_ppl,
|
|
|
- 'eval/loss': eval_epoch_loss,
|
|
|
- }, commit=False)
|
|
|
+ wandb_run.log(
|
|
|
+ {
|
|
|
+ "eval/perplexity": eval_ppl,
|
|
|
+ "eval/loss": eval_epoch_loss,
|
|
|
+ },
|
|
|
+ commit=False,
|
|
|
+ )
|
|
|
|
|
|
return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
|
|
|
|
|
|
+
|
|
|
def freeze_transformer_layers(model, num_layer):
|
|
|
- for i, layer in enumerate(model.model.layers):
|
|
|
- if i < num_layer:
|
|
|
- for param in layer.parameters():
|
|
|
- param.requires_grad = False
|
|
|
+ for i, layer in enumerate(model.model.layers):
|
|
|
+ if i < num_layer:
|
|
|
+ for param in layer.parameters():
|
|
|
+ param.requires_grad = False
|
|
|
|
|
|
|
|
|
def check_frozen_layers_peft_model(model):
|
|
|
- for i, layer in enumerate(model.base_model.model.model.layers):
|
|
|
- for name, param in layer.named_parameters():
|
|
|
- print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
|
|
|
+ for i, layer in enumerate(model.base_model.model.model.layers):
|
|
|
+ for name, param in layer.named_parameters():
|
|
|
+ print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
|
|
|
|
|
|
|
|
|
def setup():
|
|
@@ -460,58 +529,6 @@ def get_parameter_dtypes(model):
|
|
|
parameter_dtypes[name] = parameter.dtype
|
|
|
return parameter_dtypes
|
|
|
|
|
|
-def print_model_size(model, config, rank: int = 0) -> None:
|
|
|
- """
|
|
|
- Print model name, the number of trainable parameters and initialization time.
|
|
|
-
|
|
|
- Args:
|
|
|
- model: The PyTorch model.
|
|
|
- model_name (str): Name of the model.
|
|
|
- init_time_start (float): Initialization start time.
|
|
|
- init_time_end (float): Initialization end time.
|
|
|
- rank (int, optional): Current process's rank. Defaults to 0.
|
|
|
- """
|
|
|
- if rank == 0:
|
|
|
- print(f"--> Model {config.model_name}")
|
|
|
- total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
- print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-def get_policies(cfg, rank):
|
|
|
- """Get the policies for mixed precision and fsdp wrapping"""
|
|
|
-
|
|
|
-
|
|
|
- verify_bfloat_support = ((
|
|
|
- torch.version.cuda
|
|
|
- and torch.cuda.is_bf16_supported()
|
|
|
- and torch.version.cuda >= "11.0"
|
|
|
- and dist.is_nccl_available()
|
|
|
- and nccl.version() >= (2, 10)
|
|
|
- ) or
|
|
|
- (is_xpu_available()))
|
|
|
-
|
|
|
-
|
|
|
- mixed_precision_policy = None
|
|
|
- wrapping_policy = None
|
|
|
-
|
|
|
- # Mixed precision
|
|
|
- if cfg.mixed_precision:
|
|
|
- bf16_ready = verify_bfloat_support
|
|
|
-
|
|
|
- if bf16_ready and not cfg.use_fp16:
|
|
|
- mixed_precision_policy = bfSixteen
|
|
|
- if rank == 0:
|
|
|
- print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
|
|
|
- elif cfg.use_fp16:
|
|
|
- mixed_precision_policy = fpSixteen
|
|
|
- if rank == 0:
|
|
|
- print(f"FP16 enabled")
|
|
|
- else:
|
|
|
- print(f"bFloat16 support not present. Using FP32, and not mixed precision")
|
|
|
- wrapping_policy = get_llama_wrapper()
|
|
|
- return mixed_precision_policy, wrapping_policy
|
|
|
|
|
|
def save_train_params(train_config, fsdp_config, rank):
|
|
|
"""
|
|
@@ -521,17 +538,21 @@ def save_train_params(train_config, fsdp_config, rank):
|
|
|
"""
|
|
|
# Convert the train_config and fsdp_config objects to dictionaries,
|
|
|
# converting all values to strings to ensure they can be serialized into a YAML file
|
|
|
- train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
|
|
|
- fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
|
|
|
+ train_config_dict = {
|
|
|
+ k: str(v) for k, v in vars(train_config).items() if not k.startswith("__")
|
|
|
+ }
|
|
|
+ fsdp_config_dict = {
|
|
|
+ k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith("__")
|
|
|
+ }
|
|
|
# Merge the two dictionaries into one
|
|
|
train_params_dict = {**train_config_dict, **fsdp_config_dict}
|
|
|
# Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
|
|
|
folder_name = (
|
|
|
- train_config.dist_checkpoint_root_folder
|
|
|
- + "/"
|
|
|
- + train_config.dist_checkpoint_folder
|
|
|
- + "-"
|
|
|
- + train_config.model_name
|
|
|
+ train_config.dist_checkpoint_root_folder
|
|
|
+ + "/"
|
|
|
+ + train_config.dist_checkpoint_folder
|
|
|
+ + "-"
|
|
|
+ + train_config.model_name
|
|
|
)
|
|
|
|
|
|
save_dir = Path.cwd() / folder_name
|
|
@@ -540,19 +561,30 @@ def save_train_params(train_config, fsdp_config, rank):
|
|
|
os.makedirs(save_dir)
|
|
|
# Convert the dictionary to a YAML string
|
|
|
config_yaml = yaml.dump(train_params_dict, indent=4)
|
|
|
- file_name = os.path.join(save_dir,'train_params.yaml')
|
|
|
+ file_name = os.path.join(save_dir, "train_params.yaml")
|
|
|
|
|
|
# Check if there's a directory with the same name as the file
|
|
|
if os.path.isdir(file_name):
|
|
|
print(f"Error: {file_name} is a directory, not a file.")
|
|
|
else:
|
|
|
# Write the YAML string to the file
|
|
|
- with open(file_name, 'w') as f:
|
|
|
+ with open(file_name, "w") as f:
|
|
|
f.write(config_yaml)
|
|
|
- if rank==0:
|
|
|
+ if rank == 0:
|
|
|
print(f"training params are saved in {file_name}")
|
|
|
|
|
|
-def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
|
|
|
+
|
|
|
+def save_to_json(
|
|
|
+ output_filename,
|
|
|
+ train_step_loss,
|
|
|
+ train_epoch_loss,
|
|
|
+ train_step_ppl,
|
|
|
+ train_epoch_ppl,
|
|
|
+ val_step_loss,
|
|
|
+ val_epoch_loss,
|
|
|
+ val_step_ppl,
|
|
|
+ val_epoch_ppl,
|
|
|
+):
|
|
|
metrics_data = {
|
|
|
"train_step_loss": train_step_loss,
|
|
|
"train_epoch_loss": train_epoch_loss,
|
|
@@ -561,7 +593,7 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_
|
|
|
"val_step_loss": val_step_loss,
|
|
|
"val_epoch_loss": val_epoch_loss,
|
|
|
"val_step_perplexity": val_step_ppl,
|
|
|
- "val_epoch_perplexity": val_epoch_ppl
|
|
|
+ "val_epoch_perplexity": val_epoch_ppl,
|
|
|
}
|
|
|
with open(output_filename, "w") as f:
|
|
|
json.dump(metrics_data, f)
|