inference.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import torch
  2. from torch.nn.functional import softmax
  3. from transformers import (
  4. AutoModelForSequenceClassification,
  5. AutoTokenizer,
  6. )
  7. """
  8. Utilities for loading the PromptGuard model and evaluating text for jailbreaks and indirect injections.
  9. Note that the underlying model has a maximum recommended input size of 512 tokens as a DeBERTa model.
  10. The final two functions in this file implement efficient parallel batched evaluation of the model on a list
  11. of input strings of arbirary length, with the final score for each input being the maximum score across all
  12. chunks of the input string.
  13. """
  14. def load_model_and_tokenizer(model_name='meta-llama/Prompt-Guard-86M'):
  15. """
  16. Load the PromptGuard model from Hugging Face or a local model.
  17. Args:
  18. model_name (str): The name of the model to load. Default is 'meta-llama/Prompt-Guard-86M'.
  19. Returns:
  20. transformers.PreTrainedModel: The loaded model.
  21. """
  22. model = AutoModelForSequenceClassification.from_pretrained(model_name)
  23. tokenizer = AutoTokenizer.from_pretrained(model_name)
  24. return model, tokenizer
  25. def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu'):
  26. """
  27. Evaluate the model on the given text with temperature-adjusted softmax.
  28. Note, as this is a DeBERTa model, the input text should have a maximum length of 512.
  29. Args:
  30. text (str): The input text to classify.
  31. temperature (float): The temperature for the softmax function. Default is 1.0.
  32. device (str): The device to evaluate the model on.
  33. Returns:
  34. torch.Tensor: The probability of each class adjusted by the temperature.
  35. """
  36. # Encode the text
  37. inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
  38. inputs = inputs.to(device)
  39. # Get logits from the model
  40. with torch.no_grad():
  41. logits = model(**inputs).logits
  42. # Apply temperature scaling
  43. scaled_logits = logits / temperature
  44. # Apply softmax to get probabilities
  45. probabilities = softmax(scaled_logits, dim=-1)
  46. return probabilities
  47. def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu'):
  48. """
  49. Evaluate the probability that a given string contains malicious jailbreak or prompt injection.
  50. Appropriate for filtering dialogue between a user and an LLM.
  51. Args:
  52. text (str): The input text to evaluate.
  53. temperature (float): The temperature for the softmax function. Default is 1.0.
  54. device (str): The device to evaluate the model on.
  55. Returns:
  56. float: The probability of the text containing malicious content.
  57. """
  58. probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
  59. return probabilities[0, 2].item()
  60. def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu'):
  61. """
  62. Evaluate the probability that a given string contains any embedded instructions (malicious or benign).
  63. Appropriate for filtering third party inputs (e.g. web searches, tool outputs) into an LLM.
  64. Args:
  65. text (str): The input text to evaluate.
  66. temperature (float): The temperature for the softmax function. Default is 1.0.
  67. device (str): The device to evaluate the model on.
  68. Returns:
  69. float: The combined probability of the text containing malicious or embedded instructions.
  70. """
  71. probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
  72. return (probabilities[0, 1] + probabilities[0, 2]).item()
  73. def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
  74. """
  75. Process a batch of texts and return their class probabilities.
  76. Args:
  77. model (transformers.PreTrainedModel): The loaded model.
  78. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  79. texts (list[str]): A list of texts to process.
  80. temperature (float): The temperature for the softmax function.
  81. device (str): The device to evaluate the model on.
  82. Returns:
  83. torch.Tensor: A tensor containing the class probabilities for each text in the batch.
  84. """
  85. inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
  86. inputs = inputs.to(device)
  87. with torch.no_grad():
  88. logits = model(**inputs).logits
  89. scaled_logits = logits / temperature
  90. probabilities = softmax(scaled_logits, dim=-1)
  91. return probabilities
  92. def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0, device='cpu', max_batch_size=16):
  93. """
  94. Compute scores for a list of texts, handling texts of arbitrary length by breaking them into chunks and processing in parallel.
  95. Args:
  96. model (transformers.PreTrainedModel): The loaded model.
  97. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  98. texts (list[str]): A list of texts to evaluate.
  99. score_indices (list[int]): Indices of scores to sum for final score calculation.
  100. temperature (float): The temperature for the softmax function.
  101. device (str): The device to evaluate the model on.
  102. max_batch_size (int): The maximum number of text chunks to process in a single batch.
  103. Returns:
  104. list[float]: A list of scores for each text.
  105. """
  106. all_chunks = []
  107. text_indices = []
  108. for index, text in enumerate(texts):
  109. chunks = [text[i:i+512] for i in range(0, len(text), 512)]
  110. all_chunks.extend(chunks)
  111. text_indices.extend([index] * len(chunks))
  112. all_scores = [0] * len(texts)
  113. for i in range(0, len(all_chunks), max_batch_size):
  114. batch_chunks = all_chunks[i:i+max_batch_size]
  115. batch_indices = text_indices[i:i+max_batch_size]
  116. probabilities = process_text_batch(model, tokenizer, batch_chunks, temperature, device)
  117. scores = probabilities[:, score_indices].sum(dim=1).tolist()
  118. for idx, score in zip(batch_indices, scores):
  119. all_scores[idx] = max(all_scores[idx], score)
  120. return all_scores
  121. def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16):
  122. """
  123. Compute jailbreak scores for a list of texts.
  124. Args:
  125. model (transformers.PreTrainedModel): The loaded model.
  126. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  127. texts (list[str]): A list of texts to evaluate.
  128. temperature (float): The temperature for the softmax function.
  129. device (str): The device to evaluate the model on.
  130. max_batch_size (int): The maximum number of text chunks to process in a single batch.
  131. Returns:
  132. list[float]: A list of jailbreak scores for each text.
  133. """
  134. return get_scores_for_texts(model, tokenizer, texts, [2], temperature, device, max_batch_size)
  135. def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16):
  136. """
  137. Compute indirect injection scores for a list of texts.
  138. Args:
  139. model (transformers.PreTrainedModel): The loaded model.
  140. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  141. texts (list[str]): A list of texts to evaluate.
  142. temperature (float): The temperature for the softmax function.
  143. device (str): The device to evaluate the model on.
  144. max_batch_size (int): The maximum number of text chunks to process in a single batch.
  145. Returns:
  146. list[float]: A list of indirect injection scores for each text.
  147. """
  148. return get_scores_for_texts(model, tokenizer, texts, [1, 2], temperature, device, max_batch_size)