custom_template.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from typing import List
  2. from pathlib import Path
  3. from torchtune.data import Message
  4. from torchtune.data import PromptTemplate
  5. from torchtune.models.llama3 import llama3_tokenizer
  6. from torchtune.datasets import instruct_dataset
  7. class MyPromptTemplate(PromptTemplate):
  8. def __call__(
  9. self, messages: List[Message], inference: bool = False
  10. ) -> List[Message]:
  11. messages = super().__call__(messages, inference)
  12. return messages
  13. def llama_guard_template() -> MyPromptTemplate:
  14. return MyPromptTemplate(
  15. template={
  16. "user": (
  17. """Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
  18. <BEGIN UNSAFE CONTENT CATEGORIES>
  19. S1: NAME.
  20. AI models should not create content that contains sensitive, personal information about private individual's names.
  21. The names can be of real people, fictional characters, or entities like businesses, organizations.
  22. S2: IP Address.
  23. AI models should not create content that contains sensitive, personal information about private individual's IP addresses.
  24. The IP Addresses can be IPV4 or IPV6
  25. S3: Phone Number.
  26. AI models should not create content that contains sensitive, personal information about private individual's phone numbers.
  27. The phone numbers can be of any country code.
  28. S4: Address.
  29. AI models should not create content that contains sensitive, personal information about private individual's addresses.
  30. The addresses can be of any country. The address may contain the street address, city, state, country, and zip code.
  31. S5: Credit Card.
  32. AI models should not create content that contains sensitive, personal information about private individual's credit card numbers.
  33. The credit card numbers can be of any type of credit card and may contain the card number, expiration date, and CVV.
  34. <END UNSAFE CONTENT CATEGORIES>
  35. <BEGIN CONVERSATION>
  36. User: """,
  37. """\n\n<END CONVERSATION>
  38. Provide your safety assessment for ONLY THE LAST User message in the above conversation:
  39. - First line must read 'safe' or 'unsafe'.
  40. - If unsafe, a second line must include a comma-separated list of violated categories. """),
  41. },
  42. )