12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- from typing import Any, Dict, List, Optional, Union
- import time
- import torch
- from torch.utils.flop_counter import FlopCounterMode
- class FlopMeasure(FlopCounterMode):
- """
- ``FlopMeasure`` is a customized context manager that counts the number of
- flops within its context. It is based on ``FlopCounterMode`` with additional start_counting() and stop_counting() function so that the flop counting
- will only start after the warmup stage.
- It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction.
- Example usage
- .. code-block:: python
- model = ...
- flop_counter = FlopMeasure(model,local_rank=0,warmup_step=3)
- for batch in enumerate(dataloader):
- with flop_counter:
- model(batch)
- flop_counter.step()
- """
- def __init__(
- self,
- mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
- depth: int = 2,
- display: bool = True,
- custom_mapping: Dict[Any, Any] = None,
- rank=None,
- warmup_step: int = 3,
- ):
- super().__init__(mods, depth, display, custom_mapping)
- self.rank = rank
- self.warmup_step = warmup_step
- self.start_time = 0
- self.end_time = 0
- def step(self):
- # decrease the warmup step by 1 for every step, so that the flop counting will start when warmup_step =0. Stop decreasing when warm_up reaches -1.
- if self.warmup_step >= 0:
- self.warmup_step -= 1
- if self.warmup_step == 0 and self.start_time == 0:
- self.start_time = time.time()
- elif self.warmup_step == -1 and self.start_time != 0 and self.end_time == 0:
- self.end_time = time.time()
- def __enter__(self):
- if self.warmup_step == 0:
- self.start_time = time.time()
- super().__enter__()
- return self
- def is_done(self):
- return self.warmup_step == -1
- def get_total_flops(self):
- return super().get_total_flops()
- def get_flops_per_sec(self):
- if self.start_time == 0 or self.end_time == 0:
- print("Warning: flop count did not finish correctly")
- return 0
- return super().get_total_flops()/ (self.end_time - self.start_time)
- def get_table(self, depth=2):
- return super().get_table(depth)
- def __exit__(self, *args):
- if self.get_total_flops() == 0:
- print(
- "Warning: did not record any flops this time. Skipping the flop report"
- )
- else:
- if self.display:
- if self.rank is None or self.rank == 0:
- print("Total time used in this flop counting step is: {}".format(self.end_time - self.start_time))
- print("The total TFlop per second is: {}".format(self.get_flops_per_sec() / 1e12))
- print("The tflop_count table is below:")
- print(self.get_table(self.depth))
- # Disable the display feature so that we don't print the table again
- self.display = False
- super().__exit__(*args)
- def __torch_dispatch__(self, func, types, args=(), kwargs=None):
- # when warmup_step is 0, count the flops and return the original output
- if self.warmup_step == 0:
- return super().__torch_dispatch__(func, types, args, kwargs)
- # otherwise, just return the original output
- return func(*args, **kwargs)
|