123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Dataloaders."""
- import torch
- import random
- from megatron import get_args
- from megatron import mpu
- def build_pretraining_data_loader(dataset, consumed_samples):
- """Buld dataloader given an input dataset."""
- if dataset is None:
- return None
- args = get_args()
- # Megatron sampler
- if args.dataloader_type == 'single':
- batch_sampler = MegatronPretrainingSampler(
- total_samples=len(dataset),
- consumed_samples=consumed_samples,
- micro_batch_size=args.micro_batch_size,
- data_parallel_rank=mpu.get_data_parallel_rank(),
- data_parallel_size=mpu.get_data_parallel_world_size())
- elif args.dataloader_type == 'cyclic':
- batch_sampler = MegatronPretrainingRandomSampler(
- total_samples=len(dataset),
- consumed_samples=consumed_samples,
- micro_batch_size=args.micro_batch_size,
- data_parallel_rank=mpu.get_data_parallel_rank(),
- data_parallel_size=mpu.get_data_parallel_world_size())
- else:
- raise Exception('{} dataloader type is not supported.'.format(
- args.dataloader_type))
- # Torch dataloader.
- return torch.utils.data.DataLoader(dataset,
- batch_sampler=batch_sampler,
- num_workers=args.num_workers,
- pin_memory=True)
- class MegatronPretrainingSampler:
- def __init__(self, total_samples, consumed_samples, micro_batch_size,
- data_parallel_rank, data_parallel_size, drop_last=True):
- # Keep a copy of input params for later use.
- self.total_samples = total_samples
- self.consumed_samples = consumed_samples
- self.micro_batch_size = micro_batch_size
- self.data_parallel_rank = data_parallel_rank
- self.micro_batch_times_data_parallel_size = \
- self.micro_batch_size * data_parallel_size
- self.drop_last = drop_last
- # Sanity checks.
- assert self.total_samples > 0, \
- 'no sample to consume: {}'.format(self.total_samples)
- assert self.consumed_samples < self.total_samples, \
- 'no samples left to consume: {}, {}'.format(self.consumed_samples,
- self.total_samples)
- assert self.micro_batch_size > 0
- assert data_parallel_size > 0
- assert self.data_parallel_rank < data_parallel_size, \
- 'data_parallel_rank should be smaller than data size: {}, ' \
- '{}'.format(self.data_parallel_rank, data_parallel_size)
- def __len__(self):
- return self.total_samples
- def get_start_end_idx(self):
- start_idx = self.data_parallel_rank * self.micro_batch_size
- end_idx = start_idx + self.micro_batch_size
- return start_idx, end_idx
- def __iter__(self):
- batch = []
- # Last batch will be dropped if drop_last is not set False
- for idx in range(self.consumed_samples, self.total_samples):
- batch.append(idx)
- if len(batch) == self.micro_batch_times_data_parallel_size:
- start_idx, end_idx = self.get_start_end_idx()
- yield batch[start_idx:end_idx]
- batch = []
- # Check the last partial batch and see drop_last is set
- if len(batch) > 0 and not self.drop_last:
- start_idx, end_idx = self.get_start_end_idx()
- yield batch[start_idx:end_idx]
- class MegatronPretrainingRandomSampler:
- def __init__(self, total_samples, consumed_samples, micro_batch_size,
- data_parallel_rank, data_parallel_size):
- # Keep a copy of input params for later use.
- self.total_samples = total_samples
- self.consumed_samples = consumed_samples
- self.micro_batch_size = micro_batch_size
- self.data_parallel_rank = data_parallel_rank
- self.data_parallel_size = data_parallel_size
- self.micro_batch_times_data_parallel_size = \
- self.micro_batch_size * data_parallel_size
- self.last_batch_size = \
- self.total_samples % self.micro_batch_times_data_parallel_size
- # Sanity checks.
- assert self.total_samples > 0, \
- 'no sample to consume: {}'.format(self.total_samples)
- assert self.micro_batch_size > 0
- assert data_parallel_size > 0
- assert self.data_parallel_rank < data_parallel_size, \
- 'data_parallel_rank should be smaller than data size: {}, ' \
- '{}'.format(self.data_parallel_rank, data_parallel_size)
- def __len__(self):
- return self.total_samples
- def __iter__(self):
- active_total_samples = self.total_samples - self.last_batch_size
- self.epoch = self.consumed_samples // active_total_samples
- current_epoch_samples = self.consumed_samples % active_total_samples
- assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
- # data sharding and random sampling
- bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
- * self.micro_batch_size
- bucket_offset = current_epoch_samples // self.data_parallel_size
- start_idx = self.data_parallel_rank * bucket_size
-
- g = torch.Generator()
- g.manual_seed(self.epoch)
- random_idx = torch.randperm(bucket_size, generator=g).tolist()
- idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
- batch = []
- # Last batch if not complete will be dropped.
- for idx in idx_range:
- batch.append(idx)
- if len(batch) == self.micro_batch_size:
- self.consumed_samples += self.micro_batch_times_data_parallel_size
- yield batch
- batch = []
|