瀏覽代碼

fixed some typo to pass spellcheck

Kai Wu 1 年之前
父節點
當前提交
03f1ca7817
共有 3 個文件被更改,包括 11 次插入11 次删除
  1. 6 6
      recipes/finetuning/README.md
  2. 2 2
      src/llama_recipes/configs/training.py
  3. 3 3
      src/llama_recipes/utils/train_utils.py

+ 6 - 6
recipes/finetuning/README.md

@@ -50,9 +50,9 @@ save_model: bool = False
 dist_checkpoint_root_folder: str="model_checkpoints"
 dist_checkpoint_root_folder: str="model_checkpoints"
 dist_checkpoint_folder: str="fine-tuned"
 dist_checkpoint_folder: str="fine-tuned"
 save_optimizer: bool=False
 save_optimizer: bool=False
-flop_counter: bool=False # Enable Flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
-flop_counter_startpoint: int=3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
-use_profiler: bool=False # Enable pytorch profiler, can not be used with flop counter at the same time.
+flop_counter: bool=False # Enable FLOPS counter to measure model throughput, can not be used with pytorch profiler at the same time.
+flop_counter_start: int=3 # The step to start profiling, default is 3, which means after 3 steps of warm-up stage, the profiler will start to count FLOPS.
+use_profiler: bool=False # Enable pytorch profiler, can not be used with FLOPS counter at the same time.
 profiler_dir: str="PATH/to/save/profiler/results" # will be used if using profiler
 profiler_dir: str="PATH/to/save/profiler/results" # will be used if using profiler
 ```
 ```
 
 
@@ -94,8 +94,8 @@ You'll be able to access a dedicated project or run link on [wandb.ai](https://w
     <img src="../../docs/images/wandb_screenshot.png" alt="wandb screenshot" width="500" />
     <img src="../../docs/images/wandb_screenshot.png" alt="wandb screenshot" width="500" />
 </div>
 </div>
 
 
-## FLop Counting and Pytorch Profiling
+## FLOPS Counting and Pytorch Profiling
 
 
-To help with benchmarking effort, we are adding the support for counting the flops during the fine-tuning process. You can achieve this by setting `--flop_counter` when launching your single/multi GPU fine-tuning. Use `--flop_counter_startpoint` to choose which step to count the flops. It is recommended to allow a warmup stage before using the flop counter.
+To help with benchmarking effort, we are adding the support for counting the FLOPS during the fine-tuning process. You can achieve this by setting `--flop_counter` when launching your single/multi GPU fine-tuning. Use `--flop_counter_start` to choose which step to count the FLOPS. It is recommended to allow a warm-up stage before using the FLOPS counter.
 
 
-Similarly, you can set `--use_profiler` flag and pass a profiling output path using `--profiler_dir` to capture the profile traces of your model using [PyTorch profiler](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html). This would be helpful for debugging purposes. However, the `--flop_counter` and `--use_profiler` can not be used in the same time to ensure the measurement accuarcy.
+Similarly, you can set `--use_profiler` flag and pass a profiling output path using `--profiler_dir` to capture the profile traces of your model using [PyTorch profiler](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html). This would be helpful for debugging purposes. However, the `--flop_counter` and `--use_profiler` can not be used in the same time to ensure the measurement accuracy.

+ 2 - 2
src/llama_recipes/configs/training.py

@@ -42,7 +42,7 @@ class train_config:
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     use_wandb: bool = False # Enable wandb for experient tracking
     use_wandb: bool = False # Enable wandb for experient tracking
     save_metrics: bool = False # saves training metrics to a json file for later plotting
     save_metrics: bool = False # saves training metrics to a json file for later plotting
-    flop_counter: bool = False # Enable Flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
-    flop_counter_startpoint: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
+    flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
+    flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
     use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
     use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
     profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
     profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler

+ 3 - 3
src/llama_recipes/utils/train_utils.py

@@ -59,8 +59,8 @@ def throughput_measure_context(cfg, local_rank=None):
         ) as torch_profiler:
         ) as torch_profiler:
             yield torch_profiler
             yield torch_profiler
     elif use_flop_counter:
     elif use_flop_counter:
-        if cfg.max_train_step > 0 and cfg.max_train_step < cfg.flop_counter_startpoint:
-            raise ValueError(f"flop counter requires at least {cfg.flop_counter_startpoint} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
+        if cfg.max_train_step > 0 and cfg.max_train_step < cfg.flop_counter_start:
+            raise ValueError(f"flop counter requires at least {cfg.flop_counter_start} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
         with FlopMeasure(rank=local_rank) as flop_counter:
         with FlopMeasure(rank=local_rank) as flop_counter:
             yield flop_counter
             yield flop_counter
     else:
     else:
@@ -136,7 +136,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         if not train_config.enable_fsdp or local_rank==0:
                         if not train_config.enable_fsdp or local_rank==0:
                             print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1)
                             print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1)
                         break
                         break
-                    if train_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint:
+                    if train_config.flop_counter and total_train_steps == train_config.flop_counter_start:
                         print("start flop counting at the step: ", total_train_steps)
                         print("start flop counting at the step: ", total_train_steps)
                         measure_context.start_counting()
                         measure_context.start_counting()
                     for key in batch.keys():
                     for key in batch.keys():