Przeglądaj źródła

fixed typo and handling unexpected exit

Kai Wu 1 rok temu
rodzic
commit
a35519ee90

+ 4 - 0
src/llama_recipes/configs/training.py

@@ -42,3 +42,7 @@ class train_config:
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     use_wandb: bool = False # Enable wandb for experient tracking
     save_metrics: bool = False # saves training metrics to a json file for later plotting
+    flop_counter: bool = False # Enable Flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
+    flop_counter_startpoint: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops.
+    use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
+    profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler

+ 13 - 4
src/llama_recipes/utils/tflop_counter.py

@@ -445,19 +445,28 @@ class FlopCounterMode(TorchDispatchMode):
         return self
 
     def __exit__(self, *args):
-        self.stop_counting()
-        if self.display:
-            if self.rank is None or self.rank == 0:
-                print(self.get_table(self.depth))
+        if self.get_total_flops() == 0:
+            print("Warning: did not record any flops this time. Skipping the flop report")
+        else:
+            self.stop_counting()
+            if self.display:
+                if self.rank is None or self.rank == 0:
+                    print("exiting flop counter")
+                    print("self.flop_counts", self.flop_counts["Global"].values())
+                    print(self.get_table(self.depth))
         super().__exit__(*args)
     def start_counting(self):
         self.flop_counts.clear()
         self.ready = True
+        print("start_counting")
     def stop_counting(self):
         self.ready = False
+        print("stop_counting")
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         if not self.ready:
+            print("not ready yet, stop_counting")
             return
+        print("in torch_dispatch now")
         kwargs = kwargs if kwargs else {}
         out = func(*args, **kwargs)
         func_packet = func._overloadpacket

+ 4 - 3
src/llama_recipes/utils/train_utils.py

@@ -8,6 +8,7 @@ from contextlib import nullcontext
 from pathlib import Path
 from pkg_resources import packaging
 from datetime import datetime
+import contextlib
 
 
 import torch
@@ -55,7 +56,7 @@ def throughput_measure_context(cfg, local_rank=None):
     elif use_flop_counter:
         torch_profiler = contextlib.nullcontext()
         with FlopCounterMode(rank=local_rank) as flop_counter:
-            yeild flop_counter
+            yield flop_counter
     else:
         torch_profiler = contextlib.nullcontext()
         yield None
@@ -135,10 +136,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         if not train_config.enable_fsdp or local_rank==0:
                             print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
                         break
-                    if traing_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint:
+                    if train_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint:
                         print("start flop counting at the step: ", total_train_steps)
                         measure_context.start()
-                    if traing_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint + 1:
+                    if train_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint + 1:
                         print("stop flop counting at the step: ", total_train_steps)
                         TFlops = get_total_flops(flop_counter) / 1e12
                         measure_context.stop()