inference.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import torch
  2. from torch.nn.functional import softmax
  3. from transformers import (
  4. AutoModelForSequenceClassification,
  5. AutoTokenizer,
  6. )
  7. """
  8. Utilities for loading the PromptGuard model and evaluating text for jailbreaks and indirect injections.
  9. Note that the underlying model has a maximum recommended input size of 512 tokens as a DeBERTa model.
  10. The final two functions in this file implement efficient parallel batched evaluation of the model on a list
  11. of input strings of arbitrary length, with the final score for each input being the maximum score across all
  12. chunks of the input string.
  13. """
  14. def load_model_and_tokenizer(model_name='meta-llama/Prompt-Guard-86M'):
  15. """
  16. Load the PromptGuard model from Hugging Face or a local model.
  17. Args:
  18. model_name (str): The name of the model to load. Default is 'meta-llama/Prompt-Guard-86M'.
  19. Returns:
  20. transformers.PreTrainedModel: The loaded model.
  21. """
  22. model = AutoModelForSequenceClassification.from_pretrained(model_name)
  23. tokenizer = AutoTokenizer.from_pretrained(model_name)
  24. return model, tokenizer
  25. def preprocess_text_for_promptguard(text: str, tokenizer) -> str:
  26. """
  27. Preprocess the text by removing spaces that break apart larger tokens.
  28. This hotfixes a workaround to PromptGuard, where spaces can be inserted into a string
  29. to allow the string to be classified as benign.
  30. Args:
  31. text (str): The input text to preprocess.
  32. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  33. Returns:
  34. str: The preprocessed text.
  35. """
  36. try:
  37. cleaned_text = ''
  38. index_map = []
  39. for i, char in enumerate(text):
  40. if not char.isspace():
  41. cleaned_text += char
  42. index_map.append(i)
  43. tokens = tokenizer.tokenize(cleaned_text)
  44. result = []
  45. last_end = 0
  46. for token in tokens:
  47. token_str = tokenizer.convert_tokens_to_string([token])
  48. start = cleaned_text.index(token_str, last_end)
  49. end = start + len(token_str)
  50. original_start = index_map[start]
  51. if original_start > 0 and text[original_start - 1].isspace():
  52. result.append(' ')
  53. result.append(token_str)
  54. last_end = end
  55. return ''.join(result)
  56. except Exception:
  57. return text
  58. def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
  59. """
  60. Evaluate the model on the given text with temperature-adjusted softmax.
  61. Note, as this is a DeBERTa model, the input text should have a maximum length of 512.
  62. Args:
  63. text (str): The input text to classify.
  64. temperature (float): The temperature for the softmax function. Default is 1.0.
  65. device (str): The device to evaluate the model on.
  66. Returns:
  67. torch.Tensor: The probability of each class adjusted by the temperature.
  68. """
  69. if preprocess:
  70. text = preprocess_text_for_promptguard(text, tokenizer)
  71. # Encode the text
  72. inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
  73. inputs = inputs.to(device)
  74. # Get logits from the model
  75. with torch.no_grad():
  76. logits = model(**inputs).logits
  77. # Apply temperature scaling
  78. scaled_logits = logits / temperature
  79. # Apply softmax to get probabilities
  80. probabilities = softmax(scaled_logits, dim=-1)
  81. return probabilities
  82. def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
  83. """
  84. Evaluate the probability that a given string contains malicious jailbreak or prompt injection.
  85. Appropriate for filtering dialogue between a user and an LLM.
  86. Args:
  87. text (str): The input text to evaluate.
  88. temperature (float): The temperature for the softmax function. Default is 1.0.
  89. device (str): The device to evaluate the model on.
  90. Returns:
  91. float: The probability of the text containing malicious content.
  92. """
  93. probabilities = get_class_probabilities(model, tokenizer, text, temperature, device, preprocess)
  94. return probabilities[0, 2].item()
  95. def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
  96. """
  97. Evaluate the probability that a given string contains any embedded instructions (malicious or benign).
  98. Appropriate for filtering third party inputs (e.g. web searches, tool outputs) into an LLM.
  99. Args:
  100. text (str): The input text to evaluate.
  101. temperature (float): The temperature for the softmax function. Default is 1.0.
  102. device (str): The device to evaluate the model on.
  103. Returns:
  104. float: The combined probability of the text containing malicious or embedded instructions.
  105. """
  106. probabilities = get_class_probabilities(model, tokenizer, text, temperature, device, preprocess)
  107. return (probabilities[0, 1] + probabilities[0, 2]).item()
  108. def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu', preprocess=True):
  109. """
  110. Process a batch of texts and return their class probabilities.
  111. Args:
  112. model (transformers.PreTrainedModel): The loaded model.
  113. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  114. texts (list[str]): A list of texts to process.
  115. temperature (float): The temperature for the softmax function.
  116. device (str): The device to evaluate the model on.
  117. Returns:
  118. torch.Tensor: A tensor containing the class probabilities for each text in the batch.
  119. """
  120. if preprocess:
  121. texts = [preprocess_text_for_promptguard(text, tokenizer) for text in texts]
  122. inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
  123. inputs = inputs.to(device)
  124. with torch.no_grad():
  125. logits = model(**inputs).logits
  126. scaled_logits = logits / temperature
  127. probabilities = softmax(scaled_logits, dim=-1)
  128. return probabilities
  129. def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
  130. """
  131. Compute scores for a list of texts, handling texts of arbitrary length by breaking them into chunks and processing in parallel.
  132. Args:
  133. model (transformers.PreTrainedModel): The loaded model.
  134. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  135. texts (list[str]): A list of texts to evaluate.
  136. score_indices (list[int]): Indices of scores to sum for final score calculation.
  137. temperature (float): The temperature for the softmax function.
  138. device (str): The device to evaluate the model on.
  139. max_batch_size (int): The maximum number of text chunks to process in a single batch.
  140. Returns:
  141. list[float]: A list of scores for each text.
  142. """
  143. all_chunks = []
  144. text_indices = []
  145. for index, text in enumerate(texts):
  146. chunks = [text[i:i+512] for i in range(0, len(text), 512)]
  147. all_chunks.extend(chunks)
  148. text_indices.extend([index] * len(chunks))
  149. all_scores = [0] * len(texts)
  150. for i in range(0, len(all_chunks), max_batch_size):
  151. batch_chunks = all_chunks[i:i+max_batch_size]
  152. batch_indices = text_indices[i:i+max_batch_size]
  153. probabilities = process_text_batch(model, tokenizer, batch_chunks, temperature, device, preprocess)
  154. scores = probabilities[:, score_indices].sum(dim=1).tolist()
  155. for idx, score in zip(batch_indices, scores):
  156. all_scores[idx] = max(all_scores[idx], score)
  157. return all_scores
  158. def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
  159. """
  160. Compute jailbreak scores for a list of texts.
  161. Args:
  162. model (transformers.PreTrainedModel): The loaded model.
  163. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  164. texts (list[str]): A list of texts to evaluate.
  165. temperature (float): The temperature for the softmax function.
  166. device (str): The device to evaluate the model on.
  167. max_batch_size (int): The maximum number of text chunks to process in a single batch.
  168. Returns:
  169. list[float]: A list of jailbreak scores for each text.
  170. """
  171. return get_scores_for_texts(model, tokenizer, texts, [2], temperature, device, max_batch_size, preprocess)
  172. def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
  173. """
  174. Compute indirect injection scores for a list of texts.
  175. Args:
  176. model (transformers.PreTrainedModel): The loaded model.
  177. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  178. texts (list[str]): A list of texts to evaluate.
  179. temperature (float): The temperature for the softmax function.
  180. device (str): The device to evaluate the model on.
  181. max_batch_size (int): The maximum number of text chunks to process in a single batch.
  182. Returns:
  183. list[float]: A list of indirect injection scores for each text.
  184. """
  185. return get_scores_for_texts(model, tokenizer, texts, [1, 2], temperature, device, max_batch_size, preprocess)