custom_template.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from typing import List
  2. from pathlib import Path
  3. from torchtune.data import Message
  4. from torchtune.data import PromptTemplate
  5. from torchtune.models.llama3 import llama3_tokenizer
  6. from torchtune.datasets import instruct_dataset
  7. class MyPromptTemplate(PromptTemplate):
  8. def __call__(
  9. self, messages: List[Message], inference: bool = False
  10. ) -> List[Message]:
  11. messages = super().__call__(messages, inference)
  12. return messages
  13. def llama_guard_template() -> MyPromptTemplate:
  14. return MyPromptTemplate(
  15. template={
  16. "user": (
  17. """Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
  18. <BEGIN UNSAFE CONTENT CATEGORIES>
  19. S1: Violent Crimes.
  20. S2: Non-Violent Crimes.
  21. S3: Sex Crimes.
  22. S4: Child Exploitation.
  23. S5: Defamation.
  24. S6: Specialized Advice.
  25. S7: Privacy.
  26. S8: Intellectual Property.
  27. S9: Indiscriminate Weapons.
  28. S10: Hate.
  29. S11: Self-Harm.
  30. S12: Sexual Content.
  31. S13: Elections.
  32. <END UNSAFE CONTENT CATEGORIES>
  33. <BEGIN CONVERSATION>
  34. User: """,
  35. """\n\n<END CONVERSATION>
  36. Provide your safety assessment for ONLY THE LAST User message in the above conversation:
  37. - First line must read 'safe' or 'unsafe'.
  38. - If unsafe, a second line must include a comma-separated list of violated categories. """),
  39. },
  40. )