samsum_dataset.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://huggingface.co/datasets/samsum
  4. import copy
  5. import datasets
  6. def get_preprocessed_samsum(dataset_config, tokenizer, split):
  7. if not hasattr(dataset_config, "trust_remote_code") or not dataset_config.trust_remote_code:
  8. raise ValueError("The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum. To activate `trust_remote_code` option use this config: --samsum_dataset.trust_remote_code=True")
  9. dataset = datasets.load_dataset("samsum", split=split, trust_remote_code=dataset_config.trust_remote_code)
  10. prompt = (
  11. f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
  12. )
  13. def apply_prompt_template(sample):
  14. return {
  15. "prompt": prompt.format(dialog=sample["dialogue"]),
  16. "summary": sample["summary"],
  17. }
  18. dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
  19. def tokenize_add_label(sample):
  20. prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
  21. summary = tokenizer.encode(sample["summary"] + tokenizer.eos_token, add_special_tokens=False)
  22. sample = {
  23. "input_ids": prompt + summary,
  24. "attention_mask" : [1] * (len(prompt) + len(summary)),
  25. "labels": [-100] * len(prompt) + summary,
  26. }
  27. return sample
  28. dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
  29. return dataset