raft_dataset.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. import copy
  4. import datasets
  5. from datasets import Dataset, load_dataset, DatasetDict
  6. import itertools
  7. B_INST, E_INST = "[INST]", "[/INST]"
  8. def raft_tokenize(q_a_pair, tokenizer):
  9. # last line is the question
  10. question = q_a_pair["instruction"].split('\n')[-1]
  11. # all the lines before the last line are the context
  12. documents = q_a_pair["instruction"].split('\n')[:-1]
  13. # output is the label
  14. answer = q_a_pair["output"]
  15. system_prompt = "You are a helpful question answerer who can provide an answer given a question and relevant context."
  16. user_prompt = prompt = """
  17. Question: {question}\nContext: {context}\n
  18. Answer this question using the information given in the context above. Here is things to pay attention to:
  19. - First provide step-by-step reasoning on how to answer the question.
  20. - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
  21. - End your response with final answer in the form <ANSWER>: $answer, the answer should be succinct.
  22. You MUST begin your final answer with the tag "<ANSWER>:".
  23. """.format(question=question, context=str(documents))
  24. final_prompt = system_prompt + '\n' + user_prompt
  25. prompt_tokens = tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(final_prompt).strip()} {E_INST}", add_special_tokens=False)
  26. answer_tokens = tokenizer.encode(f"{answer.strip()} {tokenizer.eos_token}", add_special_tokens=False)
  27. #Add labels, convert prompt token to -100 in order to ignore in loss function
  28. sample = {
  29. "input_ids": prompt_tokens + answer_tokens,
  30. "attention_mask" : [1] * (len(prompt_tokens) + len(answer_tokens)),
  31. "labels": [-100] * len(prompt_tokens) + answer_tokens,
  32. }
  33. return sample
  34. def get_custom_dataset(dataset_config, tokenizer, split, split_ratio=0.8):
  35. # load_dataset will return DatasetDict that contains all the data in the train set
  36. dataset_dict = load_dataset('json', data_files=dataset_config.data_path)
  37. dataset = dataset_dict['train']
  38. dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)
  39. dataset = dataset[split].map(lambda sample: {
  40. "instruction": sample["instruction"],
  41. "output": sample["cot_answer"],
  42. },
  43. batched=True,
  44. )
  45. dataset = dataset.map(lambda x: raft_tokenize(x, tokenizer))
  46. return dataset