flop_utils.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from typing import Any, Dict, List, Optional, Union
  2. import time
  3. import torch
  4. from torch.utils.flop_counter import FlopCounterMode
  5. class FlopMeasure(FlopCounterMode):
  6. """
  7. ``FlopMeasure`` is a customized context manager that counts the number of
  8. flops within its context. It is based on ``FlopCounterMode`` with additional start_counting() and stop_counting() function so that the flop counting
  9. will only start after the warmup stage.
  10. It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction.
  11. Example usage
  12. .. code-block:: python
  13. model = ...
  14. flop_counter = FlopMeasure(model,local_rank=0,warmup_step=3)
  15. for batch in enumerate(dataloader):
  16. with flop_counter:
  17. model(batch)
  18. flop_counter.step()
  19. """
  20. def __init__(
  21. self,
  22. mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
  23. depth: int = 2,
  24. display: bool = True,
  25. custom_mapping: Dict[Any, Any] = None,
  26. rank=None,
  27. warmup_step: int = 3,
  28. ):
  29. super().__init__(mods, depth, display, custom_mapping)
  30. self.rank = rank
  31. self.warmup_step = warmup_step
  32. self.start_time = 0
  33. self.end_time = 0
  34. def step(self):
  35. # decrease the warmup step by 1 for every step, so that the flop counting will start when warmup_step =0. Stop decreasing when warm_up reaches -1.
  36. if self.warmup_step >= 0:
  37. self.warmup_step -= 1
  38. if self.warmup_step == 0 and self.start_time == 0:
  39. self.start_time = time.time()
  40. elif self.warmup_step == -1 and self.start_time != 0 and self.end_time == 0:
  41. self.end_time = time.time()
  42. def __enter__(self):
  43. if self.warmup_step == 0:
  44. self.start_time = time.time()
  45. super().__enter__()
  46. return self
  47. def is_done(self):
  48. return self.warmup_step == -1
  49. def get_total_flops(self):
  50. return super().get_total_flops()
  51. def get_flops_per_sec(self):
  52. if self.start_time == 0 or self.end_time == 0:
  53. print("Warning: flop count did not finish correctly")
  54. return 0
  55. return super().get_total_flops()/ (self.end_time - self.start_time)
  56. def get_table(self, depth=2):
  57. return super().get_table(depth)
  58. def __exit__(self, *args):
  59. if self.get_total_flops() == 0:
  60. print(
  61. "Warning: did not record any flops this time. Skipping the flop report"
  62. )
  63. else:
  64. if self.display:
  65. if self.rank is None or self.rank == 0:
  66. print("Total time used in this flop counting step is: {}".format(self.end_time - self.start_time))
  67. print("The total TFlop per second is: {}".format(self.get_flops_per_sec() / 1e12))
  68. print("The tflop_count table is below:")
  69. print(self.get_table(self.depth))
  70. # Disable the display feature so that we don't print the table again
  71. self.display = False
  72. super().__exit__(*args)
  73. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  74. # when warmup_step is 0, count the flops and return the original output
  75. if self.warmup_step == 0:
  76. return super().__torch_dispatch__(func, types, args, kwargs)
  77. # otherwise, just return the original output
  78. return func(*args, **kwargs)