|
@@ -2,13 +2,14 @@
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
import random
|
|
import random
|
|
|
|
+from itertools import islice
|
|
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
import torch
|
|
import torch
|
|
|
|
|
|
|
|
|
|
class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
|
|
class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
|
|
- def __init__(self, data_source, batch_size, drop_last, randomize=True):
|
|
|
|
|
|
+ def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None:
|
|
if isinstance(next(iter(data_source)), dict):
|
|
if isinstance(next(iter(data_source)), dict):
|
|
first_key = next(iter(next(iter(data_source)).keys()))
|
|
first_key = next(iter(next(iter(data_source)).keys()))
|
|
self.lengths = [len(d[first_key]) for d in data_source]
|
|
self.lengths = [len(d[first_key]) for d in data_source]
|
|
@@ -16,7 +17,7 @@ class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
|
|
self.lengths = [len(d) for d in data_source]
|
|
self.lengths = [len(d) for d in data_source]
|
|
self.batch_size = batch_size
|
|
self.batch_size = batch_size
|
|
self.drop_last = drop_last
|
|
self.drop_last = drop_last
|
|
- self.randomize = randomize
|
|
|
|
|
|
+ self.shuffle = shuffle
|
|
|
|
|
|
def __iter__(self):
|
|
def __iter__(self):
|
|
ids = np.argsort(self.lengths)
|
|
ids = np.argsort(self.lengths)
|
|
@@ -25,7 +26,7 @@ class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
|
|
|
|
|
|
batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
|
|
batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
|
|
|
|
|
|
- if self.randomize:
|
|
|
|
|
|
+ if self.shuffle:
|
|
random.shuffle(batches)
|
|
random.shuffle(batches)
|
|
|
|
|
|
for b in batches:
|
|
for b in batches:
|
|
@@ -36,3 +37,21 @@ class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
|
|
return len(self.lengths) // self.batch_size
|
|
return len(self.lengths) // self.batch_size
|
|
else:
|
|
else:
|
|
return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
|
|
return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
|
|
|
|
+ def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None:
|
|
|
|
+ random.seed(seed)
|
|
|
|
+ self.batch_sampler = LengthBasedBatchSampler(
|
|
|
|
+ data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
|
|
|
|
+ )
|
|
|
|
+ self.num_replicas = num_replicas
|
|
|
|
+ self.rank = rank
|
|
|
|
+
|
|
|
|
+ def __iter__(self):
|
|
|
|
+ max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
|
|
|
|
+ return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
|
|
|
|
+
|
|
|
|
+ def __len__(self):
|
|
|
|
+ return len(self.batch_sampler) // self.num_replicas
|
|
|
|
+
|