Explorar o código

changed readme, unified the context interface and added get_flops_per_sec()

Kai Wu hai 11 meses
pai
achega
26e877fd42

+ 41 - 29
docs/multi_gpu.md

@@ -115,35 +115,47 @@ torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --m
 It lets us specify the training settings for everything from `model_name` to `dataset_name`, `batch_size` and so on. Below is the list of supported settings:
 
 ```python
-
-model_name: str="PATH/to/LLAMA 2/7B"
-enable_fsdp: bool= False
-run_validation: bool=True
-batch_size_training: int=4
-gradient_accumulation_steps: int=1
-num_epochs: int=3
-num_workers_dataloader: int=2
-lr: float=2e-4
-weight_decay: float=0.0
-gamma: float= 0.85
-use_fp16: bool=False
-mixed_precision: bool=True
-val_batch_size: int=4
-dataset = "samsum_dataset" # alpaca_dataset, grammar_dataset
-peft_method: str = "lora" # None , llama_adapter, prefix
-use_peft: bool=False
-output_dir: str = "./ft-output"
-freeze_layers: bool = False
-num_freeze_layers: int = 1
-quantization: bool = False
-save_model: bool = False
-dist_checkpoint_root_folder: str="model_checkpoints"
-dist_checkpoint_folder: str="fine-tuned"
-save_optimizer: bool=False
-flop_counter: bool=False # Enable FLOPS counter to measure model throughput, can not be used with pytorch profiler at the same time.
-flop_counter_start: int=3 # The step to start profiling, default is 3, which means after 3 steps of warm-up stage, the profiler will start to count FLOPS.
-use_profiler: bool=False # Enable pytorch profiler, can not be used with FLOPS counter at the same time.
-profiler_dir: str="PATH/to/save/profiler/results" # will be used if using profiler
+    model_name: str="PATH/to/Model"
+    tokenizer_name: str=None
+    enable_fsdp: bool=False
+    low_cpu_fsdp: bool=False
+    run_validation: bool=True
+    batch_size_training: int=4
+    batching_strategy: str="packing" #alternative: padding
+    context_length: int=4096
+    gradient_accumulation_steps: int=1
+    gradient_clipping: bool = False
+    gradient_clipping_threshold: float = 1.0
+    num_epochs: int=3
+    max_train_step: int=0
+    max_eval_step: int=0
+    num_workers_dataloader: int=1
+    lr: float=1e-4
+    weight_decay: float=0.0
+    gamma: float= 0.85
+    seed: int=42
+    use_fp16: bool=False
+    mixed_precision: bool=True
+    val_batch_size: int=1
+    dataset = "samsum_dataset"
+    peft_method: str = "lora" # None,llama_adapter, prefix
+    use_peft: bool=False
+    output_dir: str = "PATH/to/save/PEFT/model"
+    freeze_layers: bool = False
+    num_freeze_layers: int = 1
+    quantization: bool = False
+    one_gpu: bool = False
+    save_model: bool = True
+    dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
+    dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
+    save_optimizer: bool=False # will be used if using FSDP
+    use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    use_wandb: bool = False # Enable wandb for experient tracking
+    save_metrics: bool = False # saves training metrics to a json file for later plotting
+    flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
+    flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
+    use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
+    profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
 ```
 
 * [Datasets config file](../src/llama_recipes/configs/datasets.py) provides the available options for datasets.

+ 41 - 30
docs/single_gpu.md

@@ -71,36 +71,47 @@ python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization
 It let us specify the training settings, everything from `model_name` to `dataset_name`, `batch_size` etc. can be set here. Below is the list of supported settings:
 
 ```python
-
-model_name: str="PATH/to/LLAMA 2/7B"
-enable_fsdp: bool= False
-run_validation: bool=True
-batch_size_training: int=4
-gradient_accumulation_steps: int=1
-num_epochs: int=3
-num_workers_dataloader: int=2
-lr: float=2e-4
-weight_decay: float=0.0
-gamma: float= 0.85
-use_fp16: bool=False
-mixed_precision: bool=True
-val_batch_size: int=4
-dataset = "samsum_dataset" # alpaca_dataset,grammar_dataset
-peft_method: str = "lora" # None , llama_adapter, prefix
-use_peft: bool=False
-output_dir: str = "./ft-output"
-freeze_layers: bool = False
-num_freeze_layers: int = 1
-quantization: bool = False
-one_gpu: bool = False
-save_model: bool = False
-dist_checkpoint_root_folder: str="model_checkpoints"
-dist_checkpoint_folder: str="fine-tuned"
-save_optimizer: bool=False
-flop_counter: bool=False # Enable FLOPS counter to measure model throughput, can not be used with pytorch profiler at the same time.
-flop_counter_start: int=3 # The step to start profiling, default is 3, which means after 3 steps of warm-up stage, the profiler will start to count FLOPS.
-use_profiler: bool=False # Enable pytorch profiler, can not be used with FLOPS counter at the same time.
-profiler_dir: str="PATH/to/save/profiler/results" # will be used if using profiler
+    model_name: str="PATH/to/Model"
+    tokenizer_name: str=None
+    enable_fsdp: bool=False
+    low_cpu_fsdp: bool=False
+    run_validation: bool=True
+    batch_size_training: int=4
+    batching_strategy: str="packing" #alternative: padding
+    context_length: int=4096
+    gradient_accumulation_steps: int=1
+    gradient_clipping: bool = False
+    gradient_clipping_threshold: float = 1.0
+    num_epochs: int=3
+    max_train_step: int=0
+    max_eval_step: int=0
+    num_workers_dataloader: int=1
+    lr: float=1e-4
+    weight_decay: float=0.0
+    gamma: float= 0.85
+    seed: int=42
+    use_fp16: bool=False
+    mixed_precision: bool=True
+    val_batch_size: int=1
+    dataset = "samsum_dataset"
+    peft_method: str = "lora" # None,llama_adapter, prefix
+    use_peft: bool=False
+    output_dir: str = "PATH/to/save/PEFT/model"
+    freeze_layers: bool = False
+    num_freeze_layers: int = 1
+    quantization: bool = False
+    one_gpu: bool = False
+    save_model: bool = True
+    dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
+    dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
+    save_optimizer: bool=False # will be used if using FSDP
+    use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    use_wandb: bool = False # Enable wandb for experient tracking
+    save_metrics: bool = False # saves training metrics to a json file for later plotting
+    flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
+    flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
+    use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
+    profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
 ```
 
 * [Datasets config file](../src/llama_recipes/configs/datasets.py) provides the available options for datasets.

+ 41 - 31
recipes/finetuning/README.md

@@ -23,37 +23,47 @@ If you are new to fine-tuning techniques, check out an overview: [](./LLM_finetu
 It lets us specify the training settings for everything from `model_name` to `dataset_name`, `batch_size` and so on. Below is the list of supported settings:
 
 ```python
-
-model_name: str="PATH/to/LLAMA 2/7B"
-enable_fsdp: bool=False
-run_validation: bool=True
-batch_size_training: int=4
-gradient_accumulation_steps: int=1
-max_train_step: int=0
-max_eval_step: int=0
-num_epochs: int=3
-num_workers_dataloader: int=2
-lr: float=2e-4
-weight_decay: float=0.0
-gamma: float=0.85
-use_fp16: bool=False
-mixed_precision: bool=True
-val_batch_size: int=4
-dataset = "samsum_dataset" # alpaca_dataset, grammar_dataset
-peft_method: str="lora" # None , llama_adapter, prefix
-use_peft: bool=False
-output_dir: str="./ft-output"
-freeze_layers: bool = False
-num_freeze_layers: int = 1
-quantization: bool = False
-save_model: bool = False
-dist_checkpoint_root_folder: str="model_checkpoints"
-dist_checkpoint_folder: str="fine-tuned"
-save_optimizer: bool=False
-flop_counter: bool=False # Enable FLOPS counter to measure model throughput, can not be used with pytorch profiler at the same time.
-flop_counter_start: int=3 # The step to start profiling, default is 3, which means after 3 steps of warm-up stage, the profiler will start to count FLOPS.
-use_profiler: bool=False # Enable pytorch profiler, can not be used with FLOPS counter at the same time.
-profiler_dir: str="PATH/to/save/profiler/results" # will be used if using profiler
+    model_name: str="PATH/to/Model"
+    tokenizer_name: str=None
+    enable_fsdp: bool=False
+    low_cpu_fsdp: bool=False
+    run_validation: bool=True
+    batch_size_training: int=4
+    batching_strategy: str="packing" #alternative: padding
+    context_length: int=4096
+    gradient_accumulation_steps: int=1
+    gradient_clipping: bool = False
+    gradient_clipping_threshold: float = 1.0
+    num_epochs: int=3
+    max_train_step: int=0
+    max_eval_step: int=0
+    num_workers_dataloader: int=1
+    lr: float=1e-4
+    weight_decay: float=0.0
+    gamma: float= 0.85
+    seed: int=42
+    use_fp16: bool=False
+    mixed_precision: bool=True
+    val_batch_size: int=1
+    dataset = "samsum_dataset"
+    peft_method: str = "lora" # None,llama_adapter, prefix
+    use_peft: bool=False
+    output_dir: str = "PATH/to/save/PEFT/model"
+    freeze_layers: bool = False
+    num_freeze_layers: int = 1
+    quantization: bool = False
+    one_gpu: bool = False
+    save_model: bool = True
+    dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
+    dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
+    save_optimizer: bool=False # will be used if using FSDP
+    use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    use_wandb: bool = False # Enable wandb for experient tracking
+    save_metrics: bool = False # saves training metrics to a json file for later plotting
+    flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
+    flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
+    use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
+    profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
 ```
 
 * [Datasets config file](../../src/llama_recipes/configs/datasets.py) provides the available options for datasets.

+ 2 - 2
src/llama_recipes/configs/training.py

@@ -6,7 +6,7 @@ from dataclasses import dataclass
 
 @dataclass
 class train_config:
-    model_name: str="PATH/to/LLAMA/7B"
+    model_name: str="PATH/to/Model"
     tokenizer_name: str=None
     enable_fsdp: bool=False
     low_cpu_fsdp: bool=False
@@ -29,7 +29,7 @@ class train_config:
     mixed_precision: bool=True
     val_batch_size: int=1
     dataset = "samsum_dataset"
-    peft_method: str = "lora" # None , llama_adapter, prefix
+    peft_method: str = "lora" # None,llama_adapter, prefix
     use_peft: bool=False
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False

+ 34 - 28
src/llama_recipes/utils/flop_utils.py

@@ -1,5 +1,5 @@
 from typing import Any, Dict, List, Optional, Union
-
+import time
 import torch
 from torch.utils.flop_counter import FlopCounterMode
 
@@ -15,14 +15,12 @@ class FlopMeasure(FlopCounterMode):
 
     .. code-block:: python
 
-        mod = ...
-        flop_counter = FlopMeasure(mod)
+        model = ...
+        flop_counter = FlopMeasure(model,local_rank=0,warmup_step=3)
         for batch in enumerate(dataloader):
             with flop_counter:
-                if step == 3:
-                    flop_counter.start_counting()
-                mod(batch)
-                flop_counter.stop_counting()
+                model(batch)
+                flop_counter.step()
     """
 
     def __init__(
@@ -32,50 +30,58 @@ class FlopMeasure(FlopCounterMode):
         display: bool = True,
         custom_mapping: Dict[Any, Any] = None,
         rank=None,
+        warmup_step: int = 3,
     ):
         super().__init__(mods, depth, display, custom_mapping)
-        self.ready = False
         self.rank = rank
+        self.warmup_step = warmup_step
+        self.start_time = 0
+        self.end_time = 0
 
+    def step(self):
+        # decrease the warmup step by 1 for every step, so that the flop counting will start when warmup_step =0. Stop decreasing when warm_up reaches -1.
+        if self.warmup_step >= 0:
+            self.warmup_step -= 1
+        if self.warmup_step == 0 and self.start_time == 0:
+            self.start_time = time.time()
+        elif self.warmup_step == -1 and self.start_time != 0 and self.end_time == 0:
+            self.end_time = time.time()
     def __enter__(self):
-        self.ready = False
+        if self.warmup_step == 0:
+            self.start_time = time.time()
         super().__enter__()
         return self
-
+    def is_done(self):
+        return self.warmup_step == -1
     def get_total_flops(self):
         return super().get_total_flops()
-
+    def get_flops_per_sec(self):
+        if self.start_time == 0 or self.end_time == 0:
+            print("Warning: flop count did not finish correctly")
+            return 0
+        return super().get_total_flops()/ (self.end_time - self.start_time)
     def get_table(self, depth=2):
         return super().get_table(depth)
 
     def __exit__(self, *args):
-        self.ready = False
         if self.get_total_flops() == 0:
             print(
                 "Warning: did not record any flops this time. Skipping the flop report"
             )
         else:
-            self.stop_counting()
             if self.display:
                 if self.rank is None or self.rank == 0:
-                    print("self.flop_counts", self.get_total_flops())
+                    print("Total time used in this flop counting step is: {}".format(self.end_time - self.start_time))
+                    print("The total TFlop per second is: {}".format(self.get_flops_per_sec() / 1e12))
+                    print("The tflop_count table is below:")
                     print(self.get_table(self.depth))
             # Disable the display feature so that we don't print the table again
             self.display = False
         super().__exit__(*args)
 
-    def start_counting(self):
-        self.ready = True
-
-    def is_ready(self):
-        return self.ready
-
-    def stop_counting(self):
-        self.ready = False
-
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
-        # return the original output if not ready
-        if not self.ready:
-            return func(*args, **kwargs)
-        # otherwise, count the flops and return the original output
-        return super().__torch_dispatch__(func, types, args, kwargs)
+        # when warmup_step is 0, count the flops and return the original output
+        if self.warmup_step == 0:
+            return super().__torch_dispatch__(func, types, args, kwargs)
+        # otherwise, just return the original output
+        return func(*args, **kwargs)

+ 6 - 10
src/llama_recipes/utils/train_utils.py

@@ -59,9 +59,9 @@ def profile(cfg, local_rank=None):
         ) as torch_profiler:
             yield torch_profiler
     elif use_flop_counter:
-        if cfg.max_train_step > 0 and cfg.max_train_step < cfg.flop_counter_start:
-            raise ValueError(f"flop counter requires at least {cfg.flop_counter_start} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
-        with FlopMeasure(rank=local_rank) as flop_counter:
+        if cfg.max_train_step > 0 and cfg.max_train_step <= cfg.flop_counter_start:
+            raise ValueError(f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
+        with FlopMeasure(rank=local_rank,warmup_step=cfg.flop_counter_start) as flop_counter:
             yield flop_counter
     else:
         torch_profiler = contextlib.nullcontext()
@@ -135,9 +135,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         if not train_config.enable_fsdp or local_rank==0:
                             print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1)
                         break
-                    if train_config.flop_counter and total_train_steps == train_config.flop_counter_start:
-                        print("start flop counting at the step: ", total_train_steps)
-                        profile_context.start_counting()
                     for key in batch.keys():
                         if train_config.enable_fsdp:
                             if is_xpu_available():
@@ -183,11 +180,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             optimizer.step()
                             optimizer.zero_grad()
                             pbar.update(1)
-                    if train_config.use_profiler:
+                    if train_config.use_profiler or train_config.flop_counter:
                         profile_context.step()
-                    if train_config.flop_counter and profile_context.is_ready():
-                        TFlops = profile_context.get_total_flops() / 1e12
-                        profile_context.stop_counting()
+                    if train_config.flop_counter and profile_context.is_done():
+                        TFlops = profile_context.get_flops_per_sec() / 1e12
                     if wandb_run:
                         if not train_config.enable_fsdp or rank==0:
                             wandb_run.log({