Browse Source

Add PromptGuard to safety checker. Update inference scripts for new change

Thomas Robinson 11 months ago
parent
commit
f089186ff4

+ 2 - 0
recipes/quickstart/inference/code_llama/code_completion_example.py

@@ -34,6 +34,7 @@ def main(
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
     enable_llamaguard_content_safety: bool=False, # Enable safety check with Llama-Guard
+    enable_promptguard_safety: bool = False,
     use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
 ):
@@ -62,6 +63,7 @@ def main(
                                         enable_sensitive_topics,
                                         enable_salesforce_content_safety,
                                         enable_llamaguard_content_safety,
+                                        enable_promptguard_safety,
                                         )
 
     # Safety check of the user prompt

+ 2 - 0
recipes/quickstart/inference/code_llama/code_infilling_example.py

@@ -33,6 +33,7 @@ def main(
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
     enable_llamaguard_content_safety: bool=False, # Enable safety check with Llama-Guard
+    enable_promptguard_safety: bool = False,
     use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
 ):
@@ -62,6 +63,7 @@ def main(
                                         enable_sensitive_topics,
                                         enable_salesforce_content_safety,
                                         enable_llamaguard_content_safety,
+                                        enable_promptguard_safety
                                         )
 
     # Safety check of the user prompt

+ 2 - 0
recipes/quickstart/inference/code_llama/code_instruct_example.py

@@ -71,6 +71,7 @@ def main(
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
     enable_llamaguard_content_safety: bool=False, # Enable safety check with Llama-Guard
+    enable_promptguard_safety: bool = False,
     use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
 ):
@@ -95,6 +96,7 @@ def main(
                                         enable_sensitive_topics,
                                         enable_salesforce_content_safety,
                                         enable_llamaguard_content_safety,
+                                        enable_promptguard_safety
                                         )
 
     # Safety check of the user prompt

+ 2 - 0
recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py

@@ -37,6 +37,7 @@ def main(
     enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
     use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     enable_llamaguard_content_safety: bool = False,
+    enable_promptguard_safety: bool = False,
     **kwargs
 ):
     if prompt_file is not None:
@@ -81,6 +82,7 @@ def main(
                                         enable_sensitive_topics,
                                         enable_saleforce_content_safety,
                                         enable_llamaguard_content_safety,
+                                        enable_promptguard_safety
                                         )
             # Safety check of the user prompt
             safety_results = [check(dialogs[idx][0]["content"]) for check in safety_checker]

+ 5 - 2
recipes/quickstart/inference/local_inference/inference.py

@@ -36,6 +36,7 @@ def main(
     enable_sensitive_topics: bool = False,  # Enable check for sensitive topics using AuditNLG APIs
     enable_salesforce_content_safety: bool = True,  # Enable safety check with Salesforce safety flan t5
     enable_llamaguard_content_safety: bool = False,
+    enable_promptguard_safety: bool = False,
     max_padding_length: int = None,  # the max padding length to be used with tokenizer padding the prompts.
     use_fast_kernels: bool = False,  # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     share_gradio: bool = False,  # Enable endpoint creation for gradio.live
@@ -70,8 +71,10 @@ def main(
             enable_sensitive_topics,
             enable_salesforce_content_safety,
             enable_llamaguard_content_safety,
+            enable_promptguard_safety,
         )
 
+
         # Safety check of the user prompt
         safety_results = [check(user_prompt) for check in safety_checker]
         are_safe = all([r[1] for r in safety_results])
@@ -119,9 +122,9 @@ def main(
 
         # Safety check of the model output
         safety_results = [
-            check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt)
-            for check in safety_checker
+            check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt) for check in safety_checker
         ]
+        print(safety_results)
         are_safe = all([r[1] for r in safety_results])
         if are_safe:
             print("User input and model output deemed safe.")

+ 35 - 16
src/llama_recipes/inference/safety_utils.py

@@ -205,34 +205,51 @@ class LlamaGuardSafetyChecker(object):
 class PromptGuardSafetyChecker(object):
 
     def __init__(self):
-        from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
-        model_id = "llhf/Prompt-Guard-86M"
+        from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
+        model_id = "meta-llama/Prompt-Guard-86M"
 
         self.tokenizer = AutoTokenizer.from_pretrained(model_id)
-        self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
+        self.model = AutoModelForSequenceClassification.from_pretrained(model_id)
 
-    def __call__(self, text_for_check, **kwargs):
+    def get_scores(self, text, temperature=1.0, device='cpu'):
+        from torch.nn.functional import softmax
+        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
+        inputs = inputs.to(device)
+        if len(inputs[0]) > 512:
+            warnings.warn(
+                "Input length is > 512 token. PromptGuard check result could be incorrect."
+            )
+        with torch.no_grad():
+            logits = self.model(**inputs).logits
+        scaled_logits = logits / temperature
+        probabilities = softmax(scaled_logits, dim=-1)
 
-        #tbd
-        input_ids = self.tokenizer.apply_chat_template(text_for_check, return_tensors="pt").to("cuda")
-        prompt_len = input_ids.shape[-1]
-        output = self.model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
-        result = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
-        
-        splitted_result = result.split("\n")[0];
-        is_safe = splitted_result == "safe"    
+        return {
+            'jailbreak': probabilities[0, 2].item(),
+            'indirect_injection': (probabilities[0, 1] + probabilities[0, 2]).item()
+        }
 
-        report = result
-        
+    def __call__(self, text_for_check, **kwargs):     
+        agent_type = kwargs.get('agent_type', AgentType.USER)
+        if agent_type == AgentType.AGENT:
+            return "PromptGuard", True, "PromptGuard is not used for model output so checking not carried out"
+        sentences = text_for_check.split(".")
+        running_scores = {'jailbreak':0, 'indirect_injection' :0}
+        for sentence in sentences:
+            scores = self.get_scores(sentence)
+            running_scores['jailbreak'] = max([running_scores['jailbreak'],scores['jailbreak']])
+            running_scores['indirect_injection'] = max([running_scores['indirect_injection'],scores['indirect_injection']])
+        is_safe = True if running_scores['jailbreak'] < 0.5 else False
+        report = str(running_scores)
         return "PromptGuard", is_safe, report
 
 
-# Function to load the PeftModel for performance optimization
 # Function to determine which safety checker to use based on the options selected
 def get_safety_checker(enable_azure_content_safety,
                        enable_sensitive_topics,
                        enable_salesforce_content_safety,
-                       enable_llamaguard_content_safety):
+                       enable_llamaguard_content_safety,
+                       enable_promptguard_safety):
     safety_checker = []
     if enable_azure_content_safety:
         safety_checker.append(AzureSaftyChecker())
@@ -242,5 +259,7 @@ def get_safety_checker(enable_azure_content_safety,
         safety_checker.append(SalesforceSafetyChecker())
     if enable_llamaguard_content_safety:
         safety_checker.append(LlamaGuardSafetyChecker())
+    if enable_promptguard_safety:
+        safety_checker.append(PromptGuardSafetyChecker())
     return safety_checker