chatbot_dataset.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. import copy
  4. import datasets
  5. from datasets import Dataset, load_dataset, DatasetDict
  6. import itertools
  7. B_INST, E_INST = "[INST]", "[/INST]"
  8. def tokenize_dialog(q_a_pair, tokenizer):
  9. question, answer = q_a_pair["Question"], q_a_pair["Answer"]
  10. prompt_tokens = tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(question).strip()} {E_INST}", add_special_tokens=False)
  11. answer_tokens = tokenizer.encode(f"{answer.strip()} {tokenizer.eos_token}", add_special_tokens=False)
  12. sample = {
  13. "input_ids": prompt_tokens + answer_tokens,
  14. "attention_mask" : [1] * (len(prompt_tokens) + len(answer_tokens)),
  15. "labels": [-100] * len(prompt_tokens) + answer_tokens,
  16. }
  17. return sample
  18. def get_custom_dataset(dataset_config, tokenizer, split, split_ratio=0.8):
  19. dataset_dict = load_dataset('json', data_files=dataset_config.data_path)
  20. dataset = dataset_dict['train']
  21. dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)
  22. dataset = dataset[split].map(lambda sample: {
  23. "Question": sample["Question"],
  24. "Answer": sample["Answer"],
  25. },
  26. batched=True,
  27. )
  28. dataset = dataset.map(lambda x: tokenize_dialog(x, tokenizer))
  29. return dataset