inference.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. from typing import List, Tuple
  2. import torch
  3. from torch.nn.functional import softmax
  4. from transformers import AutoModelForSequenceClassification, AutoTokenizer
  5. """
  6. Utilities for loading the PromptGuard model and evaluating text for jailbreaking techniques.
  7. NOTE: this code is for PromptGuard 2. For our older PromptGuard 1 model, see prompt_guard_1_inference.py
  8. Note that the underlying model has a maximum recommended input size of 512 tokens as a DeBERTa model.
  9. The final two functions in this file implement efficient parallel batched evaluation of the model on a list
  10. of input strings of arbitrary length, with the final score for each input being the maximum score across all
  11. chunks of the input string.
  12. """
  13. MAX_TOKENS = 512
  14. DEFAULT_BATCH_SIZE = 16
  15. DEFAULT_TEMPERATURE = 1.0
  16. DEFAULT_DEVICE = "cpu"
  17. DEFAULT_MODEL_NAME = "meta-llama/Llama-Prompt-Guard-2-86M"
  18. def load_model_and_tokenizer(
  19. model_name: str = "meta-llama/Llama-Prompt-Guard-2-86M", device: str = DEFAULT_DEVICE
  20. ) -> Tuple[AutoModelForSequenceClassification, AutoTokenizer, str]:
  21. """
  22. Load the PromptGuard model and tokenizer, and move the model to the specified device.
  23. Args:
  24. model_name (str): The name of the model to load.
  25. device (str): The device to load the model on. If None, it will use CUDA if available, else CPU.
  26. Returns:
  27. tuple: The loaded model, tokenizer, and the device used.
  28. """
  29. try:
  30. if device is None:
  31. device = "cuda" if torch.cuda.is_available() else "cpu"
  32. model = AutoModelForSequenceClassification.from_pretrained(model_name)
  33. model = model.to(device)
  34. tokenizer = AutoTokenizer.from_pretrained(model_name)
  35. return model, tokenizer, device
  36. except Exception as e:
  37. raise RuntimeError(f"Failed to load model and tokenizer: {str(e)}")
  38. def get_class_scores(
  39. model: AutoModelForSequenceClassification,
  40. tokenizer: AutoTokenizer,
  41. text: str,
  42. temperature: float = DEFAULT_TEMPERATURE,
  43. ) -> torch.Tensor:
  44. """
  45. Evaluate the model on the given text with temperature-adjusted softmax.
  46. Note, as this is a DeBERTa model, the input text should have a maximum length of 512.
  47. Args:
  48. model: The loaded model.
  49. tokenizer: The loaded tokenizer.
  50. text (str): The input text to classify.
  51. temperature (float): The temperature for the softmax function. Default is 1.0.
  52. Returns:
  53. torch.Tensor: The scores for each class adjusted by the temperature.
  54. """
  55. # Get the device from the model
  56. device = next(model.parameters()).device
  57. # Encode the text
  58. inputs = tokenizer(
  59. text, return_tensors="pt", padding=True, truncation=True, max_length=MAX_TOKENS
  60. )
  61. inputs = {k: v.to(device) for k, v in inputs.items()}
  62. # Get logits from the model
  63. with torch.no_grad():
  64. logits = model(**inputs).logits
  65. # Apply temperature scaling
  66. scaled_logits = logits / temperature
  67. # Apply softmax to get scores
  68. scores = softmax(scaled_logits, dim=-1)
  69. return scores
  70. def get_jailbreak_score(
  71. model: AutoModelForSequenceClassification,
  72. tokenizer: AutoTokenizer,
  73. text: str,
  74. temperature: float = DEFAULT_TEMPERATURE,
  75. ) -> float:
  76. """
  77. Evaluate the probability that a given string contains malicious jailbreak or prompt injection.
  78. Appropriate for filtering dialogue between a user and an LLM.
  79. Args:
  80. model: The loaded model.
  81. tokenizer: The loaded tokenizer.
  82. text (str): The input text to evaluate.
  83. temperature (float): The temperature for the softmax function. Default is 1.0.
  84. Returns:
  85. float: The probability of the text containing malicious content.
  86. """
  87. probabilities = get_class_scores(model, tokenizer, text, temperature)
  88. return probabilities[0, 1].item()
  89. def process_text_batch(
  90. model: AutoModelForSequenceClassification,
  91. tokenizer: AutoTokenizer,
  92. texts: List[str],
  93. temperature: float = DEFAULT_TEMPERATURE,
  94. ) -> torch.Tensor:
  95. """
  96. Process a batch of texts and return their class probabilities.
  97. Args:
  98. model (transformers.PreTrainedModel): The loaded model.
  99. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  100. texts (list[str]): A list of texts to process.
  101. temperature (float): The temperature for the softmax function.
  102. Returns:
  103. torch.Tensor: A tensor containing the class probabilities for each text in the batch.
  104. """
  105. # Get the device from the model
  106. device = next(model.parameters()).device
  107. # encode the texts
  108. inputs = tokenizer(
  109. texts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_TOKENS
  110. )
  111. inputs = {k: v.to(device) for k, v in inputs.items()}
  112. with torch.no_grad():
  113. logits = model(**inputs).logits
  114. scaled_logits = logits / temperature
  115. probabilities = softmax(scaled_logits, dim=-1)
  116. return probabilities
  117. def get_scores_for_texts(
  118. model: AutoModelForSequenceClassification,
  119. tokenizer: AutoTokenizer,
  120. texts: List[str],
  121. score_indices: List[int],
  122. temperature: float = DEFAULT_TEMPERATURE,
  123. max_batch_size: int = DEFAULT_BATCH_SIZE,
  124. ) -> List[float]:
  125. """
  126. Compute scores for a list of texts, handling texts of arbitrary length by breaking them into chunks and processing in parallel.
  127. The final score for each text is the maximum score across all chunks of the text.
  128. Args:
  129. model (transformers.PreTrainedModel): The loaded model.
  130. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  131. texts (list[str]): A list of texts to evaluate.
  132. score_indices (list[int]): Indices of scores to sum for final score calculation.
  133. temperature (float): The temperature for the softmax function.
  134. max_batch_size (int): The maximum number of text chunks to process in a single batch.
  135. Returns:
  136. list[float]: A list of scores for each text.
  137. """
  138. all_chunks = []
  139. text_indices = []
  140. for index, text in enumerate(texts):
  141. # Tokenize the text and split into chunks of MAX_TOKENS
  142. tokens = tokenizer(text, return_tensors="pt", truncation=False)["input_ids"][0]
  143. chunks = [tokens[i : i + MAX_TOKENS] for i in range(0, len(tokens), MAX_TOKENS)]
  144. all_chunks.extend(chunks)
  145. text_indices.extend([index] * len(chunks))
  146. all_scores = [0.0] * len(texts)
  147. for i in range(0, len(all_chunks), max_batch_size):
  148. batch_chunks = all_chunks[i : i + max_batch_size]
  149. batch_indices = text_indices[i : i + max_batch_size]
  150. # Decode the token chunks back to text
  151. batch_texts = [
  152. tokenizer.decode(chunk, skip_special_tokens=True) for chunk in batch_chunks
  153. ]
  154. probabilities = process_text_batch(model, tokenizer, batch_texts, temperature)
  155. scores = probabilities[:, score_indices].sum(dim=1).tolist()
  156. for idx, score in zip(batch_indices, scores):
  157. all_scores[idx] = max(all_scores[idx], score)
  158. return all_scores
  159. def get_jailbreak_scores_for_texts(
  160. model: AutoModelForSequenceClassification,
  161. tokenizer: AutoTokenizer,
  162. texts: List[str],
  163. temperature: float = DEFAULT_TEMPERATURE,
  164. max_batch_size: int = DEFAULT_BATCH_SIZE,
  165. ) -> List[float]:
  166. """
  167. Compute jailbreak scores for a list of texts.
  168. Args:
  169. model (transformers.PreTrainedModel): The loaded model.
  170. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
  171. texts (list[str]): A list of texts to evaluate.
  172. temperature (float): The temperature for the softmax function.
  173. max_batch_size (int): The maximum number of text chunks to process in a single batch.
  174. Returns:
  175. list[float]: A list of jailbreak scores for each text.
  176. """
  177. return get_scores_for_texts(
  178. model, tokenizer, texts, [1], temperature, max_batch_size
  179. )