|
@@ -205,34 +205,51 @@ class LlamaGuardSafetyChecker(object):
|
|
|
class PromptGuardSafetyChecker(object):
|
|
|
|
|
|
def __init__(self):
|
|
|
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
|
- model_id = "llhf/Prompt-Guard-86M"
|
|
|
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
|
|
|
+ model_id = "meta-llama/Prompt-Guard-86M"
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
- self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
|
|
|
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
|
|
|
|
|
- def __call__(self, text_for_check, **kwargs):
|
|
|
+ def get_scores(self, text, temperature=1.0, device='cpu'):
|
|
|
+ from torch.nn.functional import softmax
|
|
|
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
|
|
+ inputs = inputs.to(device)
|
|
|
+ if len(inputs[0]) > 512:
|
|
|
+ warnings.warn(
|
|
|
+ "Input length is > 512 token. PromptGuard check result could be incorrect."
|
|
|
+ )
|
|
|
+ with torch.no_grad():
|
|
|
+ logits = self.model(**inputs).logits
|
|
|
+ scaled_logits = logits / temperature
|
|
|
+ probabilities = softmax(scaled_logits, dim=-1)
|
|
|
|
|
|
- #tbd
|
|
|
- input_ids = self.tokenizer.apply_chat_template(text_for_check, return_tensors="pt").to("cuda")
|
|
|
- prompt_len = input_ids.shape[-1]
|
|
|
- output = self.model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
|
|
|
- result = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
|
|
|
-
|
|
|
- splitted_result = result.split("\n")[0];
|
|
|
- is_safe = splitted_result == "safe"
|
|
|
+ return {
|
|
|
+ 'jailbreak': probabilities[0, 2].item(),
|
|
|
+ 'indirect_injection': (probabilities[0, 1] + probabilities[0, 2]).item()
|
|
|
+ }
|
|
|
|
|
|
- report = result
|
|
|
-
|
|
|
+ def __call__(self, text_for_check, **kwargs):
|
|
|
+ agent_type = kwargs.get('agent_type', AgentType.USER)
|
|
|
+ if agent_type == AgentType.AGENT:
|
|
|
+ return "PromptGuard", True, "PromptGuard is not used for model output so checking not carried out"
|
|
|
+ sentences = text_for_check.split(".")
|
|
|
+ running_scores = {'jailbreak':0, 'indirect_injection' :0}
|
|
|
+ for sentence in sentences:
|
|
|
+ scores = self.get_scores(sentence)
|
|
|
+ running_scores['jailbreak'] = max([running_scores['jailbreak'],scores['jailbreak']])
|
|
|
+ running_scores['indirect_injection'] = max([running_scores['indirect_injection'],scores['indirect_injection']])
|
|
|
+ is_safe = True if running_scores['jailbreak'] < 0.5 else False
|
|
|
+ report = str(running_scores)
|
|
|
return "PromptGuard", is_safe, report
|
|
|
|
|
|
|
|
|
-# Function to load the PeftModel for performance optimization
|
|
|
# Function to determine which safety checker to use based on the options selected
|
|
|
def get_safety_checker(enable_azure_content_safety,
|
|
|
enable_sensitive_topics,
|
|
|
enable_salesforce_content_safety,
|
|
|
- enable_llamaguard_content_safety):
|
|
|
+ enable_llamaguard_content_safety,
|
|
|
+ enable_promptguard_safety):
|
|
|
safety_checker = []
|
|
|
if enable_azure_content_safety:
|
|
|
safety_checker.append(AzureSaftyChecker())
|
|
@@ -242,5 +259,7 @@ def get_safety_checker(enable_azure_content_safety,
|
|
|
safety_checker.append(SalesforceSafetyChecker())
|
|
|
if enable_llamaguard_content_safety:
|
|
|
safety_checker.append(LlamaGuardSafetyChecker())
|
|
|
+ if enable_promptguard_safety:
|
|
|
+ safety_checker.append(PromptGuardSafetyChecker())
|
|
|
return safety_checker
|
|
|
|