Matthias Reso пре 7 месеци
родитељ
комит
3c9f263c79

+ 7 - 0
recipes/quickstart/finetuning/datasets/custom_dataset.py

@@ -11,6 +11,12 @@ import itertools
 B_INST, E_INST = "[INST]", "[/INST]"
 EOT_ID = 128009 #<|eot_id|>
 
+def mask_target(target,seq):
+    for i in range(len(seq)-len(target)):
+        if seq[i:i+len(target)] == target:
+            seq[i:i+len(target)] = [-100] * len(target)
+    return seq
+
 def tokenize_dialog(dialog, tokenizer):
     if tokenizer.vocab_size >= 128000:
         dialog_tokens = tokenizer.apply_chat_template(dialog)
@@ -25,6 +31,7 @@ def tokenize_dialog(dialog, tokenizer):
                 # Set labels to -100 for system and user tokens to ignore in loss function
                 labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
             last_idx = idx
+        mask_target(tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False), labels)
 
         dialog_tokens = [dialog_tokens]
         labels_tokens = [labels]

+ 3 - 3
src/tests/datasets/test_custom_dataset.py

@@ -138,6 +138,6 @@ def test_tokenize_dialog(tokenizer, monkeypatch, setup_tokenizer, llama_version)
         assert result["labels"][17:28] == [-100] * 11
         assert result["labels"].count(-100) == 11 + 12
     else:
-        assert result["labels"][:35] == [-100] * 35
-        assert result["labels"][42:51] == [-100] * 9
-        assert result["labels"].count(-100) == 35 + 9
+        assert result["labels"][:38] == [-100] * 38
+        assert result["labels"][42:54] == [-100] * 12
+        assert result["labels"].count(-100) == 38 + 12