Browse Source

changed context name and add more docs

Kai Wu 11 tháng trước cách đây
mục cha
commit
d9558c11ca

Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 11 - 2
docs/multi_gpu.md


Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 11 - 2
docs/single_gpu.md


Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 1 - 1
recipes/finetuning/README.md


Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 6 - 3
recipes/finetuning/multigpu_finetuning.md


Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 8 - 2
recipes/finetuning/singlegpu_finetuning.md


+ 1 - 0
requirements.txt

@@ -18,3 +18,4 @@ gradio
 chardet
 openai
 typing-extensions==4.8.0
+tabulate

+ 8 - 9
src/llama_recipes/utils/train_utils.py

@@ -31,7 +31,7 @@ def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.padding_side = "left"
 
 @contextlib.contextmanager
-def throughput_measure_context(cfg, local_rank=None):
+def profile(cfg, local_rank=None):
     use_profiler: bool = cfg.use_profiler
     use_flop_counter: bool = cfg.flop_counter
     if use_flop_counter and use_profiler:
@@ -41,7 +41,7 @@ def throughput_measure_context(cfg, local_rank=None):
         wait_step, warmup_step, active_step = 1, 2, 3
         min_step = wait_step + warmup_step + active_step + 1
         if cfg.max_train_step > 0 and cfg.max_train_step < min_step:
-            raise ValueError(f"pytorch profiler requires at least {min_step} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
+            raise ValueError(f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
         print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}")
         with torch.profiler.profile(
             activities=[
@@ -97,7 +97,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
 
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
-
     train_prep = []
     train_loss = []
     val_prep = []
@@ -127,7 +126,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             total_loss = 0.0
             total_length = len(train_dataloader)//gradient_accumulation_steps
             pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
-            with throughput_measure_context(train_config,local_rank) as measure_context:
+            with profile(train_config,local_rank) as profile_context:
                 for step, batch in enumerate(train_dataloader):
                     total_train_steps += 1
                     # stop when the maximum number of training steps is reached
@@ -138,7 +137,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         break
                     if train_config.flop_counter and total_train_steps == train_config.flop_counter_start:
                         print("start flop counting at the step: ", total_train_steps)
-                        measure_context.start_counting()
+                        profile_context.start_counting()
                     for key in batch.keys():
                         if train_config.enable_fsdp:
                             if is_xpu_available():
@@ -185,10 +184,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             optimizer.zero_grad()
                             pbar.update(1)
                     if train_config.use_profiler:
-                        measure_context.step()
-                    if train_config.flop_counter and measure_context.is_ready():
-                        TFlops = measure_context.get_total_flops() / 1e12
-                        measure_context.stop_counting()
+                        profile_context.step()
+                    if train_config.flop_counter and profile_context.is_ready():
+                        TFlops = profile_context.get_total_flops() / 1e12
+                        profile_context.stop_counting()
                     if wandb_run:
                         if not train_config.enable_fsdp or rank==0:
                             wandb_run.log({