浏览代码

second draft of this feature, seems to be working now

Kai Wu 1 年之前
父节点
当前提交
34e0bf4c6e
共有 2 个文件被更改,包括 9 次插入14 次删除
  1. 3 6
      src/llama_recipes/utils/tflop_counter.py
  2. 6 8
      src/llama_recipes/utils/train_utils.py

+ 3 - 6
src/llama_recipes/utils/tflop_counter.py

@@ -451,22 +451,19 @@ class FlopCounterMode(TorchDispatchMode):
             self.stop_counting()
             if self.display:
                 if self.rank is None or self.rank == 0:
-                    print("exiting flop counter")
                     print("self.flop_counts", self.flop_counts["Global"].values())
                     print(self.get_table(self.depth))
         super().__exit__(*args)
     def start_counting(self):
         self.flop_counts.clear()
         self.ready = True
-        print("start_counting")
+    def is_ready(self):
+        return self.ready
     def stop_counting(self):
         self.ready = False
-        print("stop_counting")
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         if not self.ready:
-            print("not ready yet, stop_counting")
-            return
-        print("in torch_dispatch now")
+            return func(*args, **kwargs)
         kwargs = kwargs if kwargs else {}
         out = func(*args, **kwargs)
         func_packet = func._overloadpacket

+ 6 - 8
src/llama_recipes/utils/train_utils.py

@@ -38,7 +38,7 @@ def throughput_measure_context(cfg, local_rank=None):
     if use_flop_counter and use_profiler:
         raise ValueError("Cannot use both profiler and flop counter")
     if use_profiler:
-        print(f"profiling is activated and results will be saved in {cfg.profile_dir}")
+        print(f"profiling is activated and results will be saved in {cfg.profiler_dir}")
         with torch.profiler.profile(
             activities=[
                 torch.profiler.ProfilerActivity.CPU,
@@ -46,7 +46,7 @@ def throughput_measure_context(cfg, local_rank=None):
             ],
             schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1),
             on_trace_ready=torch.profiler.tensorboard_trace_handler(
-                cfg.profile_dir
+                cfg.profiler_dir
             ),
             profile_memory=True,
             with_stack=False,
@@ -138,11 +138,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         break
                     if train_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint:
                         print("start flop counting at the step: ", total_train_steps)
-                        measure_context.start()
-                    if train_config.flop_counter and total_train_steps == train_config.flop_counter_startpoint + 1:
-                        print("stop flop counting at the step: ", total_train_steps)
-                        TFlops = get_total_flops(flop_counter) / 1e12
-                        measure_context.stop()
+                        measure_context.start_counting()
                     for key in batch.keys():
                         if train_config.enable_fsdp:
                             if is_xpu_available():
@@ -188,7 +184,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             optimizer.step()
                             optimizer.zero_grad()
                             pbar.update(1)
-
+                    if train_config.flop_counter and measure_context.is_ready():
+                        TFlops = get_total_flops(measure_context) / 1e12
+                        measure_context.stop_counting()
                     if wandb_run:
                         if not train_config.enable_fsdp or rank==0:
                             wandb_run.log({