|
@@ -25,7 +25,6 @@ from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_
|
|
from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
|
|
from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
|
|
from llama_recipes.utils.memory_utils import MemoryTrace
|
|
from llama_recipes.utils.memory_utils import MemoryTrace
|
|
from accelerate.utils import is_xpu_available, is_ccl_available
|
|
from accelerate.utils import is_xpu_available, is_ccl_available
|
|
-#from llama_recipes.utils.tflop_counter import FlopCounterMode
|
|
|
|
from llama_recipes.utils.flop_utils import FlopMeasure
|
|
from llama_recipes.utils.flop_utils import FlopMeasure
|
|
def set_tokenizer_params(tokenizer: LlamaTokenizer):
|
|
def set_tokenizer_params(tokenizer: LlamaTokenizer):
|
|
tokenizer.pad_token_id = 0
|
|
tokenizer.pad_token_id = 0
|
|
@@ -68,12 +67,6 @@ def throughput_measure_context(cfg, local_rank=None):
|
|
torch_profiler = contextlib.nullcontext()
|
|
torch_profiler = contextlib.nullcontext()
|
|
yield None
|
|
yield None
|
|
|
|
|
|
-def get_total_flops(model):
|
|
|
|
- return (sum([v for _, v in model.flop_counts["Global"].items()]))
|
|
|
|
-
|
|
|
|
-# Converting Bytes to Megabytes
|
|
|
|
-def byte2mb(x):
|
|
|
|
- return int(x / 2**20)
|
|
|
|
|
|
|
|
def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
|
|
def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
|
|
"""
|
|
"""
|
|
@@ -194,7 +187,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
if train_config.use_profiler:
|
|
if train_config.use_profiler:
|
|
measure_context.step()
|
|
measure_context.step()
|
|
if train_config.flop_counter and measure_context.is_ready():
|
|
if train_config.flop_counter and measure_context.is_ready():
|
|
- TFlops = get_total_flops(measure_context) / 1e12
|
|
|
|
|
|
+ TFlops = measure_context.get_total_flops() / 1e12
|
|
measure_context.stop_counting()
|
|
measure_context.stop_counting()
|
|
if wandb_run:
|
|
if wandb_run:
|
|
if not train_config.enable_fsdp or rank==0:
|
|
if not train_config.enable_fsdp or rank==0:
|