浏览代码

Adds tentative llamaguard HF model id, eos_token_id for model.generate

Suraj 9 月之前
父节点
当前提交
308026aad5

+ 2 - 2
recipes/responsible_ai/README.md

@@ -1,8 +1,8 @@
 # Meta Llama Guard
 # Meta Llama Guard
 
 
-Meta Llama Guard and Meta Llama Guard 2 are new models that provide input and output guardrails for LLM inference. For more details, please visit the main [repository](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard2).
+Meta Llama Guard models provide input and output guardrails for LLM inference. For more details, please visit the main [repository](https://github.com/facebookresearch/PurpleLlama/).
 
 
-**Note** Please find the right model on HF side [here](https://huggingface.co/meta-llama/Meta-Llama-Guard-2-8B).
+**Note** Please find the right model on HF side [here](https://huggingface.co/meta-llama/Llama-Guard-3-8B).
 
 
 ### Running locally
 ### Running locally
 The [llama_guard](llama_guard) folder contains the inference script to run Meta Llama Guard locally. Add test prompts directly to the [inference script](llama_guard/inference.py) before running it.
 The [llama_guard](llama_guard) folder contains the inference script to run Meta Llama Guard locally. Add test prompts directly to the [inference script](llama_guard/inference.py) before running it.

+ 3 - 3
recipes/responsible_ai/llama_guard/README.md

@@ -1,6 +1,6 @@
 # Meta Llama Guard demo
 # Meta Llama Guard demo
 <!-- markdown-link-check-disable -->
 <!-- markdown-link-check-disable -->
-Meta Llama Guard is a language model that provides input and output guardrails for LLM inference. For more details and model cards, please visit the main repository for each model, [Meta Llama Guard](https://github.com/meta-llama/PurpleLlama/tree/main/Llama-Guard) and Meta [Llama Guard 2](https://github.com/meta-llama/PurpleLlama/tree/main/Llama-Guard2).
+Meta Llama Guard is a language model that provides input and output guardrails for LLM inference. For more details and model cards, please visit the [PurpleLlama](https://github.com/meta-llama/PurpleLlama) repository.
 
 
 This folder contains an example file to run inference with a locally hosted model, either using the Hugging Face Hub or a local path.
 This folder contains an example file to run inference with a locally hosted model, either using the Hugging Face Hub or a local path.
 
 
@@ -55,9 +55,9 @@ This is the output:
 
 
 To run it with a local model, you can use the `model_id` param in the inference script:
 To run it with a local model, you can use the `model_id` param in the inference script:
 
 
-`python recipes/responsible_ai/llama_guard/inference.py --model_id=/home/ubuntu/models/llama3/llama_guard_2-hf/ --llama_guard_version=LLAMA_GUARD_2`
+`python recipes/responsible_ai/llama_guard/inference.py --model_id=/home/ubuntu/models/llama3/Llama-Guard-3-8B/ --llama_guard_version=LLAMA_GUARD_3`
 
 
-Note: Make sure to also add the llama_guard_version if when it does not match the default, the script allows you to run the prompt format from Meta Llama Guard 1 on Meta Llama Guard 2
+Note: Make sure to also add the llama_guard_version; by default it uses LLAMA_GUARD_3
 
 
 ## Inference Safety Checker
 ## Inference Safety Checker
 When running the regular inference script with prompts, Meta Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Meta Llama Guard is always loaded quantized using Hugging Face Transformers library with bitsandbytes.
 When running the regular inference script with prompts, Meta Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Meta Llama Guard is always loaded quantized using Hugging Face Transformers library with bitsandbytes.

+ 3 - 3
recipes/responsible_ai/llama_guard/inference.py

@@ -14,8 +14,8 @@ class AgentType(Enum):
     USER = "User"
     USER = "User"
 
 
 def main(
 def main(
-    model_id: str = "meta-llama/LlamaGuard-7b",
-    llama_guard_version: LlamaGuardVersion = LlamaGuardVersion.LLAMA_GUARD_1
+    model_id: str = "meta-llama/Llama-Guard-3-8B",
+    llama_guard_version: str = "LLAMA_GUARD_3"
 ):
 ):
     """
     """
     Entry point for Llama Guard inference sample script.
     Entry point for Llama Guard inference sample script.
@@ -60,7 +60,7 @@ def main(
 
 
         input = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
         input = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
         prompt_len = input["input_ids"].shape[-1]
         prompt_len = input["input_ids"].shape[-1]
-        output = model.generate(**input, max_new_tokens=100, pad_token_id=0)
+        output = model.generate(**input, max_new_tokens=100, pad_token_id=0, eos_token_id=128009)
         results = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
         results = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)