Explorar o código

Initial commit

Thomas Robinson hai 11 meses
pai
achega
896757c736
Modificáronse 1 ficheiros con 24 adicións e 0 borrados
  1. 24 0
      src/llama_recipes/inference/safety_utils.py

+ 24 - 0
src/llama_recipes/inference/safety_utils.py

@@ -201,7 +201,31 @@ class LlamaGuardSafetyChecker(object):
         report = result
         
         return "Llama Guard", is_safe, report
+
+class PromptGuardSafetyChecker(object):
+
+    def __init__(self):
+        from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
+        model_id = "llhf/Prompt-Guard-86M"
+
+        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
+        self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
+
+    def __call__(self, text_for_check, **kwargs):
+
+        #tbd
+        input_ids = self.tokenizer.apply_chat_template(text_for_check, return_tensors="pt").to("cuda")
+        prompt_len = input_ids.shape[-1]
+        output = self.model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
+        result = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
+        
+        splitted_result = result.split("\n")[0];
+        is_safe = splitted_result == "safe"    
+
+        report = result
         
+        return "PromptGuard", is_safe, report
+
 
 # Function to load the PeftModel for performance optimization
 # Function to determine which safety checker to use based on the options selected