|
@@ -3,8 +3,7 @@
|
|
|
|
|
|
|
|
|
import copy
|
|
|
-import datasets
|
|
|
-from datasets import Dataset, load_dataset, DatasetDict
|
|
|
+from datasets import load_dataset
|
|
|
import itertools
|
|
|
|
|
|
B_INST, E_INST = "[INST]", "[/INST]"
|
|
@@ -26,8 +25,6 @@ def tokenize_dialog(dialog, tokenizer):
|
|
|
eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
|
|
|
labels = copy.copy(dialog_tokens)
|
|
|
last_idx = 0
|
|
|
- token_length = len(dialog_tokens)
|
|
|
- last_idx = 0
|
|
|
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
|
|
|
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
|
|
|
prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
|
|
@@ -44,18 +41,7 @@ def tokenize_dialog(dialog, tokenizer):
|
|
|
dialog_tokens = [dialog_tokens]
|
|
|
labels_tokens = [labels]
|
|
|
else:
|
|
|
- # Otherwise, use the original tokenizer to generate the tokens as it is from Llama 2 family models
|
|
|
- prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[:2]]
|
|
|
- answer = dialog[-1]
|
|
|
- answer_tokens = tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False)
|
|
|
-
|
|
|
- #Add labels, convert prompt token to -100 in order to ignore in loss function
|
|
|
- sample = {
|
|
|
- "input_ids": prompt_tokens + answer_tokens,
|
|
|
- "attention_mask" : [1] * (len(prompt_tokens) + len(answer_tokens)),
|
|
|
- "labels": [-100] * len(prompt_tokens) + answer_tokens,
|
|
|
- }
|
|
|
- return sample
|
|
|
+ raise Exception("This raft_dataset only supports Llama 3 family models, please make sure the tokenizer is from Llama 3 family models.")
|
|
|
|
|
|
combined_tokens = {
|
|
|
"input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
|