浏览代码

Prompt Guard Tutorial

Cyrus Nikolaidis 9 月之前
父节点
当前提交
0c57646481

文件差异内容过多而无法显示
+ 817 - 0
recipes/responsible_ai/prompt_guard/Prompt Guard Tutorial.ipynb


+ 11 - 0
recipes/responsible_ai/prompt_guard/README.md

@@ -0,0 +1,11 @@
+# Prompt Guard demo
+<!-- markdown-link-check-disable -->
+Prompt Guard is a classifier model that provides input guardrails for LLM inference, particularly against *prompt attacks. For more details and model cards, please visit the main repository, [Meta Prompt Guard](https://github.com/meta-llama/PurpleLlama/tree/main/Prompt-Guard)
+
+This folder contains an example file to run inference with a locally hosted model, either using the Hugging Face Hub or a local path. It also contains a comprehensive demo demonstrating the scenarios in which the model is effective and a script for fine-tuning the model.
+
+This is a very small model and inference and fine-tuning are feasible on local CPUs.
+
+## Requirements
+1. Access to Prompt Guard model weights on Hugging Face. To get access, follow the steps described [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Prompt-Guard#download)
+2. Llama recipes package and it's dependencies [installed](https://github.com/meta-llama/llama-recipes?tab=readme-ov-file#installing)

+ 0 - 0
recipes/responsible_ai/prompt_guard/__init__.py


+ 85 - 0
recipes/responsible_ai/prompt_guard/inference.py

@@ -0,0 +1,85 @@
+import torch
+from torch.nn.functional import softmax
+
+from transformers import (
+    AutoModelForSequenceClassification,
+    AutoTokenizer,
+)
+
+"""
+Utilities for loading the PromptGuard model and evaluating text for jailbreaks and indirect injections.
+"""
+
+
+def load_model_and_tokenizer(model_name='meta-llama/PromptGuard'):
+    """
+    Load the PromptGuard model from Hugging Face or a local model.
+    
+    Args:
+        model_name (str): The name of the model to load. Default is 'meta-llama/PromptGuard'.
+        
+    Returns:
+        transformers.PreTrainedModel: The loaded model.
+    """
+    model = AutoModelForSequenceClassification.from_pretrained(model_name)
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
+    return model, tokenizer
+
+
+def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu'):
+    """
+    Evaluate the model on the given text with temperature-adjusted softmax.
+    
+    Args:
+        text (str): The input text to classify.
+        temperature (float): The temperature for the softmax function. Default is 1.0.
+        device (str): The device to evaluate the model on.
+        
+    Returns:
+        torch.Tensor: The probability of each class adjusted by the temperature.
+    """
+    # Encode the text
+    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
+    inputs = inputs.to(device)
+    # Get logits from the model
+    with torch.no_grad():
+        logits = model(**inputs).logits
+    # Apply temperature scaling
+    scaled_logits = logits / temperature
+    # Apply softmax to get probabilities
+    probabilities = softmax(scaled_logits, dim=-1)
+    return probabilities
+
+
+def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu'):
+    """
+    Evaluate the probability that a given string contains malicious jailbreak or prompt injection.
+    Appropriate for filtering dialogue between a user and an LLM.
+    
+    Args:
+        text (str): The input text to evaluate.
+        temperature (float): The temperature for the softmax function. Default is 1.0.
+        device (str): The device to evaluate the model on.
+        
+    Returns:
+        float: The probability of the text containing malicious content.
+    """
+    probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
+    return probabilities[0, 2].item()
+
+
+def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu'):
+    """
+    Evaluate the probability that a given string contains any embedded instructions (malicious or benign).
+    Appropriate for filtering third party inputs (e.g. web searches, tool outputs) into an LLM.
+    
+    Args:
+        text (str): The input text to evaluate.
+        temperature (float): The temperature for the softmax function. Default is 1.0.
+        device (str): The device to evaluate the model on.
+        
+    Returns:
+        float: The combined probability of the text containing malicious or embedded instructions.
+    """
+    probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
+    return (probabilities[0, 1] + probabilities[0, 2]).item()