Browse Source

Add preprocessor to patch PromptGuard scores for inserted characters (#636)

Cyrus Nikolaidis 1 year ago
parent
commit
3a99a54582
1 changed files with 54 additions and 12 deletions
  1. 54 12
      recipes/responsible_ai/prompt_guard/inference.py

+ 54 - 12
recipes/responsible_ai/prompt_guard/inference.py

@@ -31,7 +31,45 @@ def load_model_and_tokenizer(model_name='meta-llama/Prompt-Guard-86M'):
     return model, tokenizer
     return model, tokenizer
 
 
 
 
-def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu'):
+def preprocess_text_for_promptguard(text: str, tokenizer) -> str:
+    """
+    Preprocess the text by removing spaces that break apart larger tokens.
+    This hotfixes a workaround to PromptGuard, where spaces can be inserted into a string
+    to allow the string to be classified as benign.
+
+    Args:
+        text (str): The input text to preprocess.
+        tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
+
+    Returns:
+        str: The preprocessed text.
+    """
+
+    try:
+        cleaned_text = ''
+        index_map = []
+        for i, char in enumerate(text):
+            if not char.isspace():
+                cleaned_text += char
+                index_map.append(i)
+        tokens = tokenizer.tokenize(cleaned_text)
+        result = []
+        last_end = 0
+        for token in tokens:
+            token_str = tokenizer.convert_tokens_to_string([token])
+            start = cleaned_text.index(token_str, last_end)
+            end = start + len(token_str)
+            original_start = index_map[start]
+            if original_start > 0 and text[original_start - 1].isspace():
+                result.append(' ')
+            result.append(token_str)
+            last_end = end
+        return ''.join(result)
+    except Exception:
+        return text
+
+
+def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
     """
     """
     Evaluate the model on the given text with temperature-adjusted softmax.
     Evaluate the model on the given text with temperature-adjusted softmax.
     Note, as this is a DeBERTa model, the input text should have a maximum length of 512.
     Note, as this is a DeBERTa model, the input text should have a maximum length of 512.
@@ -44,6 +82,8 @@ def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu
     Returns:
     Returns:
         torch.Tensor: The probability of each class adjusted by the temperature.
         torch.Tensor: The probability of each class adjusted by the temperature.
     """
     """
+    if preprocess:
+        text = preprocess_text_for_promptguard(text, tokenizer)
     # Encode the text
     # Encode the text
     inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
     inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
     inputs = inputs.to(device)
     inputs = inputs.to(device)
@@ -57,7 +97,7 @@ def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu
     return probabilities
     return probabilities
 
 
 
 
-def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu'):
+def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
     """
     """
     Evaluate the probability that a given string contains malicious jailbreak or prompt injection.
     Evaluate the probability that a given string contains malicious jailbreak or prompt injection.
     Appropriate for filtering dialogue between a user and an LLM.
     Appropriate for filtering dialogue between a user and an LLM.
@@ -70,11 +110,11 @@ def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu'):
     Returns:
     Returns:
         float: The probability of the text containing malicious content.
         float: The probability of the text containing malicious content.
     """
     """
-    probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
+    probabilities = get_class_probabilities(model, tokenizer, text, temperature, device, preprocess)
     return probabilities[0, 2].item()
     return probabilities[0, 2].item()
 
 
 
 
-def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu'):
+def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
     """
     """
     Evaluate the probability that a given string contains any embedded instructions (malicious or benign).
     Evaluate the probability that a given string contains any embedded instructions (malicious or benign).
     Appropriate for filtering third party inputs (e.g. web searches, tool outputs) into an LLM.
     Appropriate for filtering third party inputs (e.g. web searches, tool outputs) into an LLM.
@@ -87,11 +127,11 @@ def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device
     Returns:
     Returns:
         float: The combined probability of the text containing malicious or embedded instructions.
         float: The combined probability of the text containing malicious or embedded instructions.
     """
     """
-    probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
+    probabilities = get_class_probabilities(model, tokenizer, text, temperature, device, preprocess)
     return (probabilities[0, 1] + probabilities[0, 2]).item()
     return (probabilities[0, 1] + probabilities[0, 2]).item()
 
 
 
 
-def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
+def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu', preprocess=True):
     """
     """
     Process a batch of texts and return their class probabilities.
     Process a batch of texts and return their class probabilities.
     Args:
     Args:
@@ -104,6 +144,8 @@ def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
     Returns:
     Returns:
         torch.Tensor: A tensor containing the class probabilities for each text in the batch.
         torch.Tensor: A tensor containing the class probabilities for each text in the batch.
     """
     """
+    if preprocess:
+        texts = [preprocess_text_for_promptguard(text, tokenizer) for text in texts]
     inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
     inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
     inputs = inputs.to(device)
     inputs = inputs.to(device)
     with torch.no_grad():
     with torch.no_grad():
@@ -113,7 +155,7 @@ def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
     return probabilities
     return probabilities
 
 
 
 
-def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0, device='cpu', max_batch_size=16):
+def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
     """
     """
     Compute scores for a list of texts, handling texts of arbitrary length by breaking them into chunks and processing in parallel.
     Compute scores for a list of texts, handling texts of arbitrary length by breaking them into chunks and processing in parallel.
     Args:
     Args:
@@ -138,7 +180,7 @@ def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0
     for i in range(0, len(all_chunks), max_batch_size):
     for i in range(0, len(all_chunks), max_batch_size):
         batch_chunks = all_chunks[i:i+max_batch_size]
         batch_chunks = all_chunks[i:i+max_batch_size]
         batch_indices = text_indices[i:i+max_batch_size]
         batch_indices = text_indices[i:i+max_batch_size]
-        probabilities = process_text_batch(model, tokenizer, batch_chunks, temperature, device)
+        probabilities = process_text_batch(model, tokenizer, batch_chunks, temperature, device, preprocess)
         scores = probabilities[:, score_indices].sum(dim=1).tolist()
         scores = probabilities[:, score_indices].sum(dim=1).tolist()
         
         
         for idx, score in zip(batch_indices, scores):
         for idx, score in zip(batch_indices, scores):
@@ -146,7 +188,7 @@ def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0
     return all_scores
     return all_scores
 
 
 
 
-def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16):
+def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
     """
     """
     Compute jailbreak scores for a list of texts.
     Compute jailbreak scores for a list of texts.
     Args:
     Args:
@@ -160,10 +202,10 @@ def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, dev
     Returns:
     Returns:
         list[float]: A list of jailbreak scores for each text.
         list[float]: A list of jailbreak scores for each text.
     """
     """
-    return get_scores_for_texts(model, tokenizer, texts, [2], temperature, device, max_batch_size)
+    return get_scores_for_texts(model, tokenizer, texts, [2], temperature, device, max_batch_size, preprocess)
 
 
 
 
-def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16):
+def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
     """
     """
     Compute indirect injection scores for a list of texts.
     Compute indirect injection scores for a list of texts.
     Args:
     Args:
@@ -177,4 +219,4 @@ def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature
     Returns:
     Returns:
         list[float]: A list of indirect injection scores for each text.
         list[float]: A list of indirect injection scores for each text.
     """
     """
-    return get_scores_for_texts(model, tokenizer, texts, [1, 2], temperature, device, max_batch_size)
+    return get_scores_for_texts(model, tokenizer, texts, [1, 2], temperature, device, max_batch_size, preprocess)