Jelajahi Sumber

Added a feature that allow users to use pytorch profiler or flop_counter to measure the performance during fine-tuning. (#433)

Kai Wu 11 bulan lalu
induk
melakukan
0ab53c269d

File diff ditekan karena terlalu besar
+ 48 - 27
docs/multi_gpu.md


File diff ditekan karena terlalu besar
+ 48 - 28
docs/single_gpu.md


File diff ditekan karena terlalu besar
+ 50 - 29
recipes/finetuning/README.md


File diff ditekan karena terlalu besar
+ 6 - 3
recipes/finetuning/multigpu_finetuning.md


File diff ditekan karena terlalu besar
+ 8 - 2
recipes/finetuning/singlegpu_finetuning.md


+ 1 - 0
requirements.txt

@@ -18,3 +18,4 @@ gradio
 chardet
 openai
 typing-extensions==4.8.0
+tabulate

+ 6 - 2
src/llama_recipes/configs/training.py

@@ -6,7 +6,7 @@ from dataclasses import dataclass
 
 @dataclass
 class train_config:
-    model_name: str="PATH/to/LLAMA/7B"
+    model_name: str="PATH/to/Model"
     tokenizer_name: str=None
     enable_fsdp: bool=False
     low_cpu_fsdp: bool=False
@@ -29,7 +29,7 @@ class train_config:
     mixed_precision: bool=True
     val_batch_size: int=1
     dataset = "samsum_dataset"
-    peft_method: str = "lora" # None , llama_adapter, prefix
+    peft_method: str = "lora" # None,llama_adapter, prefix
     use_peft: bool=False
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
@@ -43,3 +43,7 @@ class train_config:
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     use_wandb: bool = False # Enable wandb for experient tracking
     save_metrics: bool = False # saves training metrics to a json file for later plotting
+    flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
+    flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
+    use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
+    profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler

+ 87 - 0
src/llama_recipes/utils/flop_utils.py

@@ -0,0 +1,87 @@
+from typing import Any, Dict, List, Optional, Union
+import time
+import torch
+from torch.utils.flop_counter import FlopCounterMode
+
+
+class FlopMeasure(FlopCounterMode):
+    """
+    ``FlopMeasure`` is a customized context manager that counts the number of
+    flops within its context. It is based on ``FlopCounterMode`` with additional start_counting() and stop_counting() function so that the flop counting
+    will only start after the warmup stage.
+    It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction.
+
+    Example usage
+
+    .. code-block:: python
+
+        model = ...
+        flop_counter = FlopMeasure(model,local_rank=0,warmup_step=3)
+        for batch in enumerate(dataloader):
+            with flop_counter:
+                model(batch)
+                flop_counter.step()
+    """
+
+    def __init__(
+        self,
+        mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
+        depth: int = 2,
+        display: bool = True,
+        custom_mapping: Dict[Any, Any] = None,
+        rank=None,
+        warmup_step: int = 3,
+    ):
+        super().__init__(mods, depth, display, custom_mapping)
+        self.rank = rank
+        self.warmup_step = warmup_step
+        self.start_time = 0
+        self.end_time = 0
+
+    def step(self):
+        # decrease the warmup step by 1 for every step, so that the flop counting will start when warmup_step =0. Stop decreasing when warm_up reaches -1.
+        if self.warmup_step >= 0:
+            self.warmup_step -= 1
+        if self.warmup_step == 0 and self.start_time == 0:
+            self.start_time = time.time()
+        elif self.warmup_step == -1 and self.start_time != 0 and self.end_time == 0:
+            self.end_time = time.time()
+    def __enter__(self):
+        if self.warmup_step == 0:
+            self.start_time = time.time()
+        super().__enter__()
+        return self
+    def is_done(self):
+        return self.warmup_step == -1
+    def get_total_flops(self):
+        return super().get_total_flops()
+    def get_flops_per_sec(self):
+        if self.start_time == 0 or self.end_time == 0:
+            print("Warning: flop count did not finish correctly")
+            return 0
+        return super().get_total_flops()/ (self.end_time - self.start_time)
+    def get_table(self, depth=2):
+        return super().get_table(depth)
+
+    def __exit__(self, *args):
+        if self.get_total_flops() == 0:
+            print(
+                "Warning: did not record any flops this time. Skipping the flop report"
+            )
+        else:
+            if self.display:
+                if self.rank is None or self.rank == 0:
+                    print("Total time used in this flop counting step is: {}".format(self.end_time - self.start_time))
+                    print("The total TFlop per second is: {}".format(self.get_flops_per_sec() / 1e12))
+                    print("The tflop_count table is below:")
+                    print(self.get_table(self.depth))
+            # Disable the display feature so that we don't print the table again
+            self.display = False
+        super().__exit__(*args)
+
+    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+        # when warmup_step is 0, count the flops and return the original output
+        if self.warmup_step == 0:
+            return super().__torch_dispatch__(func, types, args, kwargs)
+        # otherwise, just return the original output
+        return func(*args, **kwargs)

+ 109 - 71
src/llama_recipes/utils/train_utils.py

@@ -8,6 +8,7 @@ from contextlib import nullcontext
 from pathlib import Path
 from pkg_resources import packaging
 from datetime import datetime
+import contextlib
 
 
 import torch
@@ -24,14 +25,48 @@ from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_
 from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 from accelerate.utils import is_xpu_available, is_ccl_available
-
+from llama_recipes.utils.flop_utils import FlopMeasure
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.padding_side = "left"
 
-# Converting Bytes to Megabytes
-def byte2mb(x):
-    return int(x / 2**20)
+@contextlib.contextmanager
+def profile(cfg, local_rank=None):
+    use_profiler: bool = cfg.use_profiler
+    use_flop_counter: bool = cfg.flop_counter
+    if use_flop_counter and use_profiler:
+        raise ValueError("Cannot use both profiler and flop counter")
+    if use_profiler:
+        # profiler needs a warmup stage to get the accurate profiling results
+        wait_step, warmup_step, active_step = 1, 2, 3
+        min_step = wait_step + warmup_step + active_step + 1
+        if cfg.max_train_step > 0 and cfg.max_train_step < min_step:
+            raise ValueError(f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
+        print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}")
+        with torch.profiler.profile(
+            activities=[
+                torch.profiler.ProfilerActivity.CPU,
+                torch.profiler.ProfilerActivity.CUDA,
+            ],
+            schedule=torch.profiler.schedule(wait=wait_step, warmup=warmup_step, active=active_step, repeat=1),
+            on_trace_ready=torch.profiler.tensorboard_trace_handler(
+                cfg.profiler_dir
+            ),
+            profile_memory=True,
+            with_stack=False,
+            with_flops=True,
+            record_shapes=True,
+        ) as torch_profiler:
+            yield torch_profiler
+    elif use_flop_counter:
+        if cfg.max_train_step > 0 and cfg.max_train_step <= cfg.flop_counter_start:
+            raise ValueError(f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}")
+        with FlopMeasure(rank=local_rank,warmup_step=cfg.flop_counter_start) as flop_counter:
+            yield flop_counter
+    else:
+        torch_profiler = contextlib.nullcontext()
+        yield None
+
 
 def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
     """
@@ -62,7 +97,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
 
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
-
     train_prep = []
     train_loss = []
     val_prep = []
@@ -92,73 +126,77 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             total_loss = 0.0
             total_length = len(train_dataloader)//gradient_accumulation_steps
             pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
-            for step, batch in enumerate(train_dataloader):
-                total_train_steps += 1
-                # stop when the maximum number of training steps is reached
-                if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
-                    max_steps_reached = True
-                    if not train_config.enable_fsdp or local_rank==0:
-                        print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
-                    break
-                for key in batch.keys():
-                    if train_config.enable_fsdp:
-                        if is_xpu_available():
-                            batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
+            with profile(train_config,local_rank) as profile_context:
+                for step, batch in enumerate(train_dataloader):
+                    total_train_steps += 1
+                    # stop when the maximum number of training steps is reached
+                    if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
+                        max_steps_reached = True
+                        if not train_config.enable_fsdp or local_rank==0:
+                            print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1)
+                        break
+                    for key in batch.keys():
+                        if train_config.enable_fsdp:
+                            if is_xpu_available():
+                                batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
+                            else:
+                                batch[key] = batch[key].to(local_rank)
                         else:
-                            batch[key] = batch[key].to(local_rank)
-                    else:
 
-                        if is_xpu_available():
-                            batch[key] = batch[key].to('xpu:0')
-                        else:
-                            batch[key] = batch[key].to('cuda:0')
-                with autocast():
-                    loss = model(**batch).loss
-                loss = loss / gradient_accumulation_steps
-                if train_config.save_metrics:
-                    train_step_loss.append(loss.detach().float().item())
-                    train_step_perplexity.append(float(torch.exp(loss.detach().float())))
-                total_loss += loss.detach().float()
-                if train_config.use_fp16:
-                    # if fp16 is enabled, use gradient scaler to handle gradient update
-                    scaler.scale(loss).backward()
-                    if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                        if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
-                            scaler.unscale_(optimizer)
-                            if train_config.enable_fsdp:
-                                model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                            if is_xpu_available():
+                                batch[key] = batch[key].to('xpu:0')
                             else:
-                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
-                        scaler.step(optimizer)
-                        scaler.update()
-                        optimizer.zero_grad()
-                        pbar.update(1)
-                else:
-                    # regular backpropagation when fp16 is not used
-                    loss.backward()
-                    if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                        if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
-                            if train_config.enable_fsdp:
-                                model.clip_grad_norm_(train_config.gradient_clipping_threshold)
-                            else:
-                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
-                        optimizer.step()
-                        optimizer.zero_grad()
-                        pbar.update(1)
-
-                if wandb_run:
-                    if not train_config.enable_fsdp or rank==0:
-                        wandb_run.log({
-                            'train/epoch': epoch + 1,
-                            'train/step': epoch * len(train_dataloader) + step,
-                            'train/loss': loss.detach().float(),
-                        })
-
-                pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
-
-                if train_config.save_metrics:
-                    save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
-            pbar.close()
+                                batch[key] = batch[key].to('cuda:0')
+                    with autocast():
+                        loss = model(**batch).loss
+                    loss = loss / gradient_accumulation_steps
+                    if train_config.save_metrics:
+                        train_step_loss.append(loss.detach().float().item())
+                        train_step_perplexity.append(float(torch.exp(loss.detach().float())))
+                    total_loss += loss.detach().float()
+                    if train_config.use_fp16:
+                        # if fp16 is enabled, use gradient scaler to handle gradient update
+                        scaler.scale(loss).backward()
+                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                            if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
+                                scaler.unscale_(optimizer)
+                                if train_config.enable_fsdp:
+                                    model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                                else:
+                                    torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
+                            scaler.step(optimizer)
+                            scaler.update()
+                            optimizer.zero_grad()
+                            pbar.update(1)
+                    else:
+                        # regular backpropagation when fp16 is not used
+                        loss.backward()
+                        if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                            if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
+                                if train_config.enable_fsdp:
+                                    model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                                else:
+                                    torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
+                            optimizer.step()
+                            optimizer.zero_grad()
+                            pbar.update(1)
+                    if train_config.use_profiler or train_config.flop_counter:
+                        profile_context.step()
+                    if train_config.flop_counter and profile_context.is_done():
+                        TFlops = profile_context.get_flops_per_sec() / 1e12
+                    if wandb_run:
+                        if not train_config.enable_fsdp or rank==0:
+                            wandb_run.log({
+                                'train/epoch': epoch + 1,
+                                'train/step': epoch * len(train_dataloader) + step,
+                                'train/loss': loss.detach().float(),
+                            })
+
+                    pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
+
+                    if train_config.save_metrics:
+                        save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
+                pbar.close()
 
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_times.append(epoch_end_time)
@@ -180,7 +218,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
         # Update the learning rate as needed
         lr_scheduler.step()
-
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
             if train_config.save_metrics:
@@ -266,7 +303,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     results["avg_checkpoint_time"] = avg_checkpoint_time
     if train_config.save_metrics:
         results["metrics_filename"] = metrics_filename
-
+    if train_config.flop_counter:
+        results["model_tflops"]= TFlops
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and not train_config.use_peft and rank==0:
         save_train_params(train_config, fsdp_config, rank)