|
@@ -2,8 +2,9 @@
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
import pytest
|
|
|
-from dataclasses import dataclass
|
|
|
from contextlib import nullcontext
|
|
|
+from dataclasses import dataclass
|
|
|
+from datasets import Dataset
|
|
|
from unittest.mock import patch
|
|
|
|
|
|
@dataclass
|
|
@@ -12,19 +13,23 @@ class Config:
|
|
|
|
|
|
EXPECTED_SAMPLE_NUMBER ={
|
|
|
"meta-llama/Llama-2-7b-hf": {
|
|
|
- "train": 96,
|
|
|
- "eval": 42,
|
|
|
+ "train": 4,
|
|
|
+ "eval": 37,
|
|
|
},
|
|
|
"meta-llama/Meta-Llama-3.1-8B-Instruct": {
|
|
|
- "train": 79,
|
|
|
- "eval": 34,
|
|
|
+ "train": 3,
|
|
|
+ "eval": 30,
|
|
|
},
|
|
|
"fake_llama": {
|
|
|
- "train": 50,
|
|
|
- "eval": 21,
|
|
|
+ "train": 2,
|
|
|
+ "eval": 17,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+fake_samsum_dataset = 2048*[{'id': '420',
|
|
|
+ 'dialogue': "Mario: It's a me, Mario!\nLuigi: It's a me, your brother!\nMario: I'm going to save the princess.\nLuigi: I'm going to help Mario.",
|
|
|
+ 'summary': 'Mario and Luigi are going to save the princess.'}]
|
|
|
+
|
|
|
@pytest.mark.skip_missing_tokenizer
|
|
|
@patch('llama_recipes.finetuning.train')
|
|
|
@patch('llama_recipes.finetuning.AutoTokenizer')
|
|
@@ -34,7 +39,9 @@ EXPECTED_SAMPLE_NUMBER ={
|
|
|
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
|
|
|
@patch('llama_recipes.finetuning.optim.AdamW')
|
|
|
@patch('llama_recipes.finetuning.StepLR')
|
|
|
+@patch('llama_recipes.datasets.samsum_dataset.datasets')
|
|
|
def test_packing(
|
|
|
+ datasets,
|
|
|
step_lr,
|
|
|
optimizer,
|
|
|
get_model,
|
|
@@ -55,6 +62,8 @@ def test_packing(
|
|
|
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
|
|
|
get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
|
|
|
get_config.return_value = Config(model_type=model_type)
|
|
|
+
|
|
|
+ datasets.load_dataset.return_value = Dataset.from_list(fake_samsum_dataset)
|
|
|
|
|
|
kwargs = {
|
|
|
"model_name": llama_version,
|
|
@@ -106,7 +115,9 @@ def test_packing(
|
|
|
@patch('llama_recipes.finetuning.FSDP')
|
|
|
@patch('llama_recipes.finetuning.torch.distributed.is_initialized')
|
|
|
@patch('llama_recipes.utils.config_utils.dist')
|
|
|
+@patch('llama_recipes.datasets.samsum_dataset.datasets')
|
|
|
def test_distributed_packing(
|
|
|
+ datasets,
|
|
|
dist,
|
|
|
is_initialized,
|
|
|
fsdp,
|
|
@@ -137,6 +148,8 @@ def test_distributed_packing(
|
|
|
cuda_is_available.return_value = False
|
|
|
cuda_is_bf16_supported.return_value = False
|
|
|
|
|
|
+ datasets.load_dataset.return_value = Dataset.from_list(fake_samsum_dataset)
|
|
|
+
|
|
|
rank = 1
|
|
|
os.environ['LOCAL_RANK'] = f'{rank}'
|
|
|
os.environ['RANK'] = f'{rank}'
|