Forráskód Böngészése

default to trusted_code

lessw2020 5 hónapja
szülő
commit
ce154f56ae
1 módosított fájl, 11 hozzáadás és 8 törlés
  1. 11 8
      src/llama_recipes/datasets/samsum_dataset.py

+ 11 - 8
src/llama_recipes/datasets/samsum_dataset.py

@@ -4,15 +4,14 @@
 # For dataset details visit: https://huggingface.co/datasets/samsum
 
 import copy
+
 import datasets
 
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
-    dataset = datasets.load_dataset("samsum", split=split)
+    dataset = datasets.load_dataset("samsum", split=split, trust_remote_code=True)
 
-    prompt = (
-        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
-    )
+    prompt = f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
 
     def apply_prompt_template(sample):
         return {
@@ -23,14 +22,18 @@ def get_preprocessed_samsum(dataset_config, tokenizer, split):
     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
 
     def tokenize_add_label(sample):
-        prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
-        summary = tokenizer.encode(sample["summary"] +  tokenizer.eos_token, add_special_tokens=False)
+        prompt = tokenizer.encode(
+            tokenizer.bos_token + sample["prompt"], add_special_tokens=False
+        )
+        summary = tokenizer.encode(
+            sample["summary"] + tokenizer.eos_token, add_special_tokens=False
+        )
 
         sample = {
             "input_ids": prompt + summary,
-            "attention_mask" : [1] * (len(prompt) + len(summary)),
+            "attention_mask": [1] * (len(prompt) + len(summary)),
             "labels": [-100] * len(prompt) + summary,
-            }
+        }
 
         return sample