瀏覽代碼

Remove trust_remote_code in favor of setting env variable

Matthias Reso 6 月之前
父節點
當前提交
8b01298aee

+ 0 - 1
src/llama_recipes/configs/datasets.py

@@ -9,7 +9,6 @@ class samsum_dataset:
     dataset: str =  "samsum_dataset"
     train_split: str = "train"
     test_split: str = "validation"
-    trust_remote_code: bool = False
 
 
 @dataclass

+ 14 - 3
src/llama_recipes/datasets/samsum_dataset.py

@@ -6,11 +6,22 @@
 import copy
 import datasets
 
+from unittest.mock import patch
+
+@patch('builtins.input', return_value="N")
+def load_samsum(split, _):
+    try:
+        ds = datasets.load_dataset("Samsung/samsum", split=split)
+    except ValueError as e:
+        if "trust_remote_code" in str(e):
+          raise ValueError("Loading Samsung/samsum requires you to execute the dataset script in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set HF_DATASETS_TRUST_REMOTE_CODE env variable to True.") from e
+        else:
+          raise e
+    return ds
+
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
-    if not hasattr(dataset_config, "trust_remote_code") or not dataset_config.trust_remote_code:
-        raise ValueError("The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum. To activate `trust_remote_code` option use this config: --samsum_dataset.trust_remote_code=True")
-    dataset = datasets.load_dataset("samsum", split=split, trust_remote_code=dataset_config.trust_remote_code)
+    dataset = load_samsum(split)
 
     prompt = (
         f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"

+ 0 - 14
src/tests/datasets/test_samsum_datasets.py

@@ -5,17 +5,6 @@ import pytest
 from functools import partial
 from unittest.mock import patch
 
-EXPECTED_RESULTS = {
-    "meta-llama/Llama-2-7b-hf":{
-        "label": 8432,
-        "pos": 242,
-    },
-    "meta-llama/Meta-Llama-3.1-8B":{
-        "label": 2250,
-        "pos": 211,
-    },
-}
-
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
@@ -59,9 +48,6 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
 
-    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
-    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
-
     assert batch["input_ids"][0][0] == token.bos_token_id
     assert batch["labels"][0][-1] == token.eos_token_id
     assert batch["input_ids"][0][-1] == token.eos_token_id