|
@@ -1,32 +1,36 @@
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
-import pytest
|
|
|
from dataclasses import dataclass
|
|
|
from functools import partial
|
|
|
from unittest.mock import patch
|
|
|
+
|
|
|
+import pytest
|
|
|
from datasets import load_dataset
|
|
|
|
|
|
+
|
|
|
@dataclass
|
|
|
class Config:
|
|
|
model_type: str = "llama"
|
|
|
|
|
|
+
|
|
|
try:
|
|
|
- load_dataset("Samsung/samsum")
|
|
|
+ load_dataset("knkarthick/samsum")
|
|
|
SAMSUM_UNAVAILABLE = False
|
|
|
except ValueError:
|
|
|
SAMSUM_UNAVAILABLE = True
|
|
|
|
|
|
+
|
|
|
@pytest.mark.skipif(SAMSUM_UNAVAILABLE, reason="Samsum dataset is unavailable")
|
|
|
@pytest.mark.skip_missing_tokenizer
|
|
|
-@patch('llama_cookbook.finetuning.train')
|
|
|
-@patch('llama_cookbook.finetuning.AutoTokenizer')
|
|
|
+@patch("llama_cookbook.finetuning.train")
|
|
|
+@patch("llama_cookbook.finetuning.AutoTokenizer")
|
|
|
@patch("llama_cookbook.finetuning.AutoConfig.from_pretrained")
|
|
|
@patch("llama_cookbook.finetuning.AutoProcessor")
|
|
|
@patch("llama_cookbook.finetuning.MllamaForConditionalGeneration.from_pretrained")
|
|
|
-@patch('llama_cookbook.finetuning.LlamaForCausalLM.from_pretrained')
|
|
|
-@patch('llama_cookbook.finetuning.optim.AdamW')
|
|
|
-@patch('llama_cookbook.finetuning.StepLR')
|
|
|
+@patch("llama_cookbook.finetuning.LlamaForCausalLM.from_pretrained")
|
|
|
+@patch("llama_cookbook.finetuning.optim.AdamW")
|
|
|
+@patch("llama_cookbook.finetuning.StepLR")
|
|
|
def test_samsum_dataset(
|
|
|
step_lr,
|
|
|
optimizer,
|
|
@@ -39,11 +43,13 @@ def test_samsum_dataset(
|
|
|
mocker,
|
|
|
setup_tokenizer,
|
|
|
llama_version,
|
|
|
- ):
|
|
|
+):
|
|
|
from llama_cookbook.finetuning import main
|
|
|
|
|
|
setup_tokenizer(tokenizer)
|
|
|
- get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
|
|
|
+ get_model.return_value.get_input_embeddings.return_value.weight.shape = [
|
|
|
+ 32000 if "Llama-2" in llama_version else 128256
|
|
|
+ ]
|
|
|
get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
|
|
|
get_config.return_value = Config()
|
|
|
|
|
@@ -55,7 +61,7 @@ def test_samsum_dataset(
|
|
|
"use_peft": False,
|
|
|
"dataset": "samsum_dataset",
|
|
|
"batching_strategy": "padding",
|
|
|
- }
|
|
|
+ }
|
|
|
|
|
|
main(**kwargs)
|
|
|
|