فهرست منبع

remove back tick

nabidam 5 روز پیش
والد
کامیت
b1213aea21
1فایلهای تغییر یافته به همراه14 افزوده شده و 9 حذف شده
  1. 14 9
      src/llama_cookbook/datasets/samsum_dataset.py

+ 14 - 9
src/llama_cookbook/datasets/samsum_dataset.py

@@ -8,15 +8,16 @@ import datasets
 
 from unittest.mock import patch
 
+
 @patch('builtins.input', return_value="N")
 def load_samsum(split, _):
     try:
-        ds = datasets.load_dataset("knkarthick/samsum`", split=split)
+        ds = datasets.load_dataset("knkarthick/samsum", split=split)
     except ValueError as e:
         if "trust_remote_code" in str(e):
-          raise ValueError("Loading knkarthick/samsum` requires you to execute the dataset script in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set HF_DATASETS_TRUST_REMOTE_CODE env variable to True.") from e
+            raise ValueError("Loading knkarthick/samsum requires you to execute the dataset script in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set HF_DATASETS_TRUST_REMOTE_CODE env variable to True.") from e
         else:
-          raise e
+            raise e
     return ds
 
 
@@ -33,20 +34,24 @@ def get_preprocessed_samsum(dataset_config, tokenizer, split):
             "summary": sample["summary"],
         }
 
-    dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
+    dataset = dataset.map(apply_prompt_template,
+                          remove_columns=list(dataset.features))
 
     def tokenize_add_label(sample):
-        prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
-        summary = tokenizer.encode(sample["summary"] +  tokenizer.eos_token, add_special_tokens=False)
+        prompt = tokenizer.encode(
+            tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
+        summary = tokenizer.encode(
+            sample["summary"] + tokenizer.eos_token, add_special_tokens=False)
 
         sample = {
             "input_ids": prompt + summary,
-            "attention_mask" : [1] * (len(prompt) + len(summary)),
+            "attention_mask": [1] * (len(prompt) + len(summary)),
             "labels": [-100] * len(prompt) + summary,
-            }
+        }
 
         return sample
 
-    dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
+    dataset = dataset.map(tokenize_add_label,
+                          remove_columns=list(dataset.features))
 
     return dataset