1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- import torch
- from torch.nn.functional import softmax
- from transformers import (
- AutoModelForSequenceClassification,
- AutoTokenizer,
- )
- """
- Utilities for loading the PromptGuard model and evaluating text for jailbreaks and indirect injections.
- """
- def load_model_and_tokenizer(model_name='meta-llama/PromptGuard'):
- """
- Load the PromptGuard model from Hugging Face or a local model.
-
- Args:
- model_name (str): The name of the model to load. Default is 'meta-llama/PromptGuard'.
-
- Returns:
- transformers.PreTrainedModel: The loaded model.
- """
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
- tokenizer = AutoTokenizer.from_pretrained(model_name)
- return model, tokenizer
- def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu'):
- """
- Evaluate the model on the given text with temperature-adjusted softmax.
-
- Args:
- text (str): The input text to classify.
- temperature (float): The temperature for the softmax function. Default is 1.0.
- device (str): The device to evaluate the model on.
-
- Returns:
- torch.Tensor: The probability of each class adjusted by the temperature.
- """
- # Encode the text
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
- inputs = inputs.to(device)
- # Get logits from the model
- with torch.no_grad():
- logits = model(**inputs).logits
- # Apply temperature scaling
- scaled_logits = logits / temperature
- # Apply softmax to get probabilities
- probabilities = softmax(scaled_logits, dim=-1)
- return probabilities
- def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu'):
- """
- Evaluate the probability that a given string contains malicious jailbreak or prompt injection.
- Appropriate for filtering dialogue between a user and an LLM.
-
- Args:
- text (str): The input text to evaluate.
- temperature (float): The temperature for the softmax function. Default is 1.0.
- device (str): The device to evaluate the model on.
-
- Returns:
- float: The probability of the text containing malicious content.
- """
- probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
- return probabilities[0, 2].item()
- def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu'):
- """
- Evaluate the probability that a given string contains any embedded instructions (malicious or benign).
- Appropriate for filtering third party inputs (e.g. web searches, tool outputs) into an LLM.
-
- Args:
- text (str): The input text to evaluate.
- temperature (float): The temperature for the softmax function. Default is 1.0.
- device (str): The device to evaluate the model on.
-
- Returns:
- float: The combined probability of the text containing malicious or embedded instructions.
- """
- probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
- return (probabilities[0, 1] + probabilities[0, 2]).item()
|