|
@@ -201,7 +201,31 @@ class LlamaGuardSafetyChecker(object):
|
|
|
report = result
|
|
|
|
|
|
return "Llama Guard", is_safe, report
|
|
|
+
|
|
|
+class PromptGuardSafetyChecker(object):
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
|
+ model_id = "llhf/Prompt-Guard-86M"
|
|
|
+
|
|
|
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
|
|
|
+
|
|
|
+ def __call__(self, text_for_check, **kwargs):
|
|
|
+
|
|
|
+ #tbd
|
|
|
+ input_ids = self.tokenizer.apply_chat_template(text_for_check, return_tensors="pt").to("cuda")
|
|
|
+ prompt_len = input_ids.shape[-1]
|
|
|
+ output = self.model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
|
|
|
+ result = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
|
|
|
+
|
|
|
+ splitted_result = result.split("\n")[0];
|
|
|
+ is_safe = splitted_result == "safe"
|
|
|
+
|
|
|
+ report = result
|
|
|
|
|
|
+ return "PromptGuard", is_safe, report
|
|
|
+
|
|
|
|
|
|
# Function to load the PeftModel for performance optimization
|
|
|
# Function to determine which safety checker to use based on the options selected
|