Просмотр исходного кода

formatted and removed duplicated or unused function get_total_flops() and byte2mb()

Kai Wu 1 год назад
Родитель
Сommit
41434dc825
2 измененных файлов с 15 добавлено и 10 удалено
  1. 14 2
      src/llama_recipes/utils/flop_utils.py
  2. 1 8
      src/llama_recipes/utils/train_utils.py

+ 14 - 2
src/llama_recipes/utils/flop_utils.py

@@ -1,6 +1,8 @@
+from typing import Any, Dict, List, Optional, Union
+
 import torch
 import torch
 from torch.utils.flop_counter import FlopCounterMode
 from torch.utils.flop_counter import FlopCounterMode
-from typing import List, Any, Dict, Optional, Union
+
 
 
 class FlopMeasure(FlopCounterMode):
 class FlopMeasure(FlopCounterMode):
     """
     """
@@ -22,6 +24,7 @@ class FlopMeasure(FlopCounterMode):
                 mod(batch)
                 mod(batch)
                 flop_counter.stop_counting()
                 flop_counter.stop_counting()
     """
     """
+
     def __init__(
     def __init__(
         self,
         self,
         mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
         mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
@@ -33,18 +36,24 @@ class FlopMeasure(FlopCounterMode):
         super().__init__(mods, depth, display, custom_mapping)
         super().__init__(mods, depth, display, custom_mapping)
         self.ready = False
         self.ready = False
         self.rank = rank
         self.rank = rank
+
     def __enter__(self):
     def __enter__(self):
         self.ready = False
         self.ready = False
         super().__enter__()
         super().__enter__()
         return self
         return self
+
     def get_total_flops(self):
     def get_total_flops(self):
         return super().get_total_flops()
         return super().get_total_flops()
+
     def get_table(self, depth=2):
     def get_table(self, depth=2):
         return super().get_table(depth)
         return super().get_table(depth)
+
     def __exit__(self, *args):
     def __exit__(self, *args):
         self.ready = False
         self.ready = False
         if self.get_total_flops() == 0:
         if self.get_total_flops() == 0:
-            print("Warning: did not record any flops this time. Skipping the flop report")
+            print(
+                "Warning: did not record any flops this time. Skipping the flop report"
+            )
         else:
         else:
             self.stop_counting()
             self.stop_counting()
             if self.display:
             if self.display:
@@ -57,10 +66,13 @@ class FlopMeasure(FlopCounterMode):
 
 
     def start_counting(self):
     def start_counting(self):
         self.ready = True
         self.ready = True
+
     def is_ready(self):
     def is_ready(self):
         return self.ready
         return self.ready
+
     def stop_counting(self):
     def stop_counting(self):
         self.ready = False
         self.ready = False
+
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         # return the original output if not ready
         # return the original output if not ready
         if not self.ready:
         if not self.ready:

+ 1 - 8
src/llama_recipes/utils/train_utils.py

@@ -25,7 +25,6 @@ from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_
 from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 from llama_recipes.utils.memory_utils import MemoryTrace
 from accelerate.utils import is_xpu_available, is_ccl_available
 from accelerate.utils import is_xpu_available, is_ccl_available
-#from llama_recipes.utils.tflop_counter import FlopCounterMode
 from llama_recipes.utils.flop_utils import FlopMeasure
 from llama_recipes.utils.flop_utils import FlopMeasure
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.pad_token_id = 0
@@ -68,12 +67,6 @@ def throughput_measure_context(cfg, local_rank=None):
         torch_profiler = contextlib.nullcontext()
         torch_profiler = contextlib.nullcontext()
         yield None
         yield None
 
 
-def get_total_flops(model):
-    return (sum([v for _, v in model.flop_counts["Global"].items()]))
-
-# Converting Bytes to Megabytes
-def byte2mb(x):
-    return int(x / 2**20)
 
 
 def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
 def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
     """
     """
@@ -194,7 +187,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.use_profiler:
                     if train_config.use_profiler:
                         measure_context.step()
                         measure_context.step()
                     if train_config.flop_counter and measure_context.is_ready():
                     if train_config.flop_counter and measure_context.is_ready():
-                        TFlops = get_total_flops(measure_context) / 1e12
+                        TFlops = measure_context.get_total_flops() / 1e12
                         measure_context.stop_counting()
                         measure_context.stop_counting()
                     if wandb_run:
                     if wandb_run:
                         if not train_config.enable_fsdp or rank==0:
                         if not train_config.enable_fsdp or rank==0: