prompt_guard_1_inference.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. import torch
  2. from torch.nn.functional import softmax
  3. from transformers import AutoModelForSequenceClassification, AutoTokenizer
  4. """
  5. Utilities for loading the PromptGuard 1 model and evaluating text for jailbreaks and indirect injections.
  6. NOTE: this code is for PromptGuard 1. For our newer PromptGuard 2 model, see inference.py
  7. Note that the underlying model has a maximum recommended input size of 512 tokens as a DeBERTa model.
  8. The final two functions in this file implement efficient parallel batched evaluation of the model on a list
  9. of input strings of arbitrary length, with the final score for each input being the maximum score across all
  10. chunks of the input string.
  11. """
  12. def load_model_and_tokenizer(model_name="meta-llama/Prompt-Guard-86M"):
  13. """
  14. Load the PromptGuard model from Hugging Face or a local model.
  15. Args:
  16. model_name (str): The name of the model to load. Default is 'meta-llama/Prompt-Guard-86M'.
  17. Returns:
  18. transformers.PreTrainedModel: The loaded model.
  19. """
  20. model = AutoModelForSequenceClassification.from_pretrained(model_name)
  21. tokenizer = AutoTokenizer.from_pretrained(model_name)
  22. return model, tokenizer
  23. def preprocess_text_for_promptguard(text: str, tokenizer) -> str:
  24. """
  25. Preprocess the text by removing spaces that break apart larger tokens.
  26. This hotfixes a workaround to PromptGuard, where spaces can be inserted into a string
  27. to allow the string to be classified as benign.
  28. Args:
  29. text (str): The input text to preprocess.
  30. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  31. Returns:
  32. str: The preprocessed text.
  33. """
  34. try:
  35. cleaned_text = ""
  36. index_map = []
  37. for i, char in enumerate(text):
  38. if not char.isspace():
  39. cleaned_text += char
  40. index_map.append(i)
  41. tokens = tokenizer.tokenize(cleaned_text)
  42. result = []
  43. last_end = 0
  44. for token in tokens:
  45. token_str = tokenizer.convert_tokens_to_string([token])
  46. start = cleaned_text.index(token_str, last_end)
  47. end = start + len(token_str)
  48. original_start = index_map[start]
  49. if original_start > 0 and text[original_start - 1].isspace():
  50. result.append(" ")
  51. result.append(token_str)
  52. last_end = end
  53. return "".join(result)
  54. except Exception:
  55. return text
  56. def get_class_probabilities(
  57. model, tokenizer, text, temperature=1.0, device="cpu", preprocess=True
  58. ):
  59. """
  60. Evaluate the model on the given text with temperature-adjusted softmax.
  61. Note, as this is a DeBERTa model, the input text should have a maximum length of 512.
  62. Args:
  63. text (str): The input text to classify.
  64. temperature (float): The temperature for the softmax function. Default is 1.0.
  65. device (str): The device to evaluate the model on.
  66. Returns:
  67. torch.Tensor: The probability of each class adjusted by the temperature.
  68. """
  69. if preprocess:
  70. text = preprocess_text_for_promptguard(text, tokenizer)
  71. # Encode the text
  72. inputs = tokenizer(
  73. text, return_tensors="pt", padding=True, truncation=True, max_length=512
  74. )
  75. inputs = inputs.to(device)
  76. # Get logits from the model
  77. with torch.no_grad():
  78. logits = model(**inputs).logits
  79. # Apply temperature scaling
  80. scaled_logits = logits / temperature
  81. # Apply softmax to get probabilities
  82. probabilities = softmax(scaled_logits, dim=-1)
  83. return probabilities
  84. def get_jailbreak_score(
  85. model, tokenizer, text, temperature=1.0, device="cpu", preprocess=True
  86. ):
  87. """
  88. Evaluate the probability that a given string contains malicious jailbreak or prompt injection.
  89. Appropriate for filtering dialogue between a user and an LLM.
  90. Args:
  91. text (str): The input text to evaluate.
  92. temperature (float): The temperature for the softmax function. Default is 1.0.
  93. device (str): The device to evaluate the model on.
  94. Returns:
  95. float: The probability of the text containing malicious content.
  96. """
  97. probabilities = get_class_probabilities(
  98. model, tokenizer, text, temperature, device, preprocess
  99. )
  100. return probabilities[0, 2].item()
  101. def get_indirect_injection_score(
  102. model, tokenizer, text, temperature=1.0, device="cpu", preprocess=True
  103. ):
  104. """
  105. Evaluate the probability that a given string contains any embedded instructions (malicious or benign).
  106. Appropriate for filtering third party inputs (e.g. web searches, tool outputs) into an LLM.
  107. Args:
  108. text (str): The input text to evaluate.
  109. temperature (float): The temperature for the softmax function. Default is 1.0.
  110. device (str): The device to evaluate the model on.
  111. Returns:
  112. float: The combined probability of the text containing malicious or embedded instructions.
  113. """
  114. probabilities = get_class_probabilities(
  115. model, tokenizer, text, temperature, device, preprocess
  116. )
  117. return (probabilities[0, 1] + probabilities[0, 2]).item()
  118. def process_text_batch(
  119. model, tokenizer, texts, temperature=1.0, device="cpu", preprocess=True
  120. ):
  121. """
  122. Process a batch of texts and return their class probabilities.
  123. Args:
  124. model (transformers.PreTrainedModel): The loaded model.
  125. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  126. texts (list[str]): A list of texts to process.
  127. temperature (float): The temperature for the softmax function.
  128. device (str): The device to evaluate the model on.
  129. Returns:
  130. torch.Tensor: A tensor containing the class probabilities for each text in the batch.
  131. """
  132. if preprocess:
  133. texts = [preprocess_text_for_promptguard(text, tokenizer) for text in texts]
  134. inputs = tokenizer(
  135. texts, return_tensors="pt", padding=True, truncation=True, max_length=512
  136. )
  137. inputs = inputs.to(device)
  138. with torch.no_grad():
  139. logits = model(**inputs).logits
  140. scaled_logits = logits / temperature
  141. probabilities = softmax(scaled_logits, dim=-1)
  142. return probabilities
  143. def get_scores_for_texts(
  144. model,
  145. tokenizer,
  146. texts,
  147. score_indices,
  148. temperature=1.0,
  149. device="cpu",
  150. max_batch_size=16,
  151. preprocess=True,
  152. ):
  153. """
  154. Compute scores for a list of texts, handling texts of arbitrary length by breaking them into chunks and processing in parallel.
  155. Args:
  156. model (transformers.PreTrainedModel): The loaded model.
  157. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  158. texts (list[str]): A list of texts to evaluate.
  159. score_indices (list[int]): Indices of scores to sum for final score calculation.
  160. temperature (float): The temperature for the softmax function.
  161. device (str): The device to evaluate the model on.
  162. max_batch_size (int): The maximum number of text chunks to process in a single batch.
  163. Returns:
  164. list[float]: A list of scores for each text.
  165. """
  166. all_chunks = []
  167. text_indices = []
  168. for index, text in enumerate(texts):
  169. chunks = [text[i : i + 512] for i in range(0, len(text), 512)]
  170. all_chunks.extend(chunks)
  171. text_indices.extend([index] * len(chunks))
  172. all_scores = [0] * len(texts)
  173. for i in range(0, len(all_chunks), max_batch_size):
  174. batch_chunks = all_chunks[i : i + max_batch_size]
  175. batch_indices = text_indices[i : i + max_batch_size]
  176. probabilities = process_text_batch(
  177. model, tokenizer, batch_chunks, temperature, device, preprocess
  178. )
  179. scores = probabilities[:, score_indices].sum(dim=1).tolist()
  180. for idx, score in zip(batch_indices, scores):
  181. all_scores[idx] = max(all_scores[idx], score)
  182. return all_scores
  183. def get_jailbreak_scores_for_texts(
  184. model,
  185. tokenizer,
  186. texts,
  187. temperature=1.0,
  188. device="cpu",
  189. max_batch_size=16,
  190. preprocess=True,
  191. ):
  192. """
  193. Compute jailbreak scores for a list of texts.
  194. Args:
  195. model (transformers.PreTrainedModel): The loaded model.
  196. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  197. texts (list[str]): A list of texts to evaluate.
  198. temperature (float): The temperature for the softmax function.
  199. device (str): The device to evaluate the model on.
  200. max_batch_size (int): The maximum number of text chunks to process in a single batch.
  201. Returns:
  202. list[float]: A list of jailbreak scores for each text.
  203. """
  204. return get_scores_for_texts(
  205. model, tokenizer, texts, [2], temperature, device, max_batch_size, preprocess
  206. )
  207. def get_indirect_injection_scores_for_texts(
  208. model,
  209. tokenizer,
  210. texts,
  211. temperature=1.0,
  212. device="cpu",
  213. max_batch_size=16,
  214. preprocess=True,
  215. ):
  216. """
  217. Compute indirect injection scores for a list of texts.
  218. Args:
  219. model (transformers.PreTrainedModel): The loaded model.
  220. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  221. texts (list[str]): A list of texts to evaluate.
  222. temperature (float): The temperature for the softmax function.
  223. device (str): The device to evaluate the model on.
  224. max_batch_size (int): The maximum number of text chunks to process in a single batch.
  225. Returns:
  226. list[float]: A list of indirect injection scores for each text.
  227. """
  228. return get_scores_for_texts(
  229. model, tokenizer, texts, [1, 2], temperature, device, max_batch_size, preprocess
  230. )