Browse Source

Add preprocessor to patch PromptGuard scores for inserted characters (#636)

Cyrus Nikolaidis 7 months ago
parent
commit
3a99a54582
1 changed files with 54 additions and 12 deletions
  1. 54 12
      recipes/responsible_ai/prompt_guard/inference.py

+ 54 - 12
recipes/responsible_ai/prompt_guard/inference.py

@@ -31,7 +31,45 @@ def load_model_and_tokenizer(model_name='meta-llama/Prompt-Guard-86M'):
     return model, tokenizer
 
 
-def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu'):
+def preprocess_text_for_promptguard(text: str, tokenizer) -> str:
+    """
+    Preprocess the text by removing spaces that break apart larger tokens.
+    This hotfixes a workaround to PromptGuard, where spaces can be inserted into a string
+    to allow the string to be classified as benign.
+
+    Args:
+        text (str): The input text to preprocess.
+        tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
+
+    Returns:
+        str: The preprocessed text.
+    """
+
+    try:
+        cleaned_text = ''
+        index_map = []
+        for i, char in enumerate(text):
+            if not char.isspace():
+                cleaned_text += char
+                index_map.append(i)
+        tokens = tokenizer.tokenize(cleaned_text)
+        result = []
+        last_end = 0
+        for token in tokens:
+            token_str = tokenizer.convert_tokens_to_string([token])
+            start = cleaned_text.index(token_str, last_end)
+            end = start + len(token_str)
+            original_start = index_map[start]
+            if original_start > 0 and text[original_start - 1].isspace():
+                result.append(' ')
+            result.append(token_str)
+            last_end = end
+        return ''.join(result)
+    except Exception:
+        return text
+
+
+def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
     """
     Evaluate the model on the given text with temperature-adjusted softmax.
     Note, as this is a DeBERTa model, the input text should have a maximum length of 512.
@@ -44,6 +82,8 @@ def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu
     Returns:
         torch.Tensor: The probability of each class adjusted by the temperature.
     """
+    if preprocess:
+        text = preprocess_text_for_promptguard(text, tokenizer)
     # Encode the text
     inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
     inputs = inputs.to(device)
@@ -57,7 +97,7 @@ def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu
     return probabilities
 
 
-def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu'):
+def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
     """
     Evaluate the probability that a given string contains malicious jailbreak or prompt injection.
     Appropriate for filtering dialogue between a user and an LLM.
@@ -70,11 +110,11 @@ def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu'):
     Returns:
         float: The probability of the text containing malicious content.
     """
-    probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
+    probabilities = get_class_probabilities(model, tokenizer, text, temperature, device, preprocess)
     return probabilities[0, 2].item()
 
 
-def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu'):
+def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
     """
     Evaluate the probability that a given string contains any embedded instructions (malicious or benign).
     Appropriate for filtering third party inputs (e.g. web searches, tool outputs) into an LLM.
@@ -87,11 +127,11 @@ def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device
     Returns:
         float: The combined probability of the text containing malicious or embedded instructions.
     """
-    probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
+    probabilities = get_class_probabilities(model, tokenizer, text, temperature, device, preprocess)
     return (probabilities[0, 1] + probabilities[0, 2]).item()
 
 
-def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
+def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu', preprocess=True):
     """
     Process a batch of texts and return their class probabilities.
     Args:
@@ -104,6 +144,8 @@ def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
     Returns:
         torch.Tensor: A tensor containing the class probabilities for each text in the batch.
     """
+    if preprocess:
+        texts = [preprocess_text_for_promptguard(text, tokenizer) for text in texts]
     inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
     inputs = inputs.to(device)
     with torch.no_grad():
@@ -113,7 +155,7 @@ def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
     return probabilities
 
 
-def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0, device='cpu', max_batch_size=16):
+def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
     """
     Compute scores for a list of texts, handling texts of arbitrary length by breaking them into chunks and processing in parallel.
     Args:
@@ -138,7 +180,7 @@ def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0
     for i in range(0, len(all_chunks), max_batch_size):
         batch_chunks = all_chunks[i:i+max_batch_size]
         batch_indices = text_indices[i:i+max_batch_size]
-        probabilities = process_text_batch(model, tokenizer, batch_chunks, temperature, device)
+        probabilities = process_text_batch(model, tokenizer, batch_chunks, temperature, device, preprocess)
         scores = probabilities[:, score_indices].sum(dim=1).tolist()
         
         for idx, score in zip(batch_indices, scores):
@@ -146,7 +188,7 @@ def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0
     return all_scores
 
 
-def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16):
+def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
     """
     Compute jailbreak scores for a list of texts.
     Args:
@@ -160,10 +202,10 @@ def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, dev
     Returns:
         list[float]: A list of jailbreak scores for each text.
     """
-    return get_scores_for_texts(model, tokenizer, texts, [2], temperature, device, max_batch_size)
+    return get_scores_for_texts(model, tokenizer, texts, [2], temperature, device, max_batch_size, preprocess)
 
 
-def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16):
+def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
     """
     Compute indirect injection scores for a list of texts.
     Args:
@@ -177,4 +219,4 @@ def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature
     Returns:
         list[float]: A list of indirect injection scores for each text.
     """
-    return get_scores_for_texts(model, tokenizer, texts, [1, 2], temperature, device, max_batch_size)
+    return get_scores_for_texts(model, tokenizer, texts, [1, 2], temperature, device, max_batch_size, preprocess)