瀏覽代碼

add llama3 support for alpaca dataset

Kai Wu 7 月之前
父節點
當前提交
b5ae9aef56
共有 1 個文件被更改,包括 69 次插入24 次删除
  1. 69 24
      src/llama_recipes/datasets/alpaca_dataset.py

+ 69 - 24
src/llama_recipes/datasets/alpaca_dataset.py

@@ -4,6 +4,7 @@
 # For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html
 
 import copy
+import itertools
 import json
 
 import torch
@@ -23,11 +24,12 @@ PROMPT_DICT = {
     ),
 }
 
+
 class InstructionDataset(Dataset):
     def __init__(self, dataset_config, tokenizer, partition="train"):
         self.ann = json.load(open(dataset_config.data_path))
         # Use 5% of the dataset for evaluation
-        eval_length = int(len(self.ann)/20)
+        eval_length = int(len(self.ann) / 20)
         if partition == "train":
             self.ann = self.ann[eval_length:]
         else:
@@ -41,30 +43,73 @@ class InstructionDataset(Dataset):
     def __getitem__(self, index):
         IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss
 
-
         ann = self.ann[index]
         if ann.get("input", "") == "":
-            prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
+            user_prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
+        else:
+            user_prompt = PROMPT_DICT["prompt_input"].format_map(ann)
+
+        # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
+        if self.tokenizer.vocab_size >= 128000:
+            dialog = [
+                {"role": "user", "content": user_prompt},
+                {"role": "assistant", "content": ann["output"]},
+            ]
+            dialog_tokens = self.tokenizer.apply_chat_template(dialog)
+            eot_indices = [i for i, n in enumerate(dialog_tokens) if n == 128009]
+            labels = copy.copy(dialog_tokens)
+            last_idx = 0
+            # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
+            # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
+            prompt_header_seqs = [[128006, 9125, 128007], [128006, 882, 128007]]
+            for n, idx in enumerate(eot_indices):
+                current_seq = labels[last_idx : idx + 1]
+                if self.check_header(prompt_header_seqs, current_seq):
+                    # found prompt header, indicating that this seq should be masked
+                    labels[last_idx : idx + 1] = [IGNORE_INDEX] * (idx - last_idx + 1)
+                else:
+                    last_idx = idx + 1
+            # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
+            assistant_header_seq = [128006, 78191, 128007]
+            labels = self.replace_target(assistant_header_seq, labels)
+            dialog_tokens = [dialog_tokens]
+            labels_tokens = [labels]
+            combined_tokens = {
+                "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+                "labels": list(itertools.chain(*(t for t in labels_tokens))),
+            }
+            return dict(
+                combined_tokens, attention_mask=[1] * len(combined_tokens["input_ids"])
+            )
         else:
-            prompt = PROMPT_DICT["prompt_input"].format_map(ann)
-        example = prompt + ann["output"]
-        prompt = torch.tensor(
-            self.tokenizer.encode(prompt), dtype=torch.int64
-        )
-        example = self.tokenizer.encode(example)
-        example.append(self.tokenizer.eos_token_id)
-        example = torch.tensor(
-            example, dtype=torch.int64
-        )
-        labels = copy.deepcopy(example)
-        labels[: len(prompt)] = -1
-        example_mask = example.ge(0)
-        label_mask = labels.ge(0)
-        example[~example_mask] = 0
-        labels[~label_mask] = IGNORE_INDEX
+            # for llama 2 fine-tuning, we use the old prompt template
+            example = user_prompt + ann["output"]
+            prompt = torch.tensor(self.tokenizer.encode(user_prompt), dtype=torch.int64)
+            example = self.tokenizer.encode(example)
+            example.append(self.tokenizer.eos_token_id)
+            example = torch.tensor(example, dtype=torch.int64)
+            labels = copy.deepcopy(example)
+            labels[: len(prompt)] = -1
+            example_mask = example.ge(0)
+            label_mask = labels.ge(0)
+            example[~example_mask] = 0
+            labels[~label_mask] = IGNORE_INDEX
+
+            return {
+                "input_ids": example.tolist(),
+                "labels": labels.tolist(),
+                "attention_mask": example_mask.tolist(),
+            }
+
+    # check system prompt token seq or user prompt token seq is in the current token list
+    def check_header(self, targets, seq):
+        for i in range(len(seq) - 3):
+            if seq[i : i + 3] in targets:
+                return True
+        return False
 
-        return {
-            "input_ids": example.tolist(),
-            "labels": labels.tolist(),
-            "attention_mask":example_mask.tolist(),
-        }
+    def replace_target(self, target, seq):
+        for i in range(len(seq) - 3):
+            if seq[i : i + 3] == target:
+                seq[i], seq[i + 1], seq[i + 2] = -100, -100, -100
+        return seq