|
@@ -24,13 +24,14 @@ def tokenize_dialog(dialog, tokenizer):
|
|
|
labels = copy.copy(dialog_tokens)
|
|
|
#determine token for system and user
|
|
|
system_or_user = (tokenizer.encode("system")[-1], tokenizer.encode("user")[-1])
|
|
|
- last_idx = 0
|
|
|
+ labels[0] = -100 # bos token
|
|
|
+ last_idx = 1
|
|
|
for n, idx in enumerate(eot_indices):
|
|
|
- role_token = labels[last_idx:idx+1][2]
|
|
|
+ role_token = labels[last_idx+1]
|
|
|
if role_token in system_or_user:
|
|
|
# Set labels to -100 for system and user tokens to ignore in loss function
|
|
|
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
|
|
|
- last_idx = idx
|
|
|
+ last_idx = idx + 1
|
|
|
mask_target(tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False), labels)
|
|
|
|
|
|
dialog_tokens = [dialog_tokens]
|