|
@@ -3,6 +3,7 @@
|
|
|
|
|
|
import pytest
|
|
|
from dataclasses import dataclass
|
|
|
+from contextlib import nullcontext
|
|
|
from unittest.mock import patch
|
|
|
|
|
|
@dataclass
|
|
@@ -19,14 +20,16 @@ EXPECTED_SAMPLE_NUMBER ={
|
|
|
"eval": 34,
|
|
|
},
|
|
|
"fake_llama": {
|
|
|
- "train": 48,
|
|
|
- "eval": 34,
|
|
|
+ "train": 50,
|
|
|
+ "eval": 21,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@patch('llama_recipes.finetuning.train')
|
|
|
@patch('llama_recipes.finetuning.AutoTokenizer')
|
|
|
@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
|
|
|
+@patch("llama_recipes.finetuning.AutoProcessor")
|
|
|
+@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
|
|
|
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
|
|
|
@patch('llama_recipes.finetuning.optim.AdamW')
|
|
|
@patch('llama_recipes.finetuning.StepLR')
|
|
@@ -34,17 +37,22 @@ def test_packing(
|
|
|
step_lr,
|
|
|
optimizer,
|
|
|
get_model,
|
|
|
+ get_mmodel,
|
|
|
+ processor,
|
|
|
get_config,
|
|
|
tokenizer,
|
|
|
train,
|
|
|
setup_tokenizer,
|
|
|
+ setup_processor,
|
|
|
llama_version,
|
|
|
model_type,
|
|
|
):
|
|
|
from llama_recipes.finetuning import main
|
|
|
|
|
|
setup_tokenizer(tokenizer)
|
|
|
+ setup_processor(processor)
|
|
|
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
|
|
|
+ get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
|
|
|
get_config.return_value = Config(model_type=model_type)
|
|
|
|
|
|
kwargs = {
|
|
@@ -56,35 +64,38 @@ def test_packing(
|
|
|
"batching_strategy": "packing",
|
|
|
}
|
|
|
|
|
|
- main(**kwargs)
|
|
|
+ c = nullcontext() if model_type == "llama" else pytest.raises(ValueError)
|
|
|
|
|
|
- assert train.call_count == 1
|
|
|
+ with c:
|
|
|
+ main(**kwargs)
|
|
|
+
|
|
|
+ if model_type == "llama":
|
|
|
+ assert train.call_count == 1
|
|
|
|
|
|
- args, kwargs = train.call_args
|
|
|
- train_dataloader = args[1]
|
|
|
- eval_dataloader = args[2]
|
|
|
+ args, kwargs = train.call_args
|
|
|
+ train_dataloader = args[1]
|
|
|
+ eval_dataloader = args[2]
|
|
|
|
|
|
- assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
|
|
|
- # assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
|
|
|
- # print(f"{len(eval_dataloader)=}")
|
|
|
+ assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
|
|
|
+ assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
|
|
|
|
|
|
- # batch = next(iter(train_dataloader))
|
|
|
+ batch = next(iter(train_dataloader))
|
|
|
|
|
|
- # assert "labels" in batch.keys()
|
|
|
- # assert "input_ids" in batch.keys()
|
|
|
- # assert "attention_mask" in batch.keys()
|
|
|
+ assert "labels" in batch.keys()
|
|
|
+ assert "input_ids" in batch.keys()
|
|
|
+ assert "attention_mask" in batch.keys()
|
|
|
|
|
|
- # # assert batch["labels"][0].size(0) == 4096
|
|
|
- # # assert batch["input_ids"][0].size(0) == 4096
|
|
|
- # # assert batch["attention_mask"][0].size(0) == 4096
|
|
|
- # print(batch["labels"][0].size(0))
|
|
|
- # print(batch["input_ids"][0].size(0))
|
|
|
- # print(batch["attention_mask"][0].size(0))
|
|
|
-
|
|
|
+ assert batch["labels"][0].size(0) == 4096
|
|
|
+ assert batch["input_ids"][0].size(0) == 4096
|
|
|
+ assert batch["attention_mask"][0].size(0) == 4096
|
|
|
|
|
|
|
|
|
+@patch("llama_recipes.finetuning.torch.cuda.is_available")
|
|
|
@patch('llama_recipes.finetuning.train')
|
|
|
@patch('llama_recipes.finetuning.AutoTokenizer')
|
|
|
+@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
|
|
|
+@patch("llama_recipes.finetuning.AutoProcessor")
|
|
|
+@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
|
|
|
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
|
|
|
@patch('llama_recipes.finetuning.optim.AdamW')
|
|
|
@patch('llama_recipes.finetuning.StepLR')
|
|
@@ -92,12 +103,34 @@ def test_packing(
|
|
|
@patch('llama_recipes.finetuning.FSDP')
|
|
|
@patch('llama_recipes.finetuning.torch.distributed.is_initialized')
|
|
|
@patch('llama_recipes.utils.config_utils.dist')
|
|
|
-def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
|
|
|
+def test_distributed_packing(
|
|
|
+ dist,
|
|
|
+ is_initialized,
|
|
|
+ fsdp,
|
|
|
+ setup,
|
|
|
+ step_lr,
|
|
|
+ optimizer,
|
|
|
+ get_model,
|
|
|
+ get_mmodel,
|
|
|
+ processor,
|
|
|
+ get_config,
|
|
|
+ tokenizer,
|
|
|
+ train,
|
|
|
+ cuda_is_available,
|
|
|
+ setup_tokenizer,
|
|
|
+ setup_processor,
|
|
|
+ llama_version,
|
|
|
+ model_type,
|
|
|
+ ):
|
|
|
import os
|
|
|
from llama_recipes.finetuning import main
|
|
|
|
|
|
setup_tokenizer(tokenizer)
|
|
|
+ setup_processor(processor)
|
|
|
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
|
|
|
+ get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
|
|
|
+ get_config.return_value = Config(model_type=model_type)
|
|
|
+ cuda_is_available.return_value = False
|
|
|
|
|
|
rank = 1
|
|
|
os.environ['LOCAL_RANK'] = f'{rank}'
|
|
@@ -120,13 +153,17 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
|
|
|
dist.get_rank.return_value = rank
|
|
|
dist.get_world_size.return_value = 2
|
|
|
|
|
|
- main(**kwargs)
|
|
|
+ c = nullcontext() if model_type == "llama" else pytest.raises(ValueError)
|
|
|
+
|
|
|
+ with c:
|
|
|
+ main(**kwargs)
|
|
|
|
|
|
- assert train.call_count == 1
|
|
|
+ if model_type == "llama":
|
|
|
+ assert train.call_count == 1
|
|
|
|
|
|
- args, kwargs = train.call_args
|
|
|
- train_dataloader = args[1]
|
|
|
- eval_dataloader = args[2]
|
|
|
+ args, kwargs = train.call_args
|
|
|
+ train_dataloader = args[1]
|
|
|
+ eval_dataloader = args[2]
|
|
|
|
|
|
- assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
|
|
|
- assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2
|
|
|
+ assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
|
|
|
+ assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2
|