|
@@ -31,7 +31,45 @@ def load_model_and_tokenizer(model_name='meta-llama/Prompt-Guard-86M'):
|
|
|
return model, tokenizer
|
|
|
|
|
|
|
|
|
-def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu'):
|
|
|
+def preprocess_text_for_promptguard(text: str, tokenizer) -> str:
|
|
|
+ """
|
|
|
+ Preprocess the text by removing spaces that break apart larger tokens.
|
|
|
+ This hotfixes a workaround to PromptGuard, where spaces can be inserted into a string
|
|
|
+ to allow the string to be classified as benign.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ text (str): The input text to preprocess.
|
|
|
+ tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: The preprocessed text.
|
|
|
+ """
|
|
|
+
|
|
|
+ try:
|
|
|
+ cleaned_text = ''
|
|
|
+ index_map = []
|
|
|
+ for i, char in enumerate(text):
|
|
|
+ if not char.isspace():
|
|
|
+ cleaned_text += char
|
|
|
+ index_map.append(i)
|
|
|
+ tokens = tokenizer.tokenize(cleaned_text)
|
|
|
+ result = []
|
|
|
+ last_end = 0
|
|
|
+ for token in tokens:
|
|
|
+ token_str = tokenizer.convert_tokens_to_string([token])
|
|
|
+ start = cleaned_text.index(token_str, last_end)
|
|
|
+ end = start + len(token_str)
|
|
|
+ original_start = index_map[start]
|
|
|
+ if original_start > 0 and text[original_start - 1].isspace():
|
|
|
+ result.append(' ')
|
|
|
+ result.append(token_str)
|
|
|
+ last_end = end
|
|
|
+ return ''.join(result)
|
|
|
+ except Exception:
|
|
|
+ return text
|
|
|
+
|
|
|
+
|
|
|
+def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
|
|
|
"""
|
|
|
Evaluate the model on the given text with temperature-adjusted softmax.
|
|
|
Note, as this is a DeBERTa model, the input text should have a maximum length of 512.
|
|
@@ -44,6 +82,8 @@ def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu
|
|
|
Returns:
|
|
|
torch.Tensor: The probability of each class adjusted by the temperature.
|
|
|
"""
|
|
|
+ if preprocess:
|
|
|
+ text = preprocess_text_for_promptguard(text, tokenizer)
|
|
|
# Encode the text
|
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
|
|
inputs = inputs.to(device)
|
|
@@ -57,7 +97,7 @@ def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu
|
|
|
return probabilities
|
|
|
|
|
|
|
|
|
-def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu'):
|
|
|
+def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
|
|
|
"""
|
|
|
Evaluate the probability that a given string contains malicious jailbreak or prompt injection.
|
|
|
Appropriate for filtering dialogue between a user and an LLM.
|
|
@@ -70,11 +110,11 @@ def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu'):
|
|
|
Returns:
|
|
|
float: The probability of the text containing malicious content.
|
|
|
"""
|
|
|
- probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
|
|
|
+ probabilities = get_class_probabilities(model, tokenizer, text, temperature, device, preprocess)
|
|
|
return probabilities[0, 2].item()
|
|
|
|
|
|
|
|
|
-def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu'):
|
|
|
+def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
|
|
|
"""
|
|
|
Evaluate the probability that a given string contains any embedded instructions (malicious or benign).
|
|
|
Appropriate for filtering third party inputs (e.g. web searches, tool outputs) into an LLM.
|
|
@@ -87,11 +127,11 @@ def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device
|
|
|
Returns:
|
|
|
float: The combined probability of the text containing malicious or embedded instructions.
|
|
|
"""
|
|
|
- probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
|
|
|
+ probabilities = get_class_probabilities(model, tokenizer, text, temperature, device, preprocess)
|
|
|
return (probabilities[0, 1] + probabilities[0, 2]).item()
|
|
|
|
|
|
|
|
|
-def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
|
|
|
+def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu', preprocess=True):
|
|
|
"""
|
|
|
Process a batch of texts and return their class probabilities.
|
|
|
Args:
|
|
@@ -104,6 +144,8 @@ def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
|
|
|
Returns:
|
|
|
torch.Tensor: A tensor containing the class probabilities for each text in the batch.
|
|
|
"""
|
|
|
+ if preprocess:
|
|
|
+ texts = [preprocess_text_for_promptguard(text, tokenizer) for text in texts]
|
|
|
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
|
|
inputs = inputs.to(device)
|
|
|
with torch.no_grad():
|
|
@@ -113,7 +155,7 @@ def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
|
|
|
return probabilities
|
|
|
|
|
|
|
|
|
-def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0, device='cpu', max_batch_size=16):
|
|
|
+def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
|
|
|
"""
|
|
|
Compute scores for a list of texts, handling texts of arbitrary length by breaking them into chunks and processing in parallel.
|
|
|
Args:
|
|
@@ -138,7 +180,7 @@ def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0
|
|
|
for i in range(0, len(all_chunks), max_batch_size):
|
|
|
batch_chunks = all_chunks[i:i+max_batch_size]
|
|
|
batch_indices = text_indices[i:i+max_batch_size]
|
|
|
- probabilities = process_text_batch(model, tokenizer, batch_chunks, temperature, device)
|
|
|
+ probabilities = process_text_batch(model, tokenizer, batch_chunks, temperature, device, preprocess)
|
|
|
scores = probabilities[:, score_indices].sum(dim=1).tolist()
|
|
|
|
|
|
for idx, score in zip(batch_indices, scores):
|
|
@@ -146,7 +188,7 @@ def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0
|
|
|
return all_scores
|
|
|
|
|
|
|
|
|
-def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16):
|
|
|
+def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
|
|
|
"""
|
|
|
Compute jailbreak scores for a list of texts.
|
|
|
Args:
|
|
@@ -160,10 +202,10 @@ def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, dev
|
|
|
Returns:
|
|
|
list[float]: A list of jailbreak scores for each text.
|
|
|
"""
|
|
|
- return get_scores_for_texts(model, tokenizer, texts, [2], temperature, device, max_batch_size)
|
|
|
+ return get_scores_for_texts(model, tokenizer, texts, [2], temperature, device, max_batch_size, preprocess)
|
|
|
|
|
|
|
|
|
-def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16):
|
|
|
+def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
|
|
|
"""
|
|
|
Compute indirect injection scores for a list of texts.
|
|
|
Args:
|
|
@@ -177,4 +219,4 @@ def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature
|
|
|
Returns:
|
|
|
list[float]: A list of indirect injection scores for each text.
|
|
|
"""
|
|
|
- return get_scores_for_texts(model, tokenizer, texts, [1, 2], temperature, device, max_batch_size)
|
|
|
+ return get_scores_for_texts(model, tokenizer, texts, [1, 2], temperature, device, max_batch_size, preprocess)
|