raft_dataset.py 4.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. import copy
  4. from datasets import load_dataset
  5. import itertools
  6. B_INST, E_INST = "[INST]", "[/INST]"
  7. # check system prompt token seq or user prompt token seq is in the current token list
  8. def check_header(targets,seq):
  9. for i in range(len(seq)-3):
  10. if seq[i:i+3] in targets:
  11. return True
  12. return False
  13. def replace_target(target,seq):
  14. for i in range(len(seq)-3):
  15. if seq[i:i+3] == target:
  16. seq[i],seq[i+1],seq[i+2] = -100,-100,-100
  17. return seq
  18. def tokenize_dialog(dialog, tokenizer):
  19. # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
  20. if tokenizer.vocab_size >= 128000:
  21. dialog_tokens = tokenizer.apply_chat_template(dialog)
  22. eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
  23. labels = copy.copy(dialog_tokens)
  24. last_idx = 0
  25. # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
  26. # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
  27. prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
  28. for n, idx in enumerate(eot_indices):
  29. current_seq = labels[last_idx:idx+1]
  30. if check_header(prompt_header_seqs,current_seq):
  31. # found prompt header, indicating that this seq should be masked
  32. labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
  33. else:
  34. last_idx = idx
  35. # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
  36. assistant_header_seq = [128006, 78191, 128007]
  37. labels = replace_target(assistant_header_seq,labels)
  38. dialog_tokens = [dialog_tokens]
  39. labels_tokens = [labels]
  40. else:
  41. raise Exception("This raft_dataset only supports Llama 3 family models, please make sure the tokenizer is from Llama 3 family models.")
  42. combined_tokens = {
  43. "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
  44. "labels": list(itertools.chain(*(t for t in labels_tokens))),
  45. }
  46. return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
  47. def raft_tokenize(q_a_pair, tokenizer):
  48. end_tag = "<\/DOCUMENT>\n"
  49. # find the last end_tag in the instruction, the rest is the question
  50. index =q_a_pair["instruction"].rindex("<\/DOCUMENT>\n")+len(end_tag)
  51. question = q_a_pair["instruction"][index:]
  52. # all the lines before end_tag are the context
  53. documents = q_a_pair["instruction"][:index]
  54. # output is the label
  55. answer = q_a_pair["output"]
  56. system_prompt = "You are a helpful chatbot who can provide an answer to every questions from the user given a relevant context."
  57. user_prompt = """
  58. Question: {question}\nContext: {context}\n
  59. Answer this question using the information given by multiple documents in the context above. Here are the things to pay attention to:
  60. - The context contains many documents, each document starts with <DOCUMENT> and ends </DOCUMENT>.
  61. - First provide step-by-step reasoning on how to answer the question.
  62. - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
  63. - End your response with final answer in the form <ANSWER>: $answer, the answer should less than 60 words.
  64. You MUST begin your final answer with the tag "<ANSWER>:".
  65. """.format(question=question, context=documents)
  66. chat = [
  67. {"role": "system", "content": system_prompt},
  68. {"role": "user", "content": user_prompt},
  69. {"role": "assistant", "content": answer}
  70. ]
  71. return tokenize_dialog(chat, tokenizer)
  72. def get_custom_dataset(dataset_config, tokenizer, split, split_ratio=0.9):
  73. # load_dataset will return DatasetDict that contains all the data in the train set
  74. dataset_dict = load_dataset('json', data_files=dataset_config.data_path)
  75. dataset = dataset_dict['train']
  76. dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)
  77. dataset = dataset[split].map(lambda sample: {
  78. "instruction": sample["instruction"],
  79. "output": sample["cot_answer"],
  80. },
  81. batched=True,
  82. )
  83. dataset = dataset.map(lambda x: raft_tokenize(x, tokenizer))
  84. return dataset