123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 3.1 Community License Agreement.
- # For dataset details visit: https://huggingface.co/datasets/lmsys/toxic-chat
- import copy
- import datasets
- import itertools
- from llama_recipes.inference.prompt_format_utils import LLAMA_GUARD_3_CATEGORY
- import ast
- import fire
- def tokenize_prompt_and_labels(full_prompt, tokenizer):
- prompt_tokens = tokenizer.encode(full_prompt)
- combined_tokens = {
- "input_ids": list(prompt_tokens),
- "labels": list(prompt_tokens)
- }
- return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
-
- from llama_recipes.data.llama_guard.finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples
- from datasets import Dataset, DatasetInfo
- def mapTcCategoriesToLGCategories(TcCategoriesString):
- TcCategories = ast.literal_eval(TcCategoriesString)
- if(len(TcCategories)==0):
- return None
- ranked = sorted(TcCategories, key=lambda x: x[1], reverse=True)
- primary = ranked[0][0] if len(ranked) else None
- TcMapping = {
- "sexual":"012",
- "violence":"01",
- "sexual/minors":"04",
- "self-harm/intent":"011",
- "hate":"010",
- "harassment":"010",
- "self-harm":"011",
- "self-harm/instructions":"011",
- "violence/graphic":"01",
- "harassment/threatening":"010",
- "hate/threatening":"010"
- }
- return TcMapping[primary]
- def get_llamaguard_toxicchat_dataset(dataset_config, tokenizer, split, return_jsonl=False):
- dataset = datasets.load_dataset("lmsys/toxic-chat", 'toxicchat0124', split=split)
- if return_jsonl:
- jsonl_data = [];
- for x in dataset.to_iterable_dataset():
- jsonl_data.append({
- "prompt":x['user_input'],
- "generation":x['model_output'],
- "label":("good" if x['toxicity'] == 0 else "bad"),
- "unsafe_content": [mapTcCategoriesToLGCategories(x['openai_moderation'])]
- })
- return jsonl_data
-
- guidelines = Guidelines(
- categories=LLAMA_GUARD_3_CATEGORY,
- category_code_prefix="O")
-
- # Setting up configs
- llama_guard_prompt_configs = LlamaGuardPromptConfigs(
- instructions_format_string="""<|begin_of_text|><|start_header_id|>user<|end_header_id|> Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
- <BEGIN UNSAFE CONTENT CATEGORIES>
- {guidelines}
- <END UNSAFE CONTENT CATEGORIES>
- <BEGIN CONVERSATION>
- {conversation}
- <END CONVERSATION>
- Provide your safety assessment for $agent_type in the above conversation:
- - First line must read 'safe' or 'unsafe'.
- - If unsafe, a second line must include a comma-separated list of violated categories. <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
- should_include_category_descriptions=False,
- should_shuffle_category_codes=False
- )
- llama_guard_generation_configs = LlamaGuardGenerationConfigs(
- should_list_violated_codes=True,
- explanation_position=None
- )
- augmentation_configs = AugmentationConfigs(
- should_add_examples_with_dropped_nonviolated_prompt_categories=False,
- should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=False,
- explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
- )
- formatter_configs = FormatterConfigs(
- guidelines=guidelines,
- llama_guard_prompt_configs=llama_guard_prompt_configs,
- llama_guard_generation_configs=llama_guard_generation_configs,
- augmentation_configs=augmentation_configs,
- random_seed=42
- )
- dataset = dataset.map(lambda x: {"full_prompt": create_formatted_finetuning_examples(
- [TrainingExample(
- prompt=x["user_input"],
- response=None,
- violated_category_codes = [] if x["toxicity"]==0 else [mapTcCategoriesToLGCategories(x["openai_moderation"])],
- label="safe" if x["toxicity"]==0 else "unsafe",
- explanation="The response contains violating information."
- )],
- formatter_configs)[0]},
- remove_columns=list(dataset.features))
- dataset = dataset.map(lambda x: tokenize_prompt_and_labels(x["full_prompt"], tokenizer), remove_columns=list(dataset.features))
- return dataset
- def main(return_jsonl = False):
- from transformers import AutoTokenizer
- model_id: str = "/home/ubuntu/LG3-interim-hf-weights"
- tokenizer = AutoTokenizer.from_pretrained(model_id)
- if return_jsonl:
- dataset = get_llamaguard_toxicchat_dataset(None, tokenizer, "train", return_jsonl = True)
- print(dataset[0:50])
- else:
- dataset = get_llamaguard_toxicchat_dataset(None, tokenizer, "train")
- print(dataset[0])
- if __name__ == '__main__':
- fire.Fire(main)
|