|
@@ -24,7 +24,7 @@ class AuditNLGSensitiveTopics(object):
|
|
|
except ImportError as e:
|
|
|
print("Could not import optional dependency: auditnlg\nPlease install manually with:\n pip install auditnlg\nFollowed by:\npip install -r requirements.txt")
|
|
|
raise e
|
|
|
-
|
|
|
+
|
|
|
|
|
|
data = [{"output": output_text}]
|
|
|
|
|
@@ -37,8 +37,8 @@ class AuditNLGSensitiveTopics(object):
|
|
|
report += "|" + "|".join(f"{n:^10}" for n in [list(k.keys())[0] for k in scores["class_scores"]]) + "|\n"
|
|
|
report += "|" + "|".join(f"{n:^10.5}" for n in [list(k.values())[0] for k in scores["class_scores"]]) + "|\n"
|
|
|
return "Sensitive Topics", is_safe, report
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
+
|
|
|
class SalesforceSafetyChecker(object):
|
|
|
def __init__(self):
|
|
|
pass
|
|
@@ -47,13 +47,13 @@ class SalesforceSafetyChecker(object):
|
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
|
|
|
|
|
|
config = AutoConfig.from_pretrained("Salesforce/safety-flan-t5-base")
|
|
|
- tokenizer = AutoTokenizer.from_pretrained("Salesforce/safety-flan-t5-base")
|
|
|
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/safety-flan-t5-base")
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/safety-flan-t5-base", config=config)
|
|
|
model.eval()
|
|
|
|
|
|
prefix = "Is the <Text> field safe or unsafe?"
|
|
|
input_ids = tokenizer(prefix + " <Text> " + output_text + " <Context> ", return_tensors="pt").input_ids
|
|
|
-
|
|
|
+
|
|
|
if len(input_ids[0]) > 512:
|
|
|
warnings.warn(
|
|
|
"Input length is > 512 token. Safety check result could be incorrect."
|
|
@@ -66,9 +66,9 @@ class SalesforceSafetyChecker(object):
|
|
|
return_dict_in_generate=True,
|
|
|
max_new_tokens=20,
|
|
|
)
|
|
|
-
|
|
|
- is_safe = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).split(" ")[0] == "safe"
|
|
|
-
|
|
|
+
|
|
|
+ is_safe = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).split(" ")[0] == "safe"
|
|
|
+
|
|
|
report = ""
|
|
|
if not is_safe:
|
|
|
true_false_ids = tokenizer("true false").input_ids[:2]
|
|
@@ -76,11 +76,11 @@ class SalesforceSafetyChecker(object):
|
|
|
scores = {}
|
|
|
for k, i in zip(keys, range(3,20,2)):
|
|
|
scores[k] = round(outputs.scores[i][0,true_false_ids].softmax(dim=0)[0].item(), 5)
|
|
|
-
|
|
|
+
|
|
|
report += "|" + "|".join(f"{n:^10}" for n in scores.keys()) + "|\n"
|
|
|
report += "|" + "|".join(f"{n:^10}" for n in scores.values()) + "|\n"
|
|
|
return "Salesforce Content Safety Flan T5 Base", is_safe, report
|
|
|
-
|
|
|
+
|
|
|
|
|
|
def get_total_length(self, data):
|
|
|
prefix = "Is the <Text> field safe or unsafe "
|
|
@@ -158,7 +158,7 @@ class LlamaGuardSafetyChecker(object):
|
|
|
|
|
|
def __init__(self):
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
|
- from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
|
|
|
+ from llama_cookbook.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
|
|
|
|
|
|
model_id = "meta-llama/Llama-Guard-3-8B"
|
|
|
|
|
@@ -168,7 +168,7 @@ class LlamaGuardSafetyChecker(object):
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
|
|
|
|
|
|
def __call__(self, output_text, **kwargs):
|
|
|
-
|
|
|
+
|
|
|
agent_type = kwargs.get('agent_type', AgentType.USER)
|
|
|
user_prompt = kwargs.get('user_prompt', "")
|
|
|
|
|
@@ -194,14 +194,14 @@ class LlamaGuardSafetyChecker(object):
|
|
|
prompt_len = input_ids.shape[-1]
|
|
|
output = self.model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
|
|
|
result = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
|
|
|
-
|
|
|
+
|
|
|
splitted_result = result.split("\n")[0];
|
|
|
- is_safe = splitted_result == "safe"
|
|
|
+ is_safe = splitted_result == "safe"
|
|
|
|
|
|
report = result
|
|
|
-
|
|
|
+
|
|
|
return "Llama Guard", is_safe, report
|
|
|
-
|
|
|
+
|
|
|
|
|
|
# Function to load the PeftModel for performance optimization
|
|
|
# Function to determine which safety checker to use based on the options selected
|