浏览代码

raft_dataset.py must be used with llama3 tokenizer

Kai Wu 1 年之前
父节点
当前提交
839492714c
共有 1 个文件被更改,包括 2 次插入16 次删除
  1. 2 16
      recipes/finetuning/datasets/raft_dataset.py

+ 2 - 16
recipes/finetuning/datasets/raft_dataset.py

@@ -3,8 +3,7 @@
 
 
 import copy
-import datasets
-from datasets import Dataset, load_dataset, DatasetDict
+from datasets import load_dataset
 import itertools
 
 B_INST, E_INST = "[INST]", "[/INST]"
@@ -26,8 +25,6 @@ def tokenize_dialog(dialog, tokenizer):
         eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
         labels = copy.copy(dialog_tokens)
         last_idx = 0
-        token_length = len(dialog_tokens)
-        last_idx = 0
         # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
         # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
         prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
@@ -44,18 +41,7 @@ def tokenize_dialog(dialog, tokenizer):
         dialog_tokens = [dialog_tokens]
         labels_tokens = [labels]
     else:
-        # Otherwise, use the original tokenizer to generate the tokens as it is from Llama 2 family models
-        prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[:2]]
-        answer = dialog[-1]
-        answer_tokens = tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False)
-
-        #Add labels, convert prompt token to -100 in order to ignore in loss function
-        sample = {
-            "input_ids": prompt_tokens + answer_tokens,
-            "attention_mask" : [1] * (len(prompt_tokens) + len(answer_tokens)),
-            "labels": [-100] * len(prompt_tokens) + answer_tokens,
-            }
-        return sample
+        raise Exception("This raft_dataset only supports Llama 3 family models, please make sure the tokenizer is from Llama 3 family models.")
 
     combined_tokens = {
         "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),