| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 | # Copyright (c) Meta Platforms, Inc. and affiliates.# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.import randomfrom itertools import isliceimport numpy as npimport torchclass LengthBasedBatchSampler(torch.utils.data.BatchSampler):    def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None:        if isinstance(next(iter(data_source)), dict):            first_key = next(iter(next(iter(data_source)).keys()))            self.lengths = [len(d[first_key]) for d in data_source]        else:            self.lengths = [len(d) for d in data_source]        self.batch_size = batch_size        self.drop_last = drop_last        self.shuffle = shuffle    def __iter__(self):        ids = np.argsort(self.lengths, kind='mergesort')        if self.drop_last:            ids = ids[:len(ids) // self.batch_size * self.batch_size]        batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]        if self.shuffle:            random.shuffle(batches)        for b in batches:            yield b    def __len__(self):        if self.drop_last:            return len(self.lengths) // self.batch_size        else:            return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):    def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None:        random.seed(seed)        self.batch_sampler = LengthBasedBatchSampler(            data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle            )        self.num_replicas = num_replicas        self.rank = rank    def __iter__(self):        max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas        return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)    def __len__(self):        return len(self.batch_sampler) // self.num_replicas
 |