| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates.
 
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
- import random
 
- import pytest
 
- import torch
 
- from llama_recipes.data.sampler import LengthBasedBatchSampler
 
- from llama_recipes.data.sampler import DistributedLengthBasedBatchSampler
 
- SAMPLES = 33
 
- @pytest.fixture
 
- def dataset():
 
-     random.seed(42)
 
-     dataset = []
 
-     def add_samples(ds, n, a, b):
 
-         for _ in range(n):
 
-             ds.append(random.randint(a,b) * [1,])
 
-     add_samples(dataset, SAMPLES // 2, 1,9)
 
-     add_samples(dataset, (SAMPLES // 2) + (SAMPLES % 2), 10,20)
 
-     
 
-     return random.sample(dataset, len(dataset))
 
-     
 
-     
 
- @pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
 
- def test_batch_sampler_array(dataset, batch_size, drop_last):
 
-     
 
-     sampler = LengthBasedBatchSampler(dataset, batch_size, drop_last)
 
-     
 
-     EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
 
-     
 
-     all_ids = [i for b in sampler for i in b]
 
-     assert len(set(all_ids)) == EXPECTED_LENGTH * batch_size if drop_last else len(dataset)
 
-     
 
-     assert len(sampler) == EXPECTED_LENGTH
 
-     is_long = [len(d)>=10 for d in dataset]
 
-     
 
-     def check_batch(batch):
 
-         return all(batch) or not any(batch)
 
-     
 
-     assert all(check_batch(is_long[i] for i in b) for b in sampler)
 
-     
 
-     
 
- @pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
 
- def test_batch_sampler_dict(dataset, batch_size, drop_last):
 
-     
 
-     dist_dataset = [{"input_ids": d, "attention_mask": d} for d in dataset]
 
-     
 
-     sampler = LengthBasedBatchSampler(dist_dataset, batch_size, drop_last)
 
-     
 
-     EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
 
-     
 
-     assert len(sampler) == EXPECTED_LENGTH
 
-     is_long = [len(d)>=10 for d in dataset]
 
-     
 
-     def check_batch(batch):
 
-         return all(batch) or not any(batch)
 
-     
 
-     assert all(check_batch(is_long[i] for i in b) for b in sampler)
 
-     
 
-     
 
- @pytest.mark.parametrize("batch_size", [2, 8])
 
- def test_dist_batch_sampling(dataset, batch_size):
 
-     sampler_1 = DistributedLengthBasedBatchSampler(
 
-         dataset,
 
-         batch_size=batch_size,
 
-         rank=0,
 
-         num_replicas=2,
 
-         shuffle=False,
 
-     )
 
-     sampler_2 = DistributedLengthBasedBatchSampler(
 
-         dataset,
 
-         batch_size=batch_size,
 
-         rank=1,
 
-         num_replicas=2,
 
-         shuffle=False,
 
-     )
 
-     
 
-     ids_1 = set(i for b in sampler_1 for i in b)
 
-     ids_2 = set(i for b in sampler_2 for i in b)
 
-     
 
-     assert ids_1.isdisjoint(ids_2)
 
-     assert len(ids_1)+len(ids_2) > 0
 
-     assert len(ids_1)+len(ids_2) == len(dataset) // batch_size  *  batch_size 
 
 
  |