data_samplers.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Dataloaders."""
  16. import torch
  17. import random
  18. from megatron import get_args
  19. from megatron import mpu
  20. def build_pretraining_data_loader(dataset, consumed_samples):
  21. """Buld dataloader given an input dataset."""
  22. if dataset is None:
  23. return None
  24. args = get_args()
  25. # Megatron sampler
  26. if args.dataloader_type == 'single':
  27. batch_sampler = MegatronPretrainingSampler(
  28. total_samples=len(dataset),
  29. consumed_samples=consumed_samples,
  30. micro_batch_size=args.micro_batch_size,
  31. data_parallel_rank=mpu.get_data_parallel_rank(),
  32. data_parallel_size=mpu.get_data_parallel_world_size())
  33. elif args.dataloader_type == 'cyclic':
  34. batch_sampler = MegatronPretrainingRandomSampler(
  35. total_samples=len(dataset),
  36. consumed_samples=consumed_samples,
  37. micro_batch_size=args.micro_batch_size,
  38. data_parallel_rank=mpu.get_data_parallel_rank(),
  39. data_parallel_size=mpu.get_data_parallel_world_size())
  40. else:
  41. raise Exception('{} dataloader type is not supported.'.format(
  42. args.dataloader_type))
  43. # Torch dataloader.
  44. return torch.utils.data.DataLoader(dataset,
  45. batch_sampler=batch_sampler,
  46. num_workers=args.num_workers,
  47. pin_memory=True)
  48. class MegatronPretrainingSampler:
  49. def __init__(self, total_samples, consumed_samples, micro_batch_size,
  50. data_parallel_rank, data_parallel_size, drop_last=True):
  51. # Keep a copy of input params for later use.
  52. self.total_samples = total_samples
  53. self.consumed_samples = consumed_samples
  54. self.micro_batch_size = micro_batch_size
  55. self.data_parallel_rank = data_parallel_rank
  56. self.micro_batch_times_data_parallel_size = \
  57. self.micro_batch_size * data_parallel_size
  58. self.drop_last = drop_last
  59. # Sanity checks.
  60. assert self.total_samples > 0, \
  61. 'no sample to consume: {}'.format(self.total_samples)
  62. assert self.consumed_samples < self.total_samples, \
  63. 'no samples left to consume: {}, {}'.format(self.consumed_samples,
  64. self.total_samples)
  65. assert self.micro_batch_size > 0
  66. assert data_parallel_size > 0
  67. assert self.data_parallel_rank < data_parallel_size, \
  68. 'data_parallel_rank should be smaller than data size: {}, ' \
  69. '{}'.format(self.data_parallel_rank, data_parallel_size)
  70. def __len__(self):
  71. return self.total_samples
  72. def get_start_end_idx(self):
  73. start_idx = self.data_parallel_rank * self.micro_batch_size
  74. end_idx = start_idx + self.micro_batch_size
  75. return start_idx, end_idx
  76. def __iter__(self):
  77. batch = []
  78. # Last batch will be dropped if drop_last is not set False
  79. for idx in range(self.consumed_samples, self.total_samples):
  80. batch.append(idx)
  81. if len(batch) == self.micro_batch_times_data_parallel_size:
  82. start_idx, end_idx = self.get_start_end_idx()
  83. yield batch[start_idx:end_idx]
  84. batch = []
  85. # Check the last partial batch and see drop_last is set
  86. if len(batch) > 0 and not self.drop_last:
  87. start_idx, end_idx = self.get_start_end_idx()
  88. yield batch[start_idx:end_idx]
  89. class MegatronPretrainingRandomSampler:
  90. def __init__(self, total_samples, consumed_samples, micro_batch_size,
  91. data_parallel_rank, data_parallel_size):
  92. # Keep a copy of input params for later use.
  93. self.total_samples = total_samples
  94. self.consumed_samples = consumed_samples
  95. self.micro_batch_size = micro_batch_size
  96. self.data_parallel_rank = data_parallel_rank
  97. self.data_parallel_size = data_parallel_size
  98. self.micro_batch_times_data_parallel_size = \
  99. self.micro_batch_size * data_parallel_size
  100. self.last_batch_size = \
  101. self.total_samples % self.micro_batch_times_data_parallel_size
  102. # Sanity checks.
  103. assert self.total_samples > 0, \
  104. 'no sample to consume: {}'.format(self.total_samples)
  105. assert self.micro_batch_size > 0
  106. assert data_parallel_size > 0
  107. assert self.data_parallel_rank < data_parallel_size, \
  108. 'data_parallel_rank should be smaller than data size: {}, ' \
  109. '{}'.format(self.data_parallel_rank, data_parallel_size)
  110. def __len__(self):
  111. return self.total_samples
  112. def __iter__(self):
  113. active_total_samples = self.total_samples - self.last_batch_size
  114. self.epoch = self.consumed_samples // active_total_samples
  115. current_epoch_samples = self.consumed_samples % active_total_samples
  116. assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
  117. # data sharding and random sampling
  118. bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
  119. * self.micro_batch_size
  120. bucket_offset = current_epoch_samples // self.data_parallel_size
  121. start_idx = self.data_parallel_rank * bucket_size
  122. g = torch.Generator()
  123. g.manual_seed(self.epoch)
  124. random_idx = torch.randperm(bucket_size, generator=g).tolist()
  125. idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
  126. batch = []
  127. # Last batch if not complete will be dropped.
  128. for idx in idx_range:
  129. batch.append(idx)
  130. if len(batch) == self.micro_batch_size:
  131. self.consumed_samples += self.micro_batch_times_data_parallel_size
  132. yield batch
  133. batch = []