|
@@ -9,19 +9,22 @@ import itertools
|
|
|
|
|
|
|
|
|
B_INST, E_INST = "[INST]", "[/INST]"
|
|
|
+EOT_ID = 128009 #<|eot_id|>
|
|
|
|
|
|
def tokenize_dialog(dialog, tokenizer):
|
|
|
if tokenizer.vocab_size >= 128000:
|
|
|
dialog_tokens = tokenizer.apply_chat_template(dialog)
|
|
|
- dialog_tokens = dialog_tokens[:-4] # Remove generation prompt <|start_header_id|>assistant<|end_header_id|>\n\n
|
|
|
- eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
|
|
|
+ eot_indices = [i for i,n in enumerate(dialog_tokens) if n == EOT_ID]
|
|
|
labels = copy.copy(dialog_tokens)
|
|
|
+ #determine token for system and user
|
|
|
+ system_or_user = (tokenizer.encode("system")[-1], tokenizer.encode("user")[-1])
|
|
|
last_idx = 0
|
|
|
for n, idx in enumerate(eot_indices):
|
|
|
- if n % 2 == 1:
|
|
|
- last_idx = idx
|
|
|
- else:
|
|
|
+ role_token = labels[last_idx:idx+1][2]
|
|
|
+ if role_token in system_or_user:
|
|
|
+ # Set labels to -100 for system and user tokens to ignore in loss function
|
|
|
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
|
|
|
+ last_idx = idx
|
|
|
|
|
|
dialog_tokens = [dialog_tokens]
|
|
|
labels_tokens = [labels]
|