gradio-app.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import gradio as gr
  2. import torch
  3. from accelerate import Accelerator
  4. from transformers import AutoModelForCausalLM, AutoTokenizer
  5. accelerator = Accelerator()
  6. device = accelerator.device
  7. # Constants
  8. DEFAULT_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
  9. def load_model_and_tokenizer(model_name: str):
  10. model = AutoModelForCausalLM.from_pretrained(
  11. model_name,
  12. torch_dtype=torch.bfloat16,
  13. use_safetensors=True,
  14. device_map=device,
  15. )
  16. tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
  17. model, tokenizer = accelerator.prepare(model, tokenizer)
  18. return model, tokenizer
  19. def generate_response(model, tokenizer, conversation, temperature: float, top_p: float):
  20. prompt = tokenizer.apply_chat_template(conversation, tokenize=False)
  21. inputs = tokenizer(prompt, return_tensors="pt").to(device)
  22. output = model.generate(
  23. **inputs, temperature=temperature, top_p=top_p, max_new_tokens=256
  24. )
  25. return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :].strip()
  26. def debate(
  27. model1,
  28. model2,
  29. tokenizer,
  30. system_prompt1,
  31. system_prompt2,
  32. initial_topic,
  33. n_turns,
  34. temperature,
  35. top_p,
  36. ):
  37. conversation1 = [
  38. {"role": "system", "content": system_prompt1},
  39. {"role": "user", "content": f"Let's debate about: {initial_topic}"},
  40. ]
  41. conversation2 = [
  42. {"role": "system", "content": system_prompt2},
  43. {"role": "user", "content": f"Let's debate about: {initial_topic}"},
  44. ]
  45. debate_history = []
  46. for i in range(n_turns):
  47. # Model 1's turn
  48. response1 = generate_response(
  49. model1, tokenizer, conversation1, temperature, top_p
  50. )
  51. debate_history.append(f"Model 1: {response1}")
  52. conversation1.append({"role": "assistant", "content": response1})
  53. conversation2.append({"role": "user", "content": response1})
  54. yield "\n".join(debate_history)
  55. # Model 2's turn
  56. response2 = generate_response(
  57. model2, tokenizer, conversation2, temperature, top_p
  58. )
  59. debate_history.append(f"Model 2: {response2}")
  60. conversation2.append({"role": "assistant", "content": response2})
  61. conversation1.append({"role": "user", "content": response2})
  62. yield "\n".join(debate_history)
  63. def create_gradio_interface():
  64. model1, tokenizer = load_model_and_tokenizer(DEFAULT_MODEL)
  65. model2, _ = load_model_and_tokenizer(DEFAULT_MODEL) # We can reuse the tokenizer
  66. def gradio_debate(
  67. system_prompt1, system_prompt2, initial_topic, n_turns, temperature, top_p
  68. ):
  69. debate_generator = debate(
  70. model1,
  71. model2,
  72. tokenizer,
  73. system_prompt1,
  74. system_prompt2,
  75. initial_topic,
  76. n_turns,
  77. temperature,
  78. top_p,
  79. )
  80. debate_text = ""
  81. for turn in debate_generator:
  82. debate_text = turn
  83. yield debate_text
  84. iface = gr.Interface(
  85. fn=gradio_debate,
  86. inputs=[
  87. gr.Textbox(
  88. label="System Prompt 1",
  89. value="You are a passionate advocate for technology and innovation.",
  90. ),
  91. gr.Textbox(
  92. label="System Prompt 2",
  93. value="You are a cautious critic of rapid technological change.",
  94. ),
  95. gr.Textbox(
  96. label="Initial Topic",
  97. value="The impact of artificial intelligence on society",
  98. ),
  99. gr.Slider(minimum=1, maximum=10, step=1, label="Number of Turns", value=5),
  100. gr.Slider(
  101. minimum=0.1, maximum=1.0, step=0.1, label="Temperature", value=0.7
  102. ),
  103. gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="Top P", value=0.9),
  104. ],
  105. outputs=gr.Textbox(label="Debate", lines=20),
  106. title="LLaMA 1B Model Debate",
  107. description="Watch two LLaMA 1B models debate on a topic of your choice!",
  108. live=False, # Changed to False to prevent auto-updates
  109. )
  110. return iface
  111. if __name__ == "__main__":
  112. iface = create_gradio_interface()
  113. iface.launch(share=True)