|
@@ -11,6 +11,12 @@ import itertools
|
|
|
B_INST, E_INST = "[INST]", "[/INST]"
|
|
|
EOT_ID = 128009 #<|eot_id|>
|
|
|
|
|
|
+def mask_target(target,seq):
|
|
|
+ for i in range(len(seq)-len(target)):
|
|
|
+ if seq[i:i+len(target)] == target:
|
|
|
+ seq[i:i+len(target)] = [-100] * len(target)
|
|
|
+ return seq
|
|
|
+
|
|
|
def tokenize_dialog(dialog, tokenizer):
|
|
|
if tokenizer.vocab_size >= 128000:
|
|
|
dialog_tokens = tokenizer.apply_chat_template(dialog)
|
|
@@ -25,6 +31,7 @@ def tokenize_dialog(dialog, tokenizer):
|
|
|
# Set labels to -100 for system and user tokens to ignore in loss function
|
|
|
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
|
|
|
last_idx = idx
|
|
|
+ mask_target(tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False), labels)
|
|
|
|
|
|
dialog_tokens = [dialog_tokens]
|
|
|
labels_tokens = [labels]
|