Prechádzať zdrojové kódy

changed context name and add more docs

Kai Wu 11 mesiacov pred
rodič
commit
d9558c11ca

Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 11 - 2
docs/multi_gpu.md


Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 11 - 2
docs/single_gpu.md


Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 1 - 1
recipes/finetuning/README.md


Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 6 - 3
recipes/finetuning/multigpu_finetuning.md


Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 8 - 2
recipes/finetuning/singlegpu_finetuning.md


+ 1 - 0
requirements.txt

@@ -18,3 +18,4 @@ gradio
 chardet
 openai
 typing-extensions==4.8.0
+tabulate

+ 8 - 9
src/llama_recipes/utils/train_utils.py

@@ -31,7 +31,7 @@ def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.padding_side = "left"
 
 @contextlib.contextmanager
-def throughput_measure_context(cfg, local_rank=None):
+def profile(cfg, local_rank=None):
     use_profiler: bool = cfg.use_profiler
     use_flop_counter: bool = cfg.flop_counter
     if use_flop_counter and use_profiler:
@@ -41,7 +41,7 @@ def throughput_measure_context(cfg, local_rank=None):
         wait_step, warmup_step, active_step = 1, 2, 3
         min_step = wait_step + warmup_step + active_step + 1
         if cfg.max_train_step > 0 and cfg.max_train_step < min_step:
-            raise ValueError(f"pytorch profiler requires at least {min_step} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
+            raise ValueError(f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
         print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}")
         with torch.profiler.profile(
             activities=[
@@ -97,7 +97,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
 
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
-
     train_prep = []
     train_loss = []
     val_prep = []
@@ -127,7 +126,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             total_loss = 0.0
             total_length = len(train_dataloader)//gradient_accumulation_steps
             pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
-            with throughput_measure_context(train_config,local_rank) as measure_context:
+            with profile(train_config,local_rank) as profile_context:
                 for step, batch in enumerate(train_dataloader):
                     total_train_steps += 1
                     # stop when the maximum number of training steps is reached
@@ -138,7 +137,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         break
                     if train_config.flop_counter and total_train_steps == train_config.flop_counter_start:
                         print("start flop counting at the step: ", total_train_steps)
-                        measure_context.start_counting()
+                        profile_context.start_counting()
                     for key in batch.keys():
                         if train_config.enable_fsdp:
                             if is_xpu_available():
@@ -185,10 +184,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             optimizer.zero_grad()
                             pbar.update(1)
                     if train_config.use_profiler:
-                        measure_context.step()
-                    if train_config.flop_counter and measure_context.is_ready():
-                        TFlops = measure_context.get_total_flops() / 1e12
-                        measure_context.stop_counting()
+                        profile_context.step()
+                    if train_config.flop_counter and profile_context.is_ready():
+                        TFlops = profile_context.get_total_flops() / 1e12
+                        profile_context.stop_counting()
                     if wandb_run:
                         if not train_config.enable_fsdp or rank==0:
                             wandb_run.log({