raft_dataset.py 4.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. import copy
  4. import datasets
  5. from datasets import Dataset, load_dataset, DatasetDict
  6. import itertools
  7. B_INST, E_INST = "[INST]", "[/INST]"
  8. def tokenize_dialog(dialog, tokenizer):
  9. # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
  10. if tokenizer.vocab_size >= 128000:
  11. dialog_tokens = tokenizer.apply_chat_template(dialog)
  12. dialog_tokens = dialog_tokens[:-4] # Remove generation prompt <|start_header_id|>assistant<|end_header_id|>\n\n
  13. eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
  14. labels = copy.copy(dialog_tokens)
  15. last_idx = 0
  16. for n, idx in enumerate(eot_indices):
  17. if n % 2 == 1:
  18. last_idx = idx
  19. else:
  20. labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
  21. dialog_tokens = [dialog_tokens]
  22. labels_tokens = [labels]
  23. else:
  24. # Otherwise, use the original tokenizer to generate the tokens as it is from Llama 2 family models
  25. prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[:2]]
  26. answer = dialog[-1]
  27. answer_tokens = tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False)
  28. #Add labels, convert prompt token to -100 in order to ignore in loss function
  29. sample = {
  30. "input_ids": prompt_tokens + answer_tokens,
  31. "attention_mask" : [1] * (len(prompt_tokens) + len(answer_tokens)),
  32. "labels": [-100] * len(prompt_tokens) + answer_tokens,
  33. }
  34. return sample
  35. combined_tokens = {
  36. "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
  37. "labels": list(itertools.chain(*(t for t in labels_tokens))),
  38. }
  39. return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
  40. def raft_tokenize(q_a_pair, tokenizer):
  41. # last line is the question
  42. question = q_a_pair["instruction"].split('\n')[-1]
  43. # all the lines before the last line are the context
  44. documents = q_a_pair["instruction"].split('\n')[:-1]
  45. # output is the label
  46. answer = q_a_pair["output"]
  47. system_prompt = "You are a helpful question answerer who can provide an answer given a question and relevant context."
  48. user_prompt = prompt = """
  49. Question: {question}\nContext: {context}\n
  50. Answer this question using the information given in the context above. Here is things to pay attention to:
  51. - First provide step-by-step reasoning on how to answer the question.
  52. - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
  53. - End your response with final answer in the form <ANSWER>: $answer, the answer should be succinct.
  54. You MUST begin your final answer with the tag "<ANSWER>:".
  55. """.format(question=question, context=str(documents))
  56. chat = [
  57. {"role": "system", "content": system_prompt},
  58. {"role": "user", "content": user_prompt},
  59. {"role": "assistant", "content": answer}
  60. ]
  61. return tokenize_dialog(chat, tokenizer)
  62. def get_custom_dataset(dataset_config, tokenizer, split, split_ratio=0.8):
  63. # load_dataset will return DatasetDict that contains all the data in the train set
  64. dataset_dict = load_dataset('json', data_files=dataset_config.data_path)
  65. dataset = dataset_dict['train']
  66. dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)
  67. dataset = dataset[split].map(lambda sample: {
  68. "instruction": sample["instruction"],
  69. "output": sample["cot_answer"],
  70. },
  71. batched=True,
  72. )
  73. dataset = dataset.map(lambda x: raft_tokenize(x, tokenizer))
  74. return dataset