inference.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import torch
  2. from torch.nn.functional import softmax
  3. from transformers import (
  4. AutoModelForSequenceClassification,
  5. AutoTokenizer,
  6. )
  7. """
  8. Utilities for loading the PromptGuard model and evaluating text for jailbreaks and indirect injections.
  9. """
  10. def load_model_and_tokenizer(model_name='meta-llama/PromptGuard'):
  11. """
  12. Load the PromptGuard model from Hugging Face or a local model.
  13. Args:
  14. model_name (str): The name of the model to load. Default is 'meta-llama/PromptGuard'.
  15. Returns:
  16. transformers.PreTrainedModel: The loaded model.
  17. """
  18. model = AutoModelForSequenceClassification.from_pretrained(model_name)
  19. tokenizer = AutoTokenizer.from_pretrained(model_name)
  20. return model, tokenizer
  21. def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu'):
  22. """
  23. Evaluate the model on the given text with temperature-adjusted softmax.
  24. Args:
  25. text (str): The input text to classify.
  26. temperature (float): The temperature for the softmax function. Default is 1.0.
  27. device (str): The device to evaluate the model on.
  28. Returns:
  29. torch.Tensor: The probability of each class adjusted by the temperature.
  30. """
  31. # Encode the text
  32. inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
  33. inputs = inputs.to(device)
  34. # Get logits from the model
  35. with torch.no_grad():
  36. logits = model(**inputs).logits
  37. # Apply temperature scaling
  38. scaled_logits = logits / temperature
  39. # Apply softmax to get probabilities
  40. probabilities = softmax(scaled_logits, dim=-1)
  41. return probabilities
  42. def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu'):
  43. """
  44. Evaluate the probability that a given string contains malicious jailbreak or prompt injection.
  45. Appropriate for filtering dialogue between a user and an LLM.
  46. Args:
  47. text (str): The input text to evaluate.
  48. temperature (float): The temperature for the softmax function. Default is 1.0.
  49. device (str): The device to evaluate the model on.
  50. Returns:
  51. float: The probability of the text containing malicious content.
  52. """
  53. probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
  54. return probabilities[0, 2].item()
  55. def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu'):
  56. """
  57. Evaluate the probability that a given string contains any embedded instructions (malicious or benign).
  58. Appropriate for filtering third party inputs (e.g. web searches, tool outputs) into an LLM.
  59. Args:
  60. text (str): The input text to evaluate.
  61. temperature (float): The temperature for the softmax function. Default is 1.0.
  62. device (str): The device to evaluate the model on.
  63. Returns:
  64. float: The combined probability of the text containing malicious or embedded instructions.
  65. """
  66. probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
  67. return (probabilities[0, 1] + probabilities[0, 2]).item()