custom_dataset.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://huggingface.co/datasets/samsum
  4. import copy
  5. import datasets
  6. import itertools
  7. B_INST, E_INST = "[INST]", "[/INST]"
  8. EOT_ID = 128009 #<|eot_id|>
  9. def mask_target(target,seq):
  10. for i in range(len(seq)-len(target)):
  11. if seq[i:i+len(target)] == target:
  12. seq[i:i+len(target)] = [-100] * len(target)
  13. return seq
  14. def tokenize_dialog(dialog, tokenizer):
  15. if tokenizer.vocab_size >= 128000:
  16. dialog_tokens = tokenizer.apply_chat_template(dialog)
  17. eot_indices = [i for i,n in enumerate(dialog_tokens) if n == EOT_ID]
  18. labels = copy.copy(dialog_tokens)
  19. #determine token for system and user
  20. system_or_user = (tokenizer.encode("system")[-1], tokenizer.encode("user")[-1])
  21. last_idx = 0
  22. for n, idx in enumerate(eot_indices):
  23. role_token = labels[last_idx:idx+1][2]
  24. if role_token in system_or_user:
  25. # Set labels to -100 for system and user tokens to ignore in loss function
  26. labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
  27. last_idx = idx
  28. mask_target(tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False), labels)
  29. dialog_tokens = [dialog_tokens]
  30. labels_tokens = [labels]
  31. else:
  32. prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
  33. answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
  34. dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
  35. #Add labels, convert prompt token to -100 in order to ignore in loss function
  36. labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
  37. combined_tokens = {
  38. "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
  39. "labels": list(itertools.chain(*(t for t in labels_tokens))),
  40. }
  41. return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
  42. def get_custom_dataset(dataset_config, tokenizer, split):
  43. dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
  44. dataset = dataset.map(lambda sample: {
  45. "message_id": sample["message_id"],
  46. "parent_id": sample["parent_id"],
  47. "text": sample["text"],
  48. },
  49. batched=True,
  50. remove_columns=list(dataset.features),)
  51. nodes = {}
  52. messages = {}
  53. root_ids = []
  54. for data in dataset:
  55. if data["parent_id"]:
  56. nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
  57. else:
  58. root_ids.append(data["message_id"])
  59. messages[data["message_id"]]=data["text"]
  60. def follow(thread, current_id):
  61. thread = copy.copy(thread) + [messages[current_id]]
  62. if current_id in nodes:
  63. new_threads = []
  64. for next_id in nodes[current_id]:
  65. new_threads += follow(thread, next_id)
  66. return new_threads
  67. else:
  68. return [thread]
  69. def get_threads_from_root(root_id):
  70. all_threads = []
  71. thread = [messages[root_id]]
  72. for cid in nodes[root_id]:
  73. all_threads += follow(thread, cid)
  74. return all_threads
  75. dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
  76. dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
  77. dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
  78. def to_dialog(thread):
  79. dialog = []
  80. for i, content in enumerate(thread):
  81. dialog.append({
  82. "role": "user" if i % 2 == 0 else "assistant",
  83. "content": content,
  84. })
  85. return {"dialog": dialog}
  86. dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
  87. dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
  88. return dataset