Prechádzať zdrojové kódy

Introduce Llama guard customization notebook and associated dataset loader example

This  notebook explores the customization of Llama Guard 3 for specific application needs. Llama Guard, a versatile AI safety tool, can be adapted to maximise its relevance in various scenarios.

We start with zero-shot prompting, a powerful method that allows Llama Guard to make predictions without prior explicit examples. This technique is particularly useful for initial explorations and quick setups. As we progress, we'll delve into adding and removing safety categories before touching on evaluation and fine-tuning processes, where we adjust Llama Guard's parameters to better align with our specific data and use cases. By following the steps in this notebook, the reader should gain a solid understanding of how to tailor Llama Guard 3 effectively, ensuring it performs optimally for your unique requirements. The notebook does not cover every aspect of Llama Guard 3, but focusses on vaious aspects of customization.
Thomas Robinson 9 mesiacov pred
rodič
commit
1a183c0a5e

Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 801 - 0
recipes/responsible_ai/llama_guard/llama_guard_customization_via_prompting_changes_and_fine_tuning.ipynb


+ 7 - 1
src/llama_recipes/configs/datasets.py

@@ -25,10 +25,16 @@ class alpaca_dataset:
     test_split: str = "val"
     data_path: str = "src/llama_recipes/datasets/alpaca_data.json"
     
-    
+
 @dataclass
 class custom_dataset:
     dataset: str = "custom_dataset"
     file: str = "recipes/quickstart/finetuning/datasets/custom_dataset.py"
     train_split: str = "train"
     test_split: str = "validation"
+    
+@dataclass
+class llamaguard_toxicchat_dataset:
+    dataset: str = "llamaguard_toxicchat_dataset"
+    train_split: str = "train"
+    test_split: str = "test"

+ 2 - 1
src/llama_recipes/datasets/__init__.py

@@ -3,4 +3,5 @@
 
 from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset
 from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset
-from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
+from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
+from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset

+ 131 - 0
src/llama_recipes/datasets/toxicchat_dataset.py

@@ -0,0 +1,131 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+# For dataset details visit: https://huggingface.co/datasets/samsum
+
+import copy
+import datasets
+import itertools
+from llama_recipes.inference.prompt_format_utils import  LLAMA_GUARD_3_CATEGORY
+import ast
+import fire
+
+def tokenize_prompt_and_labels(full_prompt, tokenizer):
+        prompt_tokens = tokenizer.encode(full_prompt)
+        combined_tokens = {
+            "input_ids": list(prompt_tokens),
+            "labels": list(prompt_tokens)
+        }
+        return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
+    
+
+from llama_recipes.data.llama_guard.finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples
+from datasets import Dataset, DatasetInfo
+
+def mapTcCategoriesToLGCategories(TcCategoriesString):
+    TcCategories = ast.literal_eval(TcCategoriesString)
+    if(len(TcCategories)==0):
+         return None
+    ranked = sorted(TcCategories, key=lambda x: x[1], reverse=True)
+    primary = ranked[0][0] if len(ranked) else None
+    TcMapping = {
+        "sexual":"012", 
+        "violence":"01", 
+        "sexual/minors":"04", 
+        "self-harm/intent":"011", 
+        "hate":"010", 
+        "harassment":"010", 
+        "self-harm":"011", 
+        "self-harm/instructions":"011", 
+        "violence/graphic":"01", 
+        "harassment/threatening":"010", 
+        "hate/threatening":"010"
+    }
+    return TcMapping[primary]
+
+def get_llamaguard_toxicchat_dataset(dataset_config, tokenizer, split, return_jsonl=False):
+    dataset = datasets.load_dataset("lmsys/toxic-chat", 'toxicchat0124', split=split)
+
+
+    if return_jsonl:
+        jsonl_data = [];
+        for x in dataset.to_iterable_dataset():
+            jsonl_data.append({
+                    "prompt":x['user_input'],
+                    "generation":x['model_output'],
+                    "label":("good" if x['toxicity'] == 0 else "bad"),
+                    "unsafe_content": [mapTcCategoriesToLGCategories(x['openai_moderation'])]
+                    })
+        return jsonl_data
+    
+    guidelines = Guidelines(
+        categories=LLAMA_GUARD_3_CATEGORY,
+        category_code_prefix="O")
+    
+    # Setting up configs
+    llama_guard_prompt_configs = LlamaGuardPromptConfigs(
+        instructions_format_string="""<|begin_of_text|><|start_header_id|>user<|end_header_id|> Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
+
+    <BEGIN UNSAFE CONTENT CATEGORIES>
+    {guidelines}
+    <END UNSAFE CONTENT CATEGORIES>
+
+    <BEGIN CONVERSATION>
+
+    {conversation}
+
+    <END CONVERSATION>
+
+    Provide your safety assessment for $agent_type in the above conversation:
+    - First line must read 'safe' or 'unsafe'.
+    - If unsafe, a second line must include a comma-separated list of violated categories. <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
+        should_include_category_descriptions=False,
+        should_shuffle_category_codes=False
+    )
+
+    llama_guard_generation_configs = LlamaGuardGenerationConfigs(
+        should_list_violated_codes=True,
+        explanation_position=None
+    )
+
+    augmentation_configs = AugmentationConfigs(
+        should_add_examples_with_dropped_nonviolated_prompt_categories=False,
+        should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=False,
+        explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
+    )
+
+    formatter_configs = FormatterConfigs(
+        guidelines=guidelines,
+        llama_guard_prompt_configs=llama_guard_prompt_configs,
+        llama_guard_generation_configs=llama_guard_generation_configs,
+        augmentation_configs=augmentation_configs,
+        random_seed=42
+    )
+
+    dataset = dataset.map(lambda x: {"full_prompt": create_formatted_finetuning_examples(
+        [TrainingExample(
+            prompt=x["user_input"],
+            response=None,
+            violated_category_codes = [] if x["toxicity"]==0 else [mapTcCategoriesToLGCategories(x["openai_moderation"])],
+            label="safe" if x["toxicity"]==0 else "unsafe",
+            explanation="The response contains violating information."
+        )],
+        formatter_configs)[0]}, 
+        remove_columns=list(dataset.features))
+
+    dataset = dataset.map(lambda x: tokenize_prompt_and_labels(x["full_prompt"], tokenizer), remove_columns=list(dataset.features))
+    return dataset
+
+def main(return_jsonl = False):
+    from transformers import AutoTokenizer
+    model_id: str = "/home/ubuntu/LG3-interim-hf-weights"
+    tokenizer = AutoTokenizer.from_pretrained(model_id)
+    if return_jsonl:
+        dataset = get_llamaguard_toxicchat_dataset(None, tokenizer, "train", return_jsonl = True)
+        print(dataset[0:50])
+    else:
+        dataset = get_llamaguard_toxicchat_dataset(None, tokenizer, "train")
+        print(dataset[0])
+
+if __name__ == '__main__':
+    fire.Fire(main)

+ 3 - 0
src/llama_recipes/utils/dataset_utils.py

@@ -11,6 +11,7 @@ from llama_recipes.datasets import (
     get_grammar_dataset,
     get_alpaca_dataset,
     get_samsum_dataset,
+    get_llamaguard_toxicchat_dataset,
 )
 
 
@@ -54,6 +55,8 @@ DATASET_PREPROC = {
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
     "custom_dataset": get_custom_dataset,
+    "llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
+
 }