|
@@ -4,15 +4,14 @@
|
|
|
# For dataset details visit: https://huggingface.co/datasets/samsum
|
|
|
|
|
|
import copy
|
|
|
+
|
|
|
import datasets
|
|
|
|
|
|
|
|
|
def get_preprocessed_samsum(dataset_config, tokenizer, split):
|
|
|
- dataset = datasets.load_dataset("samsum", split=split)
|
|
|
+ dataset = datasets.load_dataset("samsum", split=split, trust_remote_code=True)
|
|
|
|
|
|
- prompt = (
|
|
|
- f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
|
|
|
- )
|
|
|
+ prompt = f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
|
|
|
|
|
|
def apply_prompt_template(sample):
|
|
|
return {
|
|
@@ -23,14 +22,18 @@ def get_preprocessed_samsum(dataset_config, tokenizer, split):
|
|
|
dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
|
|
|
|
|
|
def tokenize_add_label(sample):
|
|
|
- prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
|
|
|
- summary = tokenizer.encode(sample["summary"] + tokenizer.eos_token, add_special_tokens=False)
|
|
|
+ prompt = tokenizer.encode(
|
|
|
+ tokenizer.bos_token + sample["prompt"], add_special_tokens=False
|
|
|
+ )
|
|
|
+ summary = tokenizer.encode(
|
|
|
+ sample["summary"] + tokenizer.eos_token, add_special_tokens=False
|
|
|
+ )
|
|
|
|
|
|
sample = {
|
|
|
"input_ids": prompt + summary,
|
|
|
- "attention_mask" : [1] * (len(prompt) + len(summary)),
|
|
|
+ "attention_mask": [1] * (len(prompt) + len(summary)),
|
|
|
"labels": [-100] * len(prompt) + summary,
|
|
|
- }
|
|
|
+ }
|
|
|
|
|
|
return sample
|
|
|
|