toxicchat_dataset.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3.1 Community License Agreement.
  3. # For dataset details visit: https://huggingface.co/datasets/lmsys/toxic-chat
  4. import copy
  5. import datasets
  6. import itertools
  7. from llama_recipes.inference.prompt_format_utils import LLAMA_GUARD_3_CATEGORY
  8. import ast
  9. import fire
  10. def tokenize_prompt_and_labels(full_prompt, tokenizer):
  11. prompt_tokens = tokenizer.encode(full_prompt)
  12. combined_tokens = {
  13. "input_ids": list(prompt_tokens),
  14. "labels": list(prompt_tokens)
  15. }
  16. return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
  17. from llama_recipes.data.llama_guard.finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples
  18. from datasets import Dataset, DatasetInfo
  19. def mapTcCategoriesToLGCategories(TcCategoriesString):
  20. TcCategories = ast.literal_eval(TcCategoriesString)
  21. if(len(TcCategories)==0):
  22. return None
  23. ranked = sorted(TcCategories, key=lambda x: x[1], reverse=True)
  24. primary = ranked[0][0] if len(ranked) else None
  25. TcMapping = {
  26. "sexual":"012",
  27. "violence":"01",
  28. "sexual/minors":"04",
  29. "self-harm/intent":"011",
  30. "hate":"010",
  31. "harassment":"010",
  32. "self-harm":"011",
  33. "self-harm/instructions":"011",
  34. "violence/graphic":"01",
  35. "harassment/threatening":"010",
  36. "hate/threatening":"010"
  37. }
  38. return TcMapping[primary]
  39. def get_llamaguard_toxicchat_dataset(dataset_config, tokenizer, split, return_jsonl=False):
  40. dataset = datasets.load_dataset("lmsys/toxic-chat", 'toxicchat0124', split=split)
  41. if return_jsonl:
  42. jsonl_data = [];
  43. for x in dataset.to_iterable_dataset():
  44. jsonl_data.append({
  45. "prompt":x['user_input'],
  46. "generation":x['model_output'],
  47. "label":("good" if x['toxicity'] == 0 else "bad"),
  48. "unsafe_content": [mapTcCategoriesToLGCategories(x['openai_moderation'])]
  49. })
  50. return jsonl_data
  51. guidelines = Guidelines(
  52. categories=LLAMA_GUARD_3_CATEGORY,
  53. category_code_prefix="O")
  54. # Setting up configs
  55. llama_guard_prompt_configs = LlamaGuardPromptConfigs(
  56. instructions_format_string="""<|begin_of_text|><|start_header_id|>user<|end_header_id|> Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
  57. <BEGIN UNSAFE CONTENT CATEGORIES>
  58. {guidelines}
  59. <END UNSAFE CONTENT CATEGORIES>
  60. <BEGIN CONVERSATION>
  61. {conversation}
  62. <END CONVERSATION>
  63. Provide your safety assessment for $agent_type in the above conversation:
  64. - First line must read 'safe' or 'unsafe'.
  65. - If unsafe, a second line must include a comma-separated list of violated categories. <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
  66. should_include_category_descriptions=False,
  67. should_shuffle_category_codes=False
  68. )
  69. llama_guard_generation_configs = LlamaGuardGenerationConfigs(
  70. should_list_violated_codes=True,
  71. explanation_position=None
  72. )
  73. augmentation_configs = AugmentationConfigs(
  74. should_add_examples_with_dropped_nonviolated_prompt_categories=False,
  75. should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=False,
  76. explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
  77. )
  78. formatter_configs = FormatterConfigs(
  79. guidelines=guidelines,
  80. llama_guard_prompt_configs=llama_guard_prompt_configs,
  81. llama_guard_generation_configs=llama_guard_generation_configs,
  82. augmentation_configs=augmentation_configs,
  83. random_seed=42
  84. )
  85. dataset = dataset.map(lambda x: {"full_prompt": create_formatted_finetuning_examples(
  86. [TrainingExample(
  87. prompt=x["user_input"],
  88. response=None,
  89. violated_category_codes = [] if x["toxicity"]==0 else [mapTcCategoriesToLGCategories(x["openai_moderation"])],
  90. label="safe" if x["toxicity"]==0 else "unsafe",
  91. explanation="The response contains violating information."
  92. )],
  93. formatter_configs)[0]},
  94. remove_columns=list(dataset.features))
  95. dataset = dataset.map(lambda x: tokenize_prompt_and_labels(x["full_prompt"], tokenizer), remove_columns=list(dataset.features))
  96. return dataset
  97. def main(return_jsonl = False):
  98. from transformers import AutoTokenizer
  99. model_id: str = "/home/ubuntu/LG3-interim-hf-weights"
  100. tokenizer = AutoTokenizer.from_pretrained(model_id)
  101. if return_jsonl:
  102. dataset = get_llamaguard_toxicchat_dataset(None, tokenizer, "train", return_jsonl = True)
  103. print(dataset[0:50])
  104. else:
  105. dataset = get_llamaguard_toxicchat_dataset(None, tokenizer, "train")
  106. print(dataset[0])
  107. if __name__ == '__main__':
  108. fire.Fire(main)