Browse Source

Fix/unit test 3.2 (#726)

Sanyam Bhutani 6 months ago
parent
commit
b554b24b6e

+ 6 - 13
.github/workflows/pytest_cpu_gha_runner.yaml

@@ -1,16 +1,10 @@
 name: "[GHA][CPU] llama-recipes Pytest tests on CPU GitHub hosted runner."
 name: "[GHA][CPU] llama-recipes Pytest tests on CPU GitHub hosted runner."
 on:
 on:
   pull_request:
   pull_request:
-    branches:    
+    branches:
       - 'main'
       - 'main'
-    paths:
-      - 'src/llama-recipes/configs/*.py'
-      - 'src/llama-recipes/utils/*.py'
-      - 'src/llama-recipes/datasets/*.py'
-      - 'src/llama-recipes/data/*.py'
-      - 'src/llama-recipes/*.py'
 
 
-  # triggers workflow manually for debugging purposes.      
+  # triggers workflow manually for debugging purposes.
   workflow_dispatch:
   workflow_dispatch:
     inputs:
     inputs:
       runner:
       runner:
@@ -23,8 +17,8 @@ on:
           required: false
           required: false
           default: "true"
           default: "true"
 
 
-env: 
-  PYTORCH_WHEEL_URL: https://download.pytorch.org/whl/test/cu118  
+env:
+  PYTORCH_WHEEL_URL: https://download.pytorch.org/whl/test/cu118
 
 
 jobs:
 jobs:
   execute_workflow:
   execute_workflow:
@@ -63,7 +57,7 @@ jobs:
         id: install_llama_recipes_package
         id: install_llama_recipes_package
         run: |
         run: |
           echo "Installing 'llama-recipes' project (re: https://github.com/facebookresearch/llama-recipes?tab=readme-ov-file#install-with-optional-dependencies)"
           echo "Installing 'llama-recipes' project (re: https://github.com/facebookresearch/llama-recipes?tab=readme-ov-file#install-with-optional-dependencies)"
-          pip install --extra-index-url ${PYTORCH_WHEEL_URL} -e '.[tests]' 
+          pip install --extra-index-url ${PYTORCH_WHEEL_URL} -e '.[tests]'
 
 
 
 
       - name: "Running PyTest tests on GHA CPU Runner"
       - name: "Running PyTest tests on GHA CPU Runner"
@@ -71,11 +65,10 @@ jobs:
         run: |
         run: |
           echo "Running PyTest tests at 'GITHUB_WORKSPACE' path: ${GITHUB_WORKSPACE}"
           echo "Running PyTest tests at 'GITHUB_WORKSPACE' path: ${GITHUB_WORKSPACE}"
           cd $GITHUB_WORKSPACE && python3 -m pytest --junitxml="$GITHUB_WORKSPACE/result.xml"
           cd $GITHUB_WORKSPACE && python3 -m pytest --junitxml="$GITHUB_WORKSPACE/result.xml"
-  
+
       - name: Publish Test Summary
       - name: Publish Test Summary
         id: test_summary
         id: test_summary
         uses: test-summary/action@v2
         uses: test-summary/action@v2
         with:
         with:
           paths: "**/*.xml"
           paths: "**/*.xml"
         if: always()
         if: always()
-          

+ 0 - 1
src/llama_recipes/configs/datasets.py

@@ -9,7 +9,6 @@ class samsum_dataset:
     dataset: str =  "samsum_dataset"
     dataset: str =  "samsum_dataset"
     train_split: str = "train"
     train_split: str = "train"
     test_split: str = "validation"
     test_split: str = "validation"
-    trust_remote_code: bool = False
 
 
 
 
 @dataclass
 @dataclass

+ 14 - 3
src/llama_recipes/datasets/samsum_dataset.py

@@ -6,11 +6,22 @@
 import copy
 import copy
 import datasets
 import datasets
 
 
+from unittest.mock import patch
+
+@patch('builtins.input', return_value="N")
+def load_samsum(split, _):
+    try:
+        ds = datasets.load_dataset("Samsung/samsum", split=split)
+    except ValueError as e:
+        if "trust_remote_code" in str(e):
+          raise ValueError("Loading Samsung/samsum requires you to execute the dataset script in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set HF_DATASETS_TRUST_REMOTE_CODE env variable to True.") from e
+        else:
+          raise e
+    return ds
+
 
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
-    if not hasattr(dataset_config, "trust_remote_code") or not dataset_config.trust_remote_code:
-        raise ValueError("The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum. To activate `trust_remote_code` option use this config: --samsum_dataset.trust_remote_code=True")
-    dataset = datasets.load_dataset("samsum", split=split, trust_remote_code=dataset_config.trust_remote_code)
+    dataset = load_samsum(split)
 
 
     prompt = (
     prompt = (
         f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
         f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"

+ 1 - 1
src/llama_recipes/finetuning.py

@@ -289,7 +289,7 @@ def main(**kwargs):
         )
         )
         print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
         print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
         if len(eval_dataloader) == 0:
         if len(eval_dataloader) == 0:
-            raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
+            raise ValueError(f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})")
         else:
         else:
             print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
             print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
 
 

+ 28 - 11
src/tests/conftest.py

@@ -3,19 +3,27 @@
 
 
 import pytest
 import pytest
 
 
-from transformers import AutoTokenizer
+from utils import maybe_tokenizer
 
 
-ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
-LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct"]
+ACCESS_ERROR_MSG = "Could not access tokenizer. Did you log into huggingface hub and provided the correct token?"
+
+LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct", "fake_llama"]
+
+LLAMA_TOKENIZERS = {k: maybe_tokenizer(k) for k in LLAMA_VERSIONS}
 
 
 @pytest.fixture(params=LLAMA_VERSIONS)
 @pytest.fixture(params=LLAMA_VERSIONS)
 def llama_version(request):
 def llama_version(request):
     return request.param
     return request.param
 
 
 
 
+@pytest.fixture(params=["mllama", "llama"])
+def model_type(request):
+    return request.param
+
+
 @pytest.fixture(scope="module")
 @pytest.fixture(scope="module")
 def llama_tokenizer(request):
 def llama_tokenizer(request):
-    return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}
+    return LLAMA_TOKENIZERS
 
 
 
 
 @pytest.fixture
 @pytest.fixture
@@ -26,6 +34,13 @@ def setup_tokenizer(llama_tokenizer, llama_version):
 
 
     return _helper
     return _helper
 
 
+@pytest.fixture
+def setup_processor(llama_tokenizer, llama_version):
+    def _helper(processor_mock):
+        processor_mock.from_pretrained.return_value.tokenizer = llama_tokenizer[llama_version]
+
+    return _helper
+
 
 
 def pytest_addoption(parser):
 def pytest_addoption(parser):
     parser.addoption(
     parser.addoption(
@@ -38,16 +53,18 @@ def pytest_configure(config):
 
 
 
 
 def pytest_collection_modifyitems(config, items):
 def pytest_collection_modifyitems(config, items):
+    #skip tests marked with skip_missing_tokenizer if tokenizer is unavailable unless --unskip-missing-tokenizer is passed
     if config.getoption("--unskip-missing-tokenizer"):
     if config.getoption("--unskip-missing-tokenizer"):
         return
         return
 
 
-    try:
-        AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
-        tokenizer_available = True
-    except OSError:
-        tokenizer_available = False
-
     skip_missing_tokenizer = pytest.mark.skip(reason=ACCESS_ERROR_MSG)
     skip_missing_tokenizer = pytest.mark.skip(reason=ACCESS_ERROR_MSG)
     for item in items:
     for item in items:
-        if "skip_missing_tokenizer" in item.keywords and not tokenizer_available:
+        # get the tokenizer for the test
+        version = [v for v in LLAMA_VERSIONS for i in item.keywords if v in i]
+        if len(version) == 0:
+            # no tokenizer used in this test
+            continue
+        version = version.pop()
+        assert version in LLAMA_TOKENIZERS
+        if "skip_missing_tokenizer" in item.keywords and LLAMA_TOKENIZERS[version] is None:
             item.add_marker(skip_missing_tokenizer)
             item.add_marker(skip_missing_tokenizer)

+ 9 - 3
src/tests/datasets/test_custom_dataset.py

@@ -2,6 +2,7 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
 import pytest
 import pytest
+from contextlib import nullcontext
 from unittest.mock import patch
 from unittest.mock import patch
 
 
 from transformers import LlamaTokenizer
 from transformers import LlamaTokenizer
@@ -96,15 +97,17 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
 
 
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.AutoConfig.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
-def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker, llama_version):
+def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, get_config, train, mocker, llama_version):
     from llama_recipes.finetuning import main
     from llama_recipes.finetuning import main
 
 
     tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
     tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
+    get_config.return_value.model_type = "llama"
 
 
     kwargs = {
     kwargs = {
         "dataset": "custom_dataset",
         "dataset": "custom_dataset",
@@ -131,13 +134,16 @@ def test_tokenize_dialog(tokenizer, monkeypatch, setup_tokenizer, llama_version)
         {"role":"assistant", "content":"Romans"},
         {"role":"assistant", "content":"Romans"},
     ]
     ]
 
 
-    result = tokenize_dialog(dialog, tokenizer)
+    c = pytest.raises(AttributeError) if llama_version == "fake_llama" else nullcontext()
+
+    with c:
+        result = tokenize_dialog(dialog, tokenizer)
     
     
     if "Llama-2" in llama_version:
     if "Llama-2" in llama_version:
         assert result["labels"][:12] == [-100] * 12
         assert result["labels"][:12] == [-100] * 12
         assert result["labels"][17:28] == [-100] * 11
         assert result["labels"][17:28] == [-100] * 11
         assert result["labels"].count(-100) == 11 + 12
         assert result["labels"].count(-100) == 11 + 12
-    else:
+    elif "Llama-3" in llama_version:
         assert result["labels"][:38] == [-100] * 38
         assert result["labels"][:38] == [-100] * 38
         assert result["labels"][43:54] == [-100] * 11
         assert result["labels"][43:54] == [-100] * 11
         assert result["labels"].count(-100) == 38 + 11
         assert result["labels"].count(-100) == 38 + 11

+ 7 - 15
src/tests/datasets/test_grammar_datasets.py

@@ -1,32 +1,27 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
+from pathlib import Path
 import pytest
 import pytest
 from unittest.mock import patch
 from unittest.mock import patch
 
 
-
-EXPECTED_RESULTS = {
-    "meta-llama/Llama-2-7b-hf":{
-        "label": 1152,
-        "pos": 31,
-    },
-    "meta-llama/Meta-Llama-3.1-8B":{
-        "label": 40,
-        "pos": 26,
-    },
-}
+DATA_DIR = Path(__file__).parents[2] / "llama_recipes/datasets/grammar_dataset/"
 
 
 @pytest.mark.skip_missing_tokenizer
 @pytest.mark.skip_missing_tokenizer
+@pytest.mark.skipif(not Path(DATA_DIR / "grammar_validation.csv").exists(), reason="grammar_validation.csv not found")
+@pytest.mark.skipif(not Path(DATA_DIR / "gtrain_10k.csv").exists(), reason="gtrain_10k.csv not found")
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.AutoTokenizer')
+@patch('llama_recipes.finetuning.AutoConfig.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
-def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
+def test_grammar_dataset(step_lr, optimizer, get_model, get_config, tokenizer, train, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
     from llama_recipes.finetuning import main
 
 
     setup_tokenizer(tokenizer)
     setup_tokenizer(tokenizer)
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
+    get_config.return_value.model_type = "llama"
 
 
     BATCH_SIZE = 8
     BATCH_SIZE = 8
     kwargs = {
     kwargs = {
@@ -58,9 +53,6 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, setup_
     assert "input_ids" in batch.keys()
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
     assert "attention_mask" in batch.keys()
 
 
-    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
-    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
-
     token = args[3]
     token = args[3]
     assert batch["input_ids"][0][0] == token.bos_token_id
     assert batch["input_ids"][0][0] == token.bos_token_id
     assert batch["labels"][0][-1] == token.eos_token_id
     assert batch["labels"][0][-1] == token.eos_token_id

+ 30 - 14
src/tests/datasets/test_samsum_datasets.py

@@ -2,31 +2,50 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
 import pytest
 import pytest
+from dataclasses import dataclass
 from functools import partial
 from functools import partial
 from unittest.mock import patch
 from unittest.mock import patch
+from datasets import load_dataset
 
 
-EXPECTED_RESULTS = {
-    "meta-llama/Llama-2-7b-hf":{
-        "label": 8432,
-        "pos": 242,
-    },
-    "meta-llama/Meta-Llama-3.1-8B":{
-        "label": 2250,
-        "pos": 211,
-    },
-}
+@dataclass
+class Config:
+    model_type: str = "llama"
 
 
+try:
+    load_dataset("Samsung/samsum")
+    SAMSUM_UNAVAILABLE = False
+except ValueError:
+    SAMSUM_UNAVAILABLE = True
+
+@pytest.mark.skipif(SAMSUM_UNAVAILABLE, reason="Samsum dataset is unavailable")
 @pytest.mark.skip_missing_tokenizer
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.AutoTokenizer')
+@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
+@patch("llama_recipes.finetuning.AutoProcessor")
+@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
-def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
+def test_samsum_dataset(
+    step_lr,
+    optimizer,
+    get_model,
+    get_mmodel,
+    processor,
+    get_config,
+    tokenizer,
+    train,
+    mocker,
+    setup_tokenizer,
+    llama_version,
+    ):
     from llama_recipes.finetuning import main
     from llama_recipes.finetuning import main
 
 
     setup_tokenizer(tokenizer)
     setup_tokenizer(tokenizer)
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
+    get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
+    get_config.return_value = Config()
 
 
     BATCH_SIZE = 8
     BATCH_SIZE = 8
     kwargs = {
     kwargs = {
@@ -59,9 +78,6 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     assert "input_ids" in batch.keys()
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
     assert "attention_mask" in batch.keys()
 
 
-    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
-    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
-
     assert batch["input_ids"][0][0] == token.bos_token_id
     assert batch["input_ids"][0][0] == token.bos_token_id
     assert batch["labels"][0][-1] == token.eos_token_id
     assert batch["labels"][0][-1] == token.eos_token_id
     assert batch["input_ids"][0][-1] == token.eos_token_id
     assert batch["input_ids"][0][-1] == token.eos_token_id

+ 107 - 28
src/tests/test_batching.py

@@ -2,30 +2,68 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
 import pytest
 import pytest
+from contextlib import nullcontext
+from dataclasses import dataclass
+from datasets import Dataset
 from unittest.mock import patch
 from unittest.mock import patch
 
 
+@dataclass
+class Config:
+    model_type: str = "llama"
+
 EXPECTED_SAMPLE_NUMBER ={
 EXPECTED_SAMPLE_NUMBER ={
     "meta-llama/Llama-2-7b-hf": {
     "meta-llama/Llama-2-7b-hf": {
-        "train": 96,
-        "eval": 42,
+        "train": 4,
+        "eval": 37,
+    },
+    "meta-llama/Meta-Llama-3.1-8B-Instruct": {
+        "train": 3,
+        "eval": 30,
     },
     },
-    "meta-llama/Meta-Llama-3.1-8B": {
-        "train": 79,
-        "eval": 34,
+    "fake_llama": {
+        "train": 2,
+        "eval": 17,
     }
     }
 }
 }
 
 
+fake_samsum_dataset = 2048*[{'id': '420',
+ 'dialogue': "Mario: It's a me, Mario!\nLuigi: It's a me, your brother!\nMario: I'm going to save the princess.\nLuigi: I'm going to help Mario.",
+ 'summary': 'Mario and Luigi are going to save the princess.'}]
+
 @pytest.mark.skip_missing_tokenizer
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.AutoTokenizer')
+@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
+@patch("llama_recipes.finetuning.AutoProcessor")
+@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
-def test_packing(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
+@patch('llama_recipes.datasets.samsum_dataset.datasets')
+def test_packing(
+    datasets,
+    step_lr,
+    optimizer,
+    get_model,
+    get_mmodel,
+    processor,
+    get_config,
+    tokenizer,
+    train,
+    setup_tokenizer,
+    setup_processor,
+    llama_version,
+    model_type,
+    ):
     from llama_recipes.finetuning import main
     from llama_recipes.finetuning import main
 
 
     setup_tokenizer(tokenizer)
     setup_tokenizer(tokenizer)
+    setup_processor(processor)
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
+    get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
+    get_config.return_value = Config(model_type=model_type)
+
+    datasets.load_dataset.return_value = Dataset.from_list(fake_samsum_dataset)
     
     
     kwargs = {
     kwargs = {
         "model_name": llama_version,
         "model_name": llama_version,
@@ -36,31 +74,40 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, setup_tokenize
         "batching_strategy": "packing",
         "batching_strategy": "packing",
         }
         }
 
 
-    main(**kwargs)
+    c = nullcontext() if model_type == "llama" else  pytest.raises(ValueError)
 
 
-    assert train.call_count == 1
+    with c:
+        main(**kwargs)
+    
+    if model_type == "llama":
+        assert train.call_count == 1
 
 
-    args, kwargs = train.call_args
-    train_dataloader = args[1]
-    eval_dataloader = args[2]
+        args, kwargs = train.call_args
+        train_dataloader = args[1]
+        eval_dataloader = args[2]
 
 
-    assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
-    assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
+        assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
+        assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
 
 
-    batch = next(iter(train_dataloader))
+        batch = next(iter(train_dataloader))
 
 
-    assert "labels" in batch.keys()
-    assert "input_ids" in batch.keys()
-    assert "attention_mask" in batch.keys()
+        assert "labels" in batch.keys()
+        assert "input_ids" in batch.keys()
+        assert "attention_mask" in batch.keys()
 
 
-    assert batch["labels"][0].size(0) == 4096
-    assert batch["input_ids"][0].size(0) == 4096
-    assert batch["attention_mask"][0].size(0) == 4096
+        assert batch["labels"][0].size(0) == 4096
+        assert batch["input_ids"][0].size(0) == 4096
+        assert batch["attention_mask"][0].size(0) == 4096
 
 
 
 
 @pytest.mark.skip_missing_tokenizer
 @pytest.mark.skip_missing_tokenizer
+@patch("llama_recipes.utils.train_utils.torch.cuda.is_bf16_supported")
+@patch("llama_recipes.finetuning.torch.cuda.is_available")
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.AutoTokenizer')
+@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
+@patch("llama_recipes.finetuning.AutoProcessor")
+@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
@@ -68,12 +115,40 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, setup_tokenize
 @patch('llama_recipes.finetuning.FSDP')
 @patch('llama_recipes.finetuning.FSDP')
 @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
 @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
 @patch('llama_recipes.utils.config_utils.dist')
 @patch('llama_recipes.utils.config_utils.dist')
-def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
+@patch('llama_recipes.datasets.samsum_dataset.datasets')
+def test_distributed_packing(
+    datasets,
+    dist,
+    is_initialized,
+    fsdp,
+    setup,
+    step_lr,
+    optimizer,
+    get_model,
+    get_mmodel,
+    processor,
+    get_config,
+    tokenizer,
+    train,
+    cuda_is_available,
+    cuda_is_bf16_supported,
+    setup_tokenizer,
+    setup_processor,
+    llama_version,
+    model_type,
+    ):
     import os
     import os
     from llama_recipes.finetuning import main
     from llama_recipes.finetuning import main
 
 
     setup_tokenizer(tokenizer)
     setup_tokenizer(tokenizer)
+    setup_processor(processor)
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
+    get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
+    get_config.return_value = Config(model_type=model_type)
+    cuda_is_available.return_value = False
+    cuda_is_bf16_supported.return_value = False
+
+    datasets.load_dataset.return_value = Dataset.from_list(fake_samsum_dataset)
 
 
     rank = 1
     rank = 1
     os.environ['LOCAL_RANK'] = f'{rank}'
     os.environ['LOCAL_RANK'] = f'{rank}'
@@ -96,13 +171,17 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
     dist.get_rank.return_value = rank
     dist.get_rank.return_value = rank
     dist.get_world_size.return_value = 2
     dist.get_world_size.return_value = 2
 
 
-    main(**kwargs)
+    c = nullcontext() if model_type == "llama" else  pytest.raises(ValueError)
+
+    with c:
+        main(**kwargs)
 
 
-    assert train.call_count == 1
+    if model_type == "llama":
+        assert train.call_count == 1
 
 
-    args, kwargs = train.call_args
-    train_dataloader = args[1]
-    eval_dataloader = args[2]
+        args, kwargs = train.call_args
+        train_dataloader = args[1]
+        eval_dataloader = args[2]
 
 
-    assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
-    assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2
+        assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
+        assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2

+ 19 - 73
src/tests/test_chat_completion.py

@@ -1,6 +1,6 @@
 import sys
 import sys
 from pathlib import Path
 from pathlib import Path
-from typing import List, Literal, TypedDict
+from typing import List, TypedDict
 from unittest.mock import patch
 from unittest.mock import patch
 
 
 import pytest
 import pytest
@@ -8,46 +8,37 @@ import torch
 from llama_recipes.inference.chat_utils import read_dialogs_from_file
 from llama_recipes.inference.chat_utils import read_dialogs_from_file
 
 
 ROOT_DIR = Path(__file__).parents[2]
 ROOT_DIR = Path(__file__).parents[2]
-CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/inference/local_inference/chat_completion/"
+CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/quickstart/inference/local_inference/chat_completion/"
 
 
 sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path
 sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path
 
 
-Role = Literal["user", "assistant"]
-
-
-class Message(TypedDict):
-    role: Role
-    content: str
-
-
-Dialog = List[Message]
-
-B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
-
+default_system_prompt = [{"role": "system", "content": "Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n"}]
 
 
 def _encode_header(message, tokenizer):
 def _encode_header(message, tokenizer):
     tokens = []
     tokens = []
-    tokens.extend(tokenizer.encode("<|start_header_id|>"))
-    tokens.extend(tokenizer.encode(message["role"]))
-    tokens.extend(tokenizer.encode("<|end_header_id|>"))
-    tokens.extend(tokenizer.encode("\n\n"))
+    tokens.extend(tokenizer.encode("<|start_header_id|>", add_special_tokens=False))
+    tokens.extend(tokenizer.encode(message["role"], add_special_tokens=False))
+    tokens.extend(tokenizer.encode("<|end_header_id|>", add_special_tokens=False))
+    tokens.extend(tokenizer.encode("\n\n", add_special_tokens=False))
     return tokens
     return tokens
 
 
 
 
 def _encode_message(message, tokenizer):
 def _encode_message(message, tokenizer):
     tokens = _encode_header(message, tokenizer)
     tokens = _encode_header(message, tokenizer)
-    tokens.extend(tokenizer.encode(message["content"].strip()))
-    tokens.extend(tokenizer.encode("<|eot_id|>"))
+    tokens.extend(tokenizer.encode(message["content"], add_special_tokens=False))
+    tokens.extend(tokenizer.encode("<|eot_id|>", add_special_tokens=False))
     return tokens
     return tokens
 
 
 
 
 def _format_dialog(dialog, tokenizer):
 def _format_dialog(dialog, tokenizer):
     tokens = []
     tokens = []
-    tokens.extend(tokenizer.encode("<|begin_of_text|>"))
+    tokens.extend(tokenizer.encode("<|begin_of_text|>", add_special_tokens=False))
+    if dialog[0]["role"] == "system":
+        dialog[0]["content"] = default_system_prompt[0]["content"] + dialog[0]["content"]
+    else:
+        dialog = default_system_prompt + dialog
     for msg in dialog:
     for msg in dialog:
         tokens.extend(_encode_message(msg, tokenizer))
         tokens.extend(_encode_message(msg, tokenizer))
-    tokens.extend(_encode_header({"role": "assistant", "content": ""}, tokenizer))
     return tokens
     return tokens
 
 
 
 
@@ -55,59 +46,19 @@ def _format_tokens_llama3(dialogs, tokenizer):
     return [_format_dialog(dialog, tokenizer) for dialog in dialogs]
     return [_format_dialog(dialog, tokenizer) for dialog in dialogs]
 
 
 
 
-def _format_tokens_llama2(dialogs, tokenizer):
-    prompt_tokens = []
-    for dialog in dialogs:
-        if dialog[0]["role"] == "system":
-            dialog = [
-                {
-                    "role": dialog[1]["role"],
-                    "content": B_SYS
-                    + dialog[0]["content"]
-                    + E_SYS
-                    + dialog[1]["content"],
-                }
-            ] + dialog[2:]
-        assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
-            [msg["role"] == "assistant" for msg in dialog[1::2]]
-        ), (
-            "model only supports 'system','user' and 'assistant' roles, "
-            "starting with user and alternating (u/a/u/a/u...)"
-        )
-        """
-        Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
-        Here, we are adding it manually.
-        """
-        dialog_tokens: List[int] = sum(
-            [
-                tokenizer.encode(
-                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-                )
-                + [tokenizer.eos_token_id]
-                for prompt, answer in zip(dialog[::2], dialog[1::2])
-            ],
-            [],
-        )
-        assert (
-            dialog[-1]["role"] == "user"
-        ), f"Last message must be from user, got {dialog[-1]['role']}"
-        dialog_tokens += tokenizer.encode(
-            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
-        )
-        prompt_tokens.append(dialog_tokens)
-    return prompt_tokens
-
-
 @pytest.mark.skip_missing_tokenizer
 @pytest.mark.skip_missing_tokenizer
 @patch("chat_completion.AutoTokenizer")
 @patch("chat_completion.AutoTokenizer")
 @patch("chat_completion.load_model")
 @patch("chat_completion.load_model")
 def test_chat_completion(
 def test_chat_completion(
     load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
     load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
 ):
 ):
+    if "Llama-2" in llama_version or llama_version == "fake_llama":
+        pytest.skip(f"skipping test for {llama_version}")
+
     from chat_completion import main
     from chat_completion import main
 
 
     setup_tokenizer(tokenizer)
     setup_tokenizer(tokenizer)
-    load_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
+    load_model.return_value.get_input_embeddings.return_value.weight.shape = [128256]
 
 
     kwargs = {
     kwargs = {
         "prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(),
         "prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(),
@@ -116,13 +67,8 @@ def test_chat_completion(
     main(llama_version, **kwargs)
     main(llama_version, **kwargs)
 
 
     dialogs = read_dialogs_from_file(kwargs["prompt_file"])
     dialogs = read_dialogs_from_file(kwargs["prompt_file"])
-    format_tokens = (
-        _format_tokens_llama2
-        if llama_version == "meta-llama/Llama-2-7b-hf"
-        else _format_tokens_llama3
-    )
 
 
-    REF_RESULT = format_tokens(dialogs, llama_tokenizer[llama_version])
+    REF_RESULT = _format_tokens_llama3(dialogs, llama_tokenizer[llama_version])
 
 
     assert all(
     assert all(
         (
         (

+ 110 - 102
src/tests/test_finetuning.py

@@ -2,6 +2,8 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
 import os
 import os
+from contextlib import nullcontext
+from dataclasses import dataclass
 from unittest.mock import patch
 from unittest.mock import patch
 
 
 import pytest
 import pytest
@@ -16,8 +18,12 @@ from torch.utils.data.dataloader import DataLoader
 from torch.utils.data.sampler import BatchSampler
 from torch.utils.data.sampler import BatchSampler
 
 
 
 
+@dataclass
+class Config:
+    model_type: str = "llama"
+
 def get_fake_dataset():
 def get_fake_dataset():
-    return [
+    return 8192*[
         {
         {
             "input_ids": [1],
             "input_ids": [1],
             "attention_mask": [1],
             "attention_mask": [1],
@@ -28,28 +34,49 @@ def get_fake_dataset():
 
 
 @patch("llama_recipes.finetuning.torch.cuda.is_available")
 @patch("llama_recipes.finetuning.torch.cuda.is_available")
 @patch("llama_recipes.finetuning.train")
 @patch("llama_recipes.finetuning.train")
+@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
+@patch("llama_recipes.finetuning.AutoProcessor.from_pretrained")
 @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
 @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
+@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
 @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
 @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
 @patch("llama_recipes.finetuning.get_preprocessed_dataset")
 @patch("llama_recipes.finetuning.get_preprocessed_dataset")
+@patch("llama_recipes.finetuning.generate_peft_config")
+@patch("llama_recipes.finetuning.get_peft_model")
 @patch("llama_recipes.finetuning.optim.AdamW")
 @patch("llama_recipes.finetuning.optim.AdamW")
 @patch("llama_recipes.finetuning.StepLR")
 @patch("llama_recipes.finetuning.StepLR")
 @pytest.mark.parametrize("cuda_is_available", [True, False])
 @pytest.mark.parametrize("cuda_is_available", [True, False])
-def test_finetuning_no_validation(
+@pytest.mark.parametrize("run_validation", [True, False])
+@pytest.mark.parametrize("use_peft", [True, False])
+def test_finetuning(
     step_lr,
     step_lr,
     optimizer,
     optimizer,
+    get_peft_model,
+    gen_peft_config,
     get_dataset,
     get_dataset,
     tokenizer,
     tokenizer,
+    get_config,
     get_model,
     get_model,
+    get_processor,
+    get_mmodel,
     train,
     train,
     cuda,
     cuda,
     cuda_is_available,
     cuda_is_available,
+    run_validation,
+    use_peft,
+    model_type,
 ):
 ):
-    kwargs = {"run_validation": False}
+    kwargs = {
+        "run_validation": run_validation,
+        "use_peft": use_peft,
+        "batching_strategy": "packing" if model_type == "llama" else "padding",
+        }
 
 
     get_dataset.return_value = get_fake_dataset()
     get_dataset.return_value = get_fake_dataset()
     cuda.return_value = cuda_is_available
     cuda.return_value = cuda_is_available
 
 
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
+    get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
+    get_config.return_value = Config(model_type=model_type)
 
 
     main(**kwargs)
     main(**kwargs)
 
 
@@ -60,115 +87,59 @@ def test_finetuning_no_validation(
     eval_dataloader = args[2]
     eval_dataloader = args[2]
 
 
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(train_dataloader, DataLoader)
-    assert eval_dataloader is None
-
-    if cuda_is_available:
-        assert get_model.return_value.to.call_count == 1
-        assert get_model.return_value.to.call_args.args[0] == "cuda"
+    if run_validation:
+        assert isinstance(eval_dataloader, DataLoader)
     else:
     else:
-        assert get_model.return_value.to.call_count == 0
-
-
-@patch("llama_recipes.finetuning.torch.cuda.is_available")
-@patch("llama_recipes.finetuning.train")
-@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
-@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
-@patch("llama_recipes.finetuning.get_preprocessed_dataset")
-@patch("llama_recipes.finetuning.optim.AdamW")
-@patch("llama_recipes.finetuning.StepLR")
-@pytest.mark.parametrize("cuda_is_available", [True, False])
-def test_finetuning_with_validation(
-    step_lr,
-    optimizer,
-    get_dataset,
-    tokenizer,
-    get_model,
-    train,
-    cuda,
-    cuda_is_available,
-):
-    kwargs = {"run_validation": True}
-
-    get_dataset.return_value = get_fake_dataset()
-    cuda.return_value = cuda_is_available
-
-    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
-
-    main(**kwargs)
-
-    assert train.call_count == 1
-
-    args, kwargs = train.call_args
-    train_dataloader = args[1]
-    eval_dataloader = args[2]
-    assert isinstance(train_dataloader, DataLoader)
-    assert isinstance(eval_dataloader, DataLoader)
+        assert eval_dataloader is None
 
 
-    if cuda_is_available:
-        assert get_model.return_value.to.call_count == 1
-        assert get_model.return_value.to.call_args.args[0] == "cuda"
+    if use_peft:
+        assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
+        model = get_peft_model
+    elif model_type == "llama":
+        model = get_model
     else:
     else:
-        assert get_model.return_value.to.call_count == 0
-
-
-@patch("llama_recipes.finetuning.torch.cuda.is_available")
-@patch("llama_recipes.finetuning.train")
-@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
-@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
-@patch("llama_recipes.finetuning.get_preprocessed_dataset")
-@patch("llama_recipes.finetuning.generate_peft_config")
-@patch("llama_recipes.finetuning.get_peft_model")
-@patch("llama_recipes.finetuning.optim.AdamW")
-@patch("llama_recipes.finetuning.StepLR")
-@pytest.mark.parametrize("cuda_is_available", [True, False])
-def test_finetuning_peft_lora(
-    step_lr,
-    optimizer,
-    get_peft_model,
-    gen_peft_config,
-    get_dataset,
-    tokenizer,
-    get_model,
-    train,
-    cuda,
-    cuda_is_available,
-):
-    kwargs = {"use_peft": True}
-
-    get_dataset.return_value = get_fake_dataset()
-    cuda.return_value = cuda_is_available
-
-    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
-
-    main(**kwargs)
+        model = get_mmodel
 
 
     if cuda_is_available:
     if cuda_is_available:
-        assert get_peft_model.return_value.to.call_count == 1
-        assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
+        assert model.return_value.to.call_count == 1
+        assert model.return_value.to.call_args.args[0] == "cuda"
     else:
     else:
-        assert get_peft_model.return_value.to.call_count == 0
-
-    assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
+        assert model.return_value.to.call_count == 0
 
 
 
 
 @patch("llama_recipes.finetuning.get_peft_model")
 @patch("llama_recipes.finetuning.get_peft_model")
 @patch("llama_recipes.finetuning.setup")
 @patch("llama_recipes.finetuning.setup")
 @patch("llama_recipes.finetuning.train")
 @patch("llama_recipes.finetuning.train")
+@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
+@patch("llama_recipes.finetuning.AutoProcessor.from_pretrained")
 @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
 @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
+@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
 @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
 @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
 @patch("llama_recipes.finetuning.get_preprocessed_dataset")
 @patch("llama_recipes.finetuning.get_preprocessed_dataset")
 def test_finetuning_peft_llama_adapter(
 def test_finetuning_peft_llama_adapter(
-    get_dataset, tokenizer, get_model, train, setup, get_peft_model
+    get_dataset,
+    tokenizer,
+    get_config,
+    get_model,
+    get_processor,
+    get_mmodel,
+    train,
+    setup,
+    get_peft_model,
+    model_type,
 ):
 ):
     kwargs = {
     kwargs = {
         "use_peft": True,
         "use_peft": True,
         "peft_method": "llama_adapter",
         "peft_method": "llama_adapter",
         "enable_fsdp": True,
         "enable_fsdp": True,
+        "batching_strategy": "packing" if model_type == "llama" else "padding",
     }
     }
 
 
     get_dataset.return_value = get_fake_dataset()
     get_dataset.return_value = get_fake_dataset()
 
 
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
+    get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
+    get_config.return_value = Config(model_type=model_type)
 
 
     os.environ["RANK"] = "0"
     os.environ["RANK"] = "0"
     os.environ["LOCAL_RANK"] = "0"
     os.environ["LOCAL_RANK"] = "0"
@@ -195,20 +166,38 @@ def test_finetuning_peft_llama_adapter(
 
 
 
 
 @patch("llama_recipes.finetuning.train")
 @patch("llama_recipes.finetuning.train")
+@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
+@patch("llama_recipes.finetuning.AutoProcessor.from_pretrained")
 @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
 @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
+@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
 @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
 @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
 @patch("llama_recipes.finetuning.get_preprocessed_dataset")
 @patch("llama_recipes.finetuning.get_preprocessed_dataset")
 @patch("llama_recipes.finetuning.get_peft_model")
 @patch("llama_recipes.finetuning.get_peft_model")
 @patch("llama_recipes.finetuning.StepLR")
 @patch("llama_recipes.finetuning.StepLR")
 def test_finetuning_weight_decay(
 def test_finetuning_weight_decay(
-    step_lr, get_peft_model, get_dataset, tokenizer, get_model, train
+    step_lr,
+    get_peft_model,
+    get_dataset,
+    tokenizer,
+    get_config,
+    get_model,
+    get_processor,
+    get_mmodel,
+    train,
+    model_type,
 ):
 ):
-    kwargs = {"weight_decay": 0.01}
+    kwargs = {
+        "weight_decay": 0.01,
+        "batching_strategy": "packing" if model_type == "llama" else "padding",
+        }
 
 
     get_dataset.return_value = get_fake_dataset()
     get_dataset.return_value = get_fake_dataset()
 
 
-    get_model.return_value.parameters.return_value = [torch.ones(1, 1)]
-    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
+    model = get_model if model_type == "llama" else get_mmodel
+    model.return_value.parameters.return_value = [torch.ones(1, 1)]
+    model.return_value.get_input_embeddings.return_value.weight.shape = [0]
+    
+    get_config.return_value = Config(model_type=model_type)
 
 
     main(**kwargs)
     main(**kwargs)
 
 
@@ -217,35 +206,54 @@ def test_finetuning_weight_decay(
     args, kwargs = train.call_args
     args, kwargs = train.call_args
     optimizer = args[4]
     optimizer = args[4]
 
 
-    print(optimizer.state_dict())
-
     assert isinstance(optimizer, AdamW)
     assert isinstance(optimizer, AdamW)
     assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
     assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
 
 
 
 
 @patch("llama_recipes.finetuning.train")
 @patch("llama_recipes.finetuning.train")
+@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
+@patch("llama_recipes.finetuning.AutoProcessor.from_pretrained")
 @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
 @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
+@patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
 @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
 @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
 @patch("llama_recipes.finetuning.get_preprocessed_dataset")
 @patch("llama_recipes.finetuning.get_preprocessed_dataset")
 @patch("llama_recipes.finetuning.optim.AdamW")
 @patch("llama_recipes.finetuning.optim.AdamW")
 @patch("llama_recipes.finetuning.StepLR")
 @patch("llama_recipes.finetuning.StepLR")
 def test_batching_strategy(
 def test_batching_strategy(
-    step_lr, optimizer, get_dataset, tokenizer, get_model, train
+    step_lr,
+    optimizer,
+    get_dataset,
+    tokenizer,
+    get_config,
+    get_model,
+    get_processor,
+    get_mmodel,
+    train,
+    model_type,
 ):
 ):
-    kwargs = {"batching_strategy": "packing"}
+    kwargs = {
+        "batching_strategy": "packing",
+        }
 
 
     get_dataset.return_value = get_fake_dataset()
     get_dataset.return_value = get_fake_dataset()
 
 
-    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
+    model = get_model if model_type == "llama" else get_mmodel
+    model.return_value.get_input_embeddings.return_value.weight.shape = [0]
 
 
-    main(**kwargs)
+    get_config.return_value = Config(model_type=model_type)
 
 
-    assert train.call_count == 1
+    c = nullcontext() if model_type == "llama" else  pytest.raises(ValueError)
+    
+    with c:
+        main(**kwargs)
 
 
-    args, kwargs = train.call_args
-    train_dataloader, eval_dataloader = args[1:3]
-    assert isinstance(train_dataloader.batch_sampler, BatchSampler)
-    assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
+    assert train.call_count == (1 if model_type == "llama" else 0)
+
+    if model_type == "llama":
+        args, kwargs = train.call_args
+        train_dataloader, eval_dataloader = args[1:3]
+        assert isinstance(train_dataloader.batch_sampler, BatchSampler)
+        assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
 
 
     kwargs["batching_strategy"] = "padding"
     kwargs["batching_strategy"] = "padding"
     train.reset_mock()
     train.reset_mock()

+ 12 - 1
src/tests/test_train_utils.py

@@ -27,10 +27,16 @@ def temp_output_dir():
 @patch("llama_recipes.utils.train_utils.nullcontext")
 @patch("llama_recipes.utils.train_utils.nullcontext")
 @patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
 @patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
 @patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
 @patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
-def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
+def test_gradient_accumulation(
+    autocast,
+    scaler,
+    nullcontext,
+    mem_trace,
+    mocker):
 
 
     model = mocker.MagicMock(name="model")
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
+    model().loss.detach.return_value = torch.tensor(1)
     mock_tensor = mocker.MagicMock(name="tensor")
     mock_tensor = mocker.MagicMock(name="tensor")
     batch = {"input": mock_tensor}
     batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]
     train_dataloader = [batch, batch, batch, batch, batch]
@@ -47,6 +53,9 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
     train_config.max_train_step = 0
     train_config.max_train_step = 0
     train_config.max_eval_step = 0
     train_config.max_eval_step = 0
     train_config.save_metrics = False
     train_config.save_metrics = False
+    train_config.flop_counter_start = 0
+    train_config.use_profiler = False
+    train_config.flop_counter = True
 
 
     train(
     train(
         model,
         model,
@@ -86,6 +95,7 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
 def test_save_to_json(temp_output_dir, mocker):
 def test_save_to_json(temp_output_dir, mocker):
     model = mocker.MagicMock(name="model")
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
+    model().loss.detach.return_value = torch.tensor(1)
     mock_tensor = mocker.MagicMock(name="tensor")
     mock_tensor = mocker.MagicMock(name="tensor")
     batch = {"input": mock_tensor}
     batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]
     train_dataloader = [batch, batch, batch, batch, batch]
@@ -103,6 +113,7 @@ def test_save_to_json(temp_output_dir, mocker):
     train_config.max_train_step = 0
     train_config.max_train_step = 0
     train_config.max_eval_step = 0
     train_config.max_eval_step = 0
     train_config.output_dir = temp_output_dir
     train_config.output_dir = temp_output_dir
+    train_config.flop_counter_start = 0
     train_config.use_profiler = False
     train_config.use_profiler = False
 
 
     results = train(
     results = train(

+ 50 - 0
src/tests/utils.py

@@ -0,0 +1,50 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from transformers import AutoTokenizer
+
+
+class FakeTokenizer(object):
+    def __init__(self):
+        self.pad_token_id = 0
+        self.bos_token_id = 42
+        self.eos_token_id = 43
+        self.sep_token_id = 3
+        self.vocab_size = 128256
+
+        self.pad_token = "<|pad_id|>"
+        self.bos_token = "<|bos_id|>"
+        self.eos_token = "<|eos_id|>"
+        self.sep_token = "<|sep_id|>"
+        self.tokenizer = self
+        self.padding_side = "left"
+
+    def __call__(self, *args, **kwargs):
+        ids = self.encode(*args, **kwargs)
+        return {"input_ids": ids}
+
+    def encode(self, text, *args, **kwargs):
+        return [self.bos_token_id] + [len(c) for c in text.split(" ")] + [self.eos_token_id]
+    
+    def __len__(self):
+        return 128256
+    
+    def pad(self, *args, **kwargs):
+        args = args[0]
+        max_len = max([len(a["input_ids"]) for a in args])
+        for a in args:
+            for k in a.keys():
+                a[k] = a[k] + ([self.pad_token_id if k == "input_ids" else 0] * (max_len - len(a)))
+        out = {}
+        for k in args[0].keys():
+            out[k] = [a[k] for a in args]
+        return out
+
+
+def maybe_tokenizer(name):
+    if name == "fake_llama":
+        return FakeTokenizer()
+    try:
+        return AutoTokenizer.from_pretrained(name)
+    except OSError:
+        return None