| 
					
				 | 
			
			
				@@ -9,19 +9,22 @@ import itertools 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 B_INST, E_INST = "[INST]", "[/INST]" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+EOT_ID = 128009 #<|eot_id|> 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def tokenize_dialog(dialog, tokenizer): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if tokenizer.vocab_size >= 128000: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         dialog_tokens = tokenizer.apply_chat_template(dialog) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        dialog_tokens = dialog_tokens[:-4] # Remove generation prompt <|start_header_id|>assistant<|end_header_id|>\n\n 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        eot_indices = [i for i,n in enumerate(dialog_tokens) if n == EOT_ID] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         labels = copy.copy(dialog_tokens) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #determine token for system and user  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        system_or_user = (tokenizer.encode("system")[-1], tokenizer.encode("user")[-1]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         last_idx = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for n, idx in enumerate(eot_indices): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if n % 2 == 1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                last_idx = idx 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            role_token = labels[last_idx:idx+1][2] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if role_token in system_or_user: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # Set labels to -100 for system and user tokens to ignore in loss function 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 labels[last_idx:idx+1] = [-100] * (idx-last_idx+1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            last_idx = idx 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         dialog_tokens = [dialog_tokens] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         labels_tokens = [labels] 
			 |