|
@@ -8,16 +8,15 @@ import datasets
|
|
|
|
|
|
from unittest.mock import patch
|
|
|
|
|
|
-
|
|
|
@patch('builtins.input', return_value="N")
|
|
|
def load_samsum(split, _):
|
|
|
try:
|
|
|
ds = datasets.load_dataset("knkarthick/samsum", split=split)
|
|
|
except ValueError as e:
|
|
|
if "trust_remote_code" in str(e):
|
|
|
- raise ValueError("Loading knkarthick/samsum requires you to execute the dataset script in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set HF_DATASETS_TRUST_REMOTE_CODE env variable to True.") from e
|
|
|
+ raise ValueError("Loading knkarthick/samsum requires you to execute the dataset script in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set HF_DATASETS_TRUST_REMOTE_CODE env variable to True.") from e
|
|
|
else:
|
|
|
- raise e
|
|
|
+ raise e
|
|
|
return ds
|
|
|
|
|
|
|
|
@@ -34,24 +33,20 @@ def get_preprocessed_samsum(dataset_config, tokenizer, split):
|
|
|
"summary": sample["summary"],
|
|
|
}
|
|
|
|
|
|
- dataset = dataset.map(apply_prompt_template,
|
|
|
- remove_columns=list(dataset.features))
|
|
|
+ dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
|
|
|
|
|
|
def tokenize_add_label(sample):
|
|
|
- prompt = tokenizer.encode(
|
|
|
- tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
|
|
|
- summary = tokenizer.encode(
|
|
|
- sample["summary"] + tokenizer.eos_token, add_special_tokens=False)
|
|
|
+ prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
|
|
|
+ summary = tokenizer.encode(sample["summary"] + tokenizer.eos_token, add_special_tokens=False)
|
|
|
|
|
|
sample = {
|
|
|
"input_ids": prompt + summary,
|
|
|
- "attention_mask": [1] * (len(prompt) + len(summary)),
|
|
|
+ "attention_mask" : [1] * (len(prompt) + len(summary)),
|
|
|
"labels": [-100] * len(prompt) + summary,
|
|
|
- }
|
|
|
+ }
|
|
|
|
|
|
return sample
|
|
|
|
|
|
- dataset = dataset.map(tokenize_add_label,
|
|
|
- remove_columns=list(dataset.features))
|
|
|
+ dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
|
|
|
|
|
|
return dataset
|