| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates.
 
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
- # For dataset details visit: https://huggingface.co/datasets/samsum
 
- import copy
 
- import datasets
 
- import itertools
 
- B_INST, E_INST = "[INST]", "[/INST]"
 
- def tokenize_dialog(dialog, tokenizer):
 
-     prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
 
-     answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
 
-     dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
 
-     #Add labels, convert prompt token to -100 in order to ignore in loss function
 
-     labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
 
-     combined_tokens = {
 
-         "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
 
-         "labels": list(itertools.chain(*(t for t in labels_tokens))),
 
-     }
 
-     return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
 
- def get_custom_dataset(dataset_config, tokenizer, split):
 
-     dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
 
-     dataset = dataset.map(lambda sample: {
 
-         "message_id": sample["message_id"],
 
-         "parent_id": sample["parent_id"],
 
-         "text": sample["text"],
 
-         },
 
-         batched=True,
 
-         remove_columns=list(dataset.features),)
 
-     nodes = {}
 
-     messages = {}
 
-     root_ids = []
 
-     for data in dataset:
 
-         if data["parent_id"]:
 
-             nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
 
-         else:
 
-             root_ids.append(data["message_id"])
 
-         messages[data["message_id"]]=data["text"]
 
-     def follow(thread, current_id):
 
-         thread = copy.copy(thread) + [messages[current_id]]
 
-         if current_id in nodes:
 
-             new_threads = []
 
-             for next_id in nodes[current_id]:
 
-                 new_threads += follow(thread, next_id)
 
-             return new_threads
 
-         else:
 
-             return [thread]
 
-     def get_threads_from_root(root_id):
 
-         all_threads = []
 
-         thread = [messages[root_id]]
 
-         for cid in nodes[root_id]:
 
-             all_threads += follow(thread, cid)
 
-         return all_threads
 
-     dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
 
-     dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
 
-     dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
 
-     def to_dialog(thread):
 
-         dialog = []
 
-         for i, content in enumerate(thread):
 
-             dialog.append({
 
-                 "role": "user" if i % 2 == 0 else "assistant",
 
-                 "content": content,
 
-             })
 
-         return {"dialog": dialog}
 
-     dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
 
-     dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
 
-     return dataset
 
 
  |