Browse Source

Replaced non-existant dataset

varunfb 1 week ago
parent
commit
9bba2863d9
1 changed files with 16 additions and 10 deletions
  1. 16 10
      src/tests/datasets/test_samsum_datasets.py

+ 16 - 10
src/tests/datasets/test_samsum_datasets.py

@@ -1,32 +1,36 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-import pytest
 from dataclasses import dataclass
 from functools import partial
 from unittest.mock import patch
+
+import pytest
 from datasets import load_dataset
 
+
 @dataclass
 class Config:
     model_type: str = "llama"
 
+
 try:
-    load_dataset("Samsung/samsum")
+    load_dataset("knkarthick/samsum")
     SAMSUM_UNAVAILABLE = False
 except ValueError:
     SAMSUM_UNAVAILABLE = True
 
+
 @pytest.mark.skipif(SAMSUM_UNAVAILABLE, reason="Samsum dataset is unavailable")
 @pytest.mark.skip_missing_tokenizer
-@patch('llama_cookbook.finetuning.train')
-@patch('llama_cookbook.finetuning.AutoTokenizer')
+@patch("llama_cookbook.finetuning.train")
+@patch("llama_cookbook.finetuning.AutoTokenizer")
 @patch("llama_cookbook.finetuning.AutoConfig.from_pretrained")
 @patch("llama_cookbook.finetuning.AutoProcessor")
 @patch("llama_cookbook.finetuning.MllamaForConditionalGeneration.from_pretrained")
-@patch('llama_cookbook.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_cookbook.finetuning.optim.AdamW')
-@patch('llama_cookbook.finetuning.StepLR')
+@patch("llama_cookbook.finetuning.LlamaForCausalLM.from_pretrained")
+@patch("llama_cookbook.finetuning.optim.AdamW")
+@patch("llama_cookbook.finetuning.StepLR")
 def test_samsum_dataset(
     step_lr,
     optimizer,
@@ -39,11 +43,13 @@ def test_samsum_dataset(
     mocker,
     setup_tokenizer,
     llama_version,
-    ):
+):
     from llama_cookbook.finetuning import main
 
     setup_tokenizer(tokenizer)
-    get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
+    get_model.return_value.get_input_embeddings.return_value.weight.shape = [
+        32000 if "Llama-2" in llama_version else 128256
+    ]
     get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
     get_config.return_value = Config()
 
@@ -55,7 +61,7 @@ def test_samsum_dataset(
         "use_peft": False,
         "dataset": "samsum_dataset",
         "batching_strategy": "padding",
-        }
+    }
 
     main(**kwargs)