chatbot_dataset.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. import copy
  4. import datasets
  5. from datasets import Dataset, load_dataset, DatasetDict
  6. import itertools
  7. B_INST, E_INST = "[INST]", "[/INST]"
  8. def tokenize_dialog(q_a_pair, tokenizer):
  9. prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(question).strip()} {E_INST}", add_special_tokens=False) for question in q_a_pair["Question"]]
  10. answer_tokens = [tokenizer.encode(f"{answer.strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in q_a_pair["Answer"]]
  11. dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
  12. dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
  13. #Add labels, convert prompt token to -100 in order to ignore in loss function
  14. labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
  15. combined_tokens = {
  16. "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
  17. "labels": list(itertools.chain(*(t for t in labels_tokens))),
  18. }
  19. return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
  20. def get_custom_dataset(dataset_config, tokenizer, split, split_ratio=0.8):
  21. dataset = load_dataset('json', data_files=dataset_config.data_path)
  22. dataset = dataset['train'].train_test_split(test_size=1-split_ratio, shuffle=True)
  23. dataset = dataset[split].map(lambda sample: {
  24. "Question": sample["Question"],
  25. "Answer": sample["Answer"],
  26. },
  27. batched=True,
  28. )
  29. dataset = dataset.map(lambda x: tokenize_dialog(x, tokenizer))
  30. return dataset