microbatches.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Megatron number of micro-batches calculators."""
  16. from abc import ABC
  17. from abc import abstractmethod
  18. def build_num_microbatches_calculator(args):
  19. # Constant num micro-batches.
  20. if args.rampup_batch_size is None:
  21. num_microbatches_calculator = ConstantNumMicroBatches(
  22. args.global_batch_size, args.micro_batch_size,
  23. args.data_parallel_size)
  24. if args.rank == 0:
  25. print('setting number of micro-batches to constant {}'.format(
  26. num_microbatches_calculator.get()), flush=True)
  27. else:
  28. assert len(args.rampup_batch_size) == 3, 'expected the following ' \
  29. 'format: --rampup-batch-size <start batch size> ' \
  30. '<batch size incerement> <ramp-up samples>'
  31. start_batch_size = int(args.rampup_batch_size[0])
  32. batch_size_increment = int(args.rampup_batch_size[1])
  33. ramup_samples = int(args.rampup_batch_size[2])
  34. if args.rank == 0:
  35. print('will use batch size rampup starting from global batch '
  36. 'size {} to global batch size {} with batch size increments '
  37. '{} over {} samples.'.format(start_batch_size,
  38. args.global_batch_size,
  39. batch_size_increment,
  40. ramup_samples), flush=True)
  41. num_microbatches_calculator = RampupBatchsizeNumMicroBatches(
  42. start_batch_size, batch_size_increment, ramup_samples,
  43. args.global_batch_size, args.micro_batch_size,
  44. args.data_parallel_size)
  45. return num_microbatches_calculator
  46. class NumMicroBatchesCalculator(ABC):
  47. def __init__(self):
  48. self.num_micro_batches = None
  49. self.current_global_batch_size = None
  50. def get(self):
  51. return self.num_micro_batches
  52. def get_current_global_batch_size(self):
  53. return self.current_global_batch_size
  54. @abstractmethod
  55. def update(self, consumed_samples, consistency_check):
  56. pass
  57. class ConstantNumMicroBatches(NumMicroBatchesCalculator):
  58. def __init__(self, global_batch_size, micro_batch_size, data_parallel_size):
  59. micro_batch_times_data_parallel = micro_batch_size * \
  60. data_parallel_size
  61. assert global_batch_size % micro_batch_times_data_parallel == 0, \
  62. 'global batch size ({}) is not divisible by micro batch size ({})' \
  63. ' times data parallel size ({})'.format(global_batch_size,
  64. micro_batch_size,
  65. data_parallel_size)
  66. self.num_micro_batches = global_batch_size // \
  67. micro_batch_times_data_parallel
  68. assert self.num_micro_batches >= 1
  69. self.current_global_batch_size = global_batch_size
  70. def update(self, consumed_samples, consistency_check):
  71. pass
  72. class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
  73. def __init__(self, start_batch_size, batch_size_increment, ramup_samples,
  74. global_batch_size, micro_batch_size, data_parallel_size):
  75. """Batch size ramp up.
  76. Over
  77. steps = (global-batch-size - start-batch-size) / batch_size_increment
  78. increment batch size from start-batch-size to global-batch-size using
  79. rampup-samples / steps
  80. samples.
  81. Arguments:
  82. start_batch_size: global batch size to start with
  83. batch_size_increment: global batch size increments
  84. ramup_samples: number of samples to use ramp up global
  85. batch size from `start_batch_size` to `global_batch_size`
  86. global_batch_size: global batch size post rampup
  87. micro_batch_size: micro batch size
  88. data_parallel_size: data parallel size.
  89. """
  90. self.micro_batch_size = micro_batch_size
  91. self.data_parallel_size = data_parallel_size
  92. self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
  93. self.data_parallel_size
  94. assert self.micro_batch_times_data_parallel_size > 0
  95. assert start_batch_size > 0
  96. self.start_batch_size = start_batch_size
  97. assert global_batch_size > 0
  98. self.global_batch_size = global_batch_size
  99. diff_batch_size = self.global_batch_size - self.start_batch_size
  100. assert diff_batch_size >= 0
  101. assert batch_size_increment > 0
  102. self.batch_size_increment = batch_size_increment
  103. assert diff_batch_size % batch_size_increment == 0, 'expected ' \
  104. 'global batch size interval ({}) to be divisible by global batch ' \
  105. 'size increment ({})'.format(diff_batch_size, batch_size_increment)
  106. num_increments = diff_batch_size // self.batch_size_increment
  107. self.ramup_samples = ramup_samples
  108. assert self.ramup_samples >= 0
  109. self.rampup_samples_per_increment = self.ramup_samples / num_increments
  110. # Initialize number of microbatches.
  111. self.update(0, False)
  112. def update(self, consumed_samples, consistency_check):
  113. if consumed_samples > self.ramup_samples:
  114. self.current_global_batch_size = self.global_batch_size
  115. else:
  116. steps = int(consumed_samples / self.rampup_samples_per_increment)
  117. self.current_global_batch_size = self.start_batch_size + \
  118. steps * self.batch_size_increment
  119. assert self.current_global_batch_size <= self.global_batch_size
  120. if consistency_check:
  121. assert self.current_global_batch_size % \
  122. self.micro_batch_times_data_parallel_size == 0, 'current global ' \
  123. 'batch size ({}) is not divisible by micro-batch-size ({}) times' \
  124. 'data parallel size ({})'.format(self.current_global_batch_size,
  125. self.micro_batch_size,
  126. self.data_parallel_size)
  127. self.num_micro_batches = self.current_global_batch_size // \
  128. self.micro_batch_times_data_parallel_size