finetuning_data_formatter.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama Guard License Agreement.
  3. import copy
  4. import random
  5. from dataclasses import dataclass
  6. from enum import Enum
  7. from typing import Dict, List, Literal, Optional, Sequence
  8. @dataclass
  9. class Category:
  10. name: str
  11. description: str
  12. @dataclass
  13. class Guidelines:
  14. categories: Sequence[Category]
  15. category_code_prefix: str = "O"
  16. class ExplanationPosition(Enum):
  17. BEFORE_DECISION = 0
  18. AFTER_DECISION = 1
  19. @dataclass
  20. class LlamaGuardPromptConfigs:
  21. instructions_format_string: str
  22. should_include_category_descriptions: bool
  23. should_shuffle_category_codes: bool = True
  24. @dataclass
  25. class LlamaGuardGenerationConfigs:
  26. should_list_violated_codes: bool
  27. explanation_position: Optional[ExplanationPosition]
  28. @dataclass
  29. class AugmentationConfigs:
  30. should_add_examples_with_dropped_nonviolated_prompt_categories: bool = True
  31. should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories: bool = (
  32. False
  33. )
  34. explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories: Optional[
  35. str
  36. ] = None
  37. @dataclass
  38. class FormatterConfigs:
  39. guidelines: Guidelines
  40. llama_guard_prompt_configs: LlamaGuardPromptConfigs
  41. llama_guard_generation_configs: LlamaGuardGenerationConfigs
  42. augmentation_configs: AugmentationConfigs
  43. # Allows subsequent reruns to reuse a stable seed for reproducibility
  44. random_seed: int = 42
  45. @dataclass
  46. class TrainingExample:
  47. prompt: str
  48. response: str
  49. violated_category_codes: List[str]
  50. label: Literal["safe", "unsafe"]
  51. explanation: Optional[str] = None
  52. def create_formatted_finetuning_examples(
  53. training_examples: Sequence[TrainingExample],
  54. formatter_configs: FormatterConfigs,
  55. ) -> List[str]:
  56. """
  57. This formatter takes consumer-provided training examples and converts them to
  58. the right format for finetuning llama-guard.
  59. There are various configuration options available.
  60. A notable one is the ability to automagically augment the finetuning data set with some useful
  61. transformations of the original training examples. These augmentations make the
  62. classifier more flexible by improving its ability to be modified at inference time
  63. to include only a subset of the original categories it was trained on - without any
  64. additional finetuning.
  65. Some of these augmented transformations are made by duplicating training
  66. examples and safely removing some violation categories from the llama
  67. guard prompts. Because of this, in some of this file you will see
  68. references to "original" category indices/codes and rewritten ones. The originals
  69. are the indices/codes of the violation categories as they appear in the
  70. consumer-provided guidelines. The rewritten codes are the ones as they appear
  71. in the llama guard prompts of the augmented examples. We occasionally need to
  72. convert between the two.
  73. """
  74. _verify_formatter_configs(formatter_configs)
  75. random.seed(formatter_configs.random_seed)
  76. indices_of_all_categories = range(len(formatter_configs.guidelines.categories))
  77. to_return = []
  78. for training_example in training_examples:
  79. to_return.append(
  80. _create_formatted_finetuning_example(
  81. training_example,
  82. formatter_configs,
  83. category_indices_to_include_in_llama_guard_prompt=list(
  84. indices_of_all_categories
  85. ),
  86. )
  87. )
  88. _maybe_add_data_augmentations_for_example(
  89. training_example, to_return, indices_of_all_categories, formatter_configs
  90. )
  91. return to_return
  92. def _verify_formatter_configs(
  93. formatter_configs: FormatterConfigs,
  94. ) -> None:
  95. if (
  96. formatter_configs.augmentation_configs.should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories
  97. == True
  98. and formatter_configs.llama_guard_generation_configs.explanation_position
  99. is not None
  100. and formatter_configs.augmentation_configs.explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories
  101. is None
  102. ):
  103. raise ValueError(
  104. """The configuration setup requires you to specify
  105. explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories.
  106. This is an explanation that we use for dynamically-created safe augmentation examples.
  107. Consider something like 'This interaction is safe because any riskiness it contains
  108. is related to violation categories that we're explicitly not trying to detect here.'"""
  109. )
  110. def _create_formatted_finetuning_example(
  111. training_example: TrainingExample,
  112. formatter_configs: FormatterConfigs,
  113. category_indices_to_include_in_llama_guard_prompt: List[int],
  114. ) -> str:
  115. if formatter_configs.llama_guard_prompt_configs.should_shuffle_category_codes:
  116. random.shuffle(category_indices_to_include_in_llama_guard_prompt)
  117. else:
  118. category_indices_to_include_in_llama_guard_prompt = sorted(
  119. category_indices_to_include_in_llama_guard_prompt
  120. )
  121. llama_guard_prompt = _create_llama_guard_prompt(
  122. training_example,
  123. category_indices_to_include_in_llama_guard_prompt,
  124. formatter_configs,
  125. )
  126. llama_guard_generation = _create_llama_guard_generation(
  127. training_example,
  128. category_indices_to_include_in_llama_guard_prompt,
  129. formatter_configs,
  130. )
  131. return f"{llama_guard_prompt} {llama_guard_generation}"
  132. def _create_llama_guard_prompt(
  133. training_example: TrainingExample,
  134. category_indices_to_include: List[int],
  135. formatter_configs: FormatterConfigs,
  136. ) -> str:
  137. full_guidelines_text = ""
  138. for (
  139. rewritten_category_index_for_current_prompt,
  140. original_category_index,
  141. ) in enumerate(category_indices_to_include):
  142. category = formatter_configs.guidelines.categories[original_category_index]
  143. newline_for_every_category_after_first = (
  144. f"\n" if rewritten_category_index_for_current_prompt > 0 else ""
  145. )
  146. # Indices start at 0, but categories start at 1, so we add 1
  147. full_guidelines_text += f"{newline_for_every_category_after_first}{formatter_configs.guidelines.category_code_prefix}{rewritten_category_index_for_current_prompt + 1}: {category.name}. "
  148. if (
  149. formatter_configs.llama_guard_prompt_configs.should_include_category_descriptions
  150. ):
  151. full_guidelines_text += f"\n{category.description}"
  152. conversation = {"human": training_example.prompt}
  153. if not _is_a_prompt_only_example(training_example):
  154. conversation["chatbot"] = training_example.response
  155. return formatter_configs.llama_guard_prompt_configs.instructions_format_string.format_map(
  156. {
  157. "guidelines": full_guidelines_text,
  158. "conversation": _serialize_conversation(conversation),
  159. }
  160. )
  161. def _is_a_prompt_only_example(training_example: TrainingExample) -> bool:
  162. return training_example.response == "N/A"
  163. def _serialize_conversation(conversation: Dict[str, str]) -> str:
  164. conversation_as_list = []
  165. for speaker, message in conversation.items():
  166. conversation_as_list.append(f"{speaker}: {message}")
  167. return "\n\n".join(conversation_as_list)
  168. def _create_llama_guard_generation(
  169. training_example: TrainingExample,
  170. category_indices_included_in_llama_guard_prompt: List[int],
  171. formatter_configs: FormatterConfigs,
  172. ) -> str:
  173. to_return = training_example.label
  174. if (
  175. training_example.label == "unsafe"
  176. and formatter_configs.llama_guard_generation_configs.should_list_violated_codes
  177. ):
  178. violated_category_indices = set(
  179. _convert_category_codes_to_indices(
  180. training_example.violated_category_codes,
  181. formatter_configs,
  182. )
  183. )
  184. map_of_original_category_indices_to_rewritten_category_codes = (
  185. _get_map_of_original_category_indices_to_rewritten_category_codes(
  186. formatter_configs, category_indices_included_in_llama_guard_prompt
  187. )
  188. )
  189. rewritten_violated_category_codes = sorted(
  190. [
  191. map_of_original_category_indices_to_rewritten_category_codes[
  192. violated_index
  193. ]
  194. for violated_index in violated_category_indices
  195. ]
  196. )
  197. to_return += "\n"
  198. to_return += ",".join(rewritten_violated_category_codes)
  199. explanation_position = (
  200. formatter_configs.llama_guard_generation_configs.explanation_position
  201. )
  202. if explanation_position == ExplanationPosition.BEFORE_DECISION:
  203. to_return = f"Explanation: {training_example.explanation}\n{to_return}"
  204. elif explanation_position == ExplanationPosition.AFTER_DECISION:
  205. to_return = f"{to_return}\nExplanation: {training_example.explanation}"
  206. return to_return
  207. def _get_map_of_original_category_indices_to_rewritten_category_codes(
  208. formatter_configs: FormatterConfigs,
  209. category_indices_included_in_llama_guard_prompt: List[int],
  210. ) -> Dict[int, str]:
  211. to_return = {}
  212. for rewritten_category_index, original_category_index in enumerate(
  213. category_indices_included_in_llama_guard_prompt
  214. ):
  215. to_return[
  216. original_category_index
  217. ] = formatter_configs.guidelines.category_code_prefix + str(
  218. rewritten_category_index + 1
  219. )
  220. return to_return
  221. def _maybe_add_data_augmentations_for_example(
  222. training_example: TrainingExample,
  223. formatted_examples_being_built: List[str],
  224. indices_of_all_categories: range,
  225. formatter_configs: FormatterConfigs,
  226. ) -> None:
  227. violated_category_indices = _convert_category_codes_to_indices(
  228. training_example.violated_category_codes,
  229. formatter_configs,
  230. )
  231. nonviolated_category_indices = list(
  232. set(indices_of_all_categories) - set(violated_category_indices)
  233. )
  234. _maybe_add_example_with_dropped_nonviolated_prompt_categories(
  235. training_example,
  236. formatted_examples_being_built,
  237. indices_of_all_categories,
  238. nonviolated_category_indices,
  239. formatter_configs,
  240. )
  241. _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories(
  242. training_example,
  243. formatted_examples_being_built,
  244. indices_of_all_categories,
  245. violated_category_indices,
  246. nonviolated_category_indices,
  247. formatter_configs,
  248. )
  249. def _convert_category_codes_to_indices(
  250. codes: List[str], formatter_configs: FormatterConfigs
  251. ) -> List[int]:
  252. # Category codes start at 1, but indices start at 0, so we subtract 1
  253. return [
  254. int(code.lstrip(formatter_configs.guidelines.category_code_prefix)) - 1
  255. for code in codes
  256. ]
  257. def _maybe_add_example_with_dropped_nonviolated_prompt_categories(
  258. training_example: TrainingExample,
  259. formatted_examples_being_built: List[str],
  260. indices_of_all_categories: range,
  261. nonviolated_category_indices: List[int],
  262. formatter_configs: FormatterConfigs,
  263. ) -> None:
  264. """
  265. If a prompt+response pair does not violate certain categories, we can augment
  266. the data by duplicating the training example but removing some of the non-violated
  267. categories from the llama guard prompt. This facilitates removing categories from
  268. the llama guard prompt at inference time without any additional finetuning.
  269. """
  270. if (
  271. not formatter_configs.augmentation_configs.should_add_examples_with_dropped_nonviolated_prompt_categories
  272. ):
  273. return
  274. number_of_categories_to_drop = random.randint(0, len(nonviolated_category_indices))
  275. if number_of_categories_to_drop == len(indices_of_all_categories):
  276. number_of_categories_to_drop -= 1
  277. dropped_category_indices = random.sample(
  278. nonviolated_category_indices, number_of_categories_to_drop
  279. )
  280. retained_category_indices = list(
  281. set(indices_of_all_categories) - (set(dropped_category_indices))
  282. )
  283. formatted_examples_being_built.append(
  284. _create_formatted_finetuning_example(
  285. training_example,
  286. formatter_configs,
  287. category_indices_to_include_in_llama_guard_prompt=retained_category_indices,
  288. )
  289. )
  290. def _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories(
  291. training_example: TrainingExample,
  292. formatted_examples_being_built: List[str],
  293. indices_of_all_categories: range,
  294. violated_category_indices: List[int],
  295. nonviolated_category_indices: List[int],
  296. formatter_configs: FormatterConfigs,
  297. ) -> None:
  298. """
  299. Same as in _maybe_add_example_with_dropped_nonviolated_prompt_categories but we
  300. also drop all of the violated categories from the llama guard prompt.
  301. """
  302. if (
  303. training_example.label == "safe"
  304. or not formatter_configs.augmentation_configs.should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories
  305. ):
  306. return
  307. random_nonviolated_category_indices_to_drop = random.sample(
  308. nonviolated_category_indices,
  309. random.randint(0, len(nonviolated_category_indices) - 1),
  310. )
  311. set_of_retained_category_indices = (
  312. set(indices_of_all_categories)
  313. - set(violated_category_indices)
  314. - set(random_nonviolated_category_indices_to_drop)
  315. )
  316. training_example_copy = copy.deepcopy(training_example)
  317. training_example_copy.label = "safe"
  318. training_example_copy.violated_category_codes = []
  319. training_example_copy.explanation = (
  320. formatter_configs.augmentation_configs.explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories
  321. )
  322. formatted_examples_being_built.append(
  323. _create_formatted_finetuning_example(
  324. training_example_copy,
  325. formatter_configs,
  326. category_indices_to_include_in_llama_guard_prompt=list(
  327. set_of_retained_category_indices
  328. ),
  329. )
  330. )