1B-debating-script.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import argparse
  2. import torch
  3. from accelerate import Accelerator
  4. from transformers import AutoModelForCausalLM, AutoTokenizer
  5. accelerator = Accelerator()
  6. device = accelerator.device
  7. # Constants
  8. DEFAULT_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
  9. def load_model_and_tokenizer(model_name: str):
  10. """
  11. Load the model and tokenizer for LLaMA-1B.
  12. """
  13. model = AutoModelForCausalLM.from_pretrained(
  14. model_name,
  15. torch_dtype=torch.bfloat16,
  16. use_safetensors=True,
  17. device_map=device,
  18. )
  19. tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
  20. model, tokenizer = accelerator.prepare(model, tokenizer)
  21. return model, tokenizer
  22. def generate_response(model, tokenizer, conversation, temperature: float, top_p: float):
  23. """
  24. Generate a response based on the conversation history.
  25. """
  26. prompt = tokenizer.apply_chat_template(conversation, tokenize=False)
  27. inputs = tokenizer(prompt, return_tensors="pt").to(device)
  28. output = model.generate(
  29. **inputs, temperature=temperature, top_p=top_p, max_new_tokens=256
  30. )
  31. return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :].strip()
  32. def debate(
  33. model1,
  34. model2,
  35. tokenizer,
  36. system_prompt1,
  37. system_prompt2,
  38. initial_topic,
  39. n_turns,
  40. temperature,
  41. top_p,
  42. ):
  43. """
  44. Conduct a debate between two models.
  45. """
  46. conversation1 = [
  47. {"role": "system", "content": system_prompt1},
  48. {"role": "user", "content": f"Let's debate about: {initial_topic}"},
  49. ]
  50. conversation2 = [
  51. {"role": "system", "content": system_prompt2},
  52. {"role": "user", "content": f"Let's debate about: {initial_topic}"},
  53. ]
  54. for i in range(n_turns):
  55. print(f"\nTurn {i+1}:")
  56. # Model 1's turn
  57. response1 = generate_response(
  58. model1, tokenizer, conversation1, temperature, top_p
  59. )
  60. print(f"Model 1: {response1}")
  61. conversation1.append({"role": "assistant", "content": response1})
  62. conversation2.append({"role": "user", "content": response1})
  63. # Model 2's turn
  64. response2 = generate_response(
  65. model2, tokenizer, conversation2, temperature, top_p
  66. )
  67. print(f"Model 2: {response2}")
  68. conversation2.append({"role": "assistant", "content": response2})
  69. conversation1.append({"role": "user", "content": response2})
  70. def main(
  71. system_prompt1: str,
  72. system_prompt2: str,
  73. initial_topic: str,
  74. n_turns: int,
  75. temperature: float,
  76. top_p: float,
  77. model_name: str,
  78. ):
  79. """
  80. Set up and run the debate.
  81. """
  82. model1, tokenizer = load_model_and_tokenizer(model_name)
  83. model2, _ = load_model_and_tokenizer(model_name) # We can reuse the tokenizer
  84. debate(
  85. model1,
  86. model2,
  87. tokenizer,
  88. system_prompt1,
  89. system_prompt2,
  90. initial_topic,
  91. n_turns,
  92. temperature,
  93. top_p,
  94. )
  95. if __name__ == "__main__":
  96. parser = argparse.ArgumentParser(
  97. description="Conduct a debate between two LLaMA-1B models."
  98. )
  99. parser.add_argument(
  100. "--system_prompt1",
  101. type=str,
  102. default="You are a passionate advocate for technology and innovation.",
  103. help="System prompt for the first model (default: 'You are a passionate advocate for technology and innovation.')",
  104. )
  105. parser.add_argument(
  106. "--system_prompt2",
  107. type=str,
  108. default="You are a cautious critic of rapid technological change.",
  109. help="System prompt for the second model (default: 'You are a cautious critic of rapid technological change.')",
  110. )
  111. parser.add_argument(
  112. "--initial_topic", type=str, required=True, help="Initial topic for the debate"
  113. )
  114. parser.add_argument(
  115. "--n_turns",
  116. type=int,
  117. default=5,
  118. help="Number of turns in the debate (default: 5)",
  119. )
  120. parser.add_argument(
  121. "--temperature",
  122. type=float,
  123. default=0.7,
  124. help="Temperature for generation (default: 0.7)",
  125. )
  126. parser.add_argument(
  127. "--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)"
  128. )
  129. parser.add_argument(
  130. "--model_name",
  131. type=str,
  132. default=DEFAULT_MODEL,
  133. help=f"Model name (default: '{DEFAULT_MODEL}')",
  134. )
  135. args = parser.parse_args()
  136. main(
  137. args.system_prompt1,
  138. args.system_prompt2,
  139. args.initial_topic,
  140. args.n_turns,
  141. args.temperature,
  142. args.top_p,
  143. args.model_name,
  144. )