samsum_dataset.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://huggingface.co/datasets/samsum
  4. import copy
  5. import datasets
  6. from unittest.mock import patch
  7. @patch('builtins.input', return_value="N")
  8. def load_samsum(split, _):
  9. try:
  10. ds = datasets.load_dataset("Samsung/samsum", split=split)
  11. except ValueError as e:
  12. if "trust_remote_code" in str(e):
  13. raise ValueError("Loading Samsung/samsum requires you to execute the dataset script in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set HF_DATASETS_TRUST_REMOTE_CODE env variable to True.") from e
  14. else:
  15. raise e
  16. return ds
  17. def get_preprocessed_samsum(dataset_config, tokenizer, split):
  18. dataset = load_samsum(split)
  19. prompt = (
  20. f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
  21. )
  22. def apply_prompt_template(sample):
  23. return {
  24. "prompt": prompt.format(dialog=sample["dialogue"]),
  25. "summary": sample["summary"],
  26. }
  27. dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
  28. def tokenize_add_label(sample):
  29. prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
  30. summary = tokenizer.encode(sample["summary"] + tokenizer.eos_token, add_special_tokens=False)
  31. sample = {
  32. "input_ids": prompt + summary,
  33. "attention_mask" : [1] * (len(prompt) + len(summary)),
  34. "labels": [-100] * len(prompt) + summary,
  35. }
  36. return sample
  37. dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
  38. return dataset