Browse Source

Fix to take bos into account

Matthias Reso 8 tháng trước cách đây
mục cha
commit
d9939cd0cf

+ 4 - 3
recipes/quickstart/finetuning/datasets/custom_dataset.py

@@ -24,13 +24,14 @@ def tokenize_dialog(dialog, tokenizer):
         labels = copy.copy(dialog_tokens)
         #determine token for system and user 
         system_or_user = (tokenizer.encode("system")[-1], tokenizer.encode("user")[-1])
-        last_idx = 0
+        labels[0] = -100 # bos token
+        last_idx = 1
         for n, idx in enumerate(eot_indices):
-            role_token = labels[last_idx:idx+1][2]
+            role_token = labels[last_idx+1]
             if role_token in system_or_user:
                 # Set labels to -100 for system and user tokens to ignore in loss function
                 labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
-            last_idx = idx
+            last_idx = idx + 1
         mask_target(tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False), labels)
 
         dialog_tokens = [dialog_tokens]

+ 2 - 2
src/tests/datasets/test_custom_dataset.py

@@ -139,5 +139,5 @@ def test_tokenize_dialog(tokenizer, monkeypatch, setup_tokenizer, llama_version)
         assert result["labels"].count(-100) == 11 + 12
     else:
         assert result["labels"][:38] == [-100] * 38
-        assert result["labels"][42:54] == [-100] * 12
-        assert result["labels"].count(-100) == 38 + 12
+        assert result["labels"][43:54] == [-100] * 11
+        assert result["labels"].count(-100) == 38 + 11