Browse Source

Gradio-App

Sanyam Bhutani 7 months ago
parent
commit
c9fffe8b0b
1 changed files with 131 additions and 0 deletions
  1. 131 0
      recipes/quickstart/NotebookLlama/gradio-app.py

+ 131 - 0
recipes/quickstart/NotebookLlama/gradio-app.py

@@ -0,0 +1,131 @@
+import gradio as gr
+import torch
+from accelerate import Accelerator
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+accelerator = Accelerator()
+device = accelerator.device
+
+# Constants
+DEFAULT_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
+
+
+def load_model_and_tokenizer(model_name: str):
+    model = AutoModelForCausalLM.from_pretrained(
+        model_name,
+        torch_dtype=torch.bfloat16,
+        use_safetensors=True,
+        device_map=device,
+    )
+    tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
+
+    model, tokenizer = accelerator.prepare(model, tokenizer)
+    return model, tokenizer
+
+
+def generate_response(model, tokenizer, conversation, temperature: float, top_p: float):
+    prompt = tokenizer.apply_chat_template(conversation, tokenize=False)
+    inputs = tokenizer(prompt, return_tensors="pt").to(device)
+    output = model.generate(
+        **inputs, temperature=temperature, top_p=top_p, max_new_tokens=256
+    )
+    return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :].strip()
+
+
+def debate(
+    model1,
+    model2,
+    tokenizer,
+    system_prompt1,
+    system_prompt2,
+    initial_topic,
+    n_turns,
+    temperature,
+    top_p,
+):
+    conversation1 = [
+        {"role": "system", "content": system_prompt1},
+        {"role": "user", "content": f"Let's debate about: {initial_topic}"},
+    ]
+    conversation2 = [
+        {"role": "system", "content": system_prompt2},
+        {"role": "user", "content": f"Let's debate about: {initial_topic}"},
+    ]
+
+    debate_history = []
+
+    for i in range(n_turns):
+        # Model 1's turn
+        response1 = generate_response(
+            model1, tokenizer, conversation1, temperature, top_p
+        )
+        debate_history.append(f"Model 1: {response1}")
+        conversation1.append({"role": "assistant", "content": response1})
+        conversation2.append({"role": "user", "content": response1})
+        yield "\n".join(debate_history)
+
+        # Model 2's turn
+        response2 = generate_response(
+            model2, tokenizer, conversation2, temperature, top_p
+        )
+        debate_history.append(f"Model 2: {response2}")
+        conversation2.append({"role": "assistant", "content": response2})
+        conversation1.append({"role": "user", "content": response2})
+        yield "\n".join(debate_history)
+
+
+def create_gradio_interface():
+    model1, tokenizer = load_model_and_tokenizer(DEFAULT_MODEL)
+    model2, _ = load_model_and_tokenizer(DEFAULT_MODEL)  # We can reuse the tokenizer
+
+    def gradio_debate(
+        system_prompt1, system_prompt2, initial_topic, n_turns, temperature, top_p
+    ):
+        debate_generator = debate(
+            model1,
+            model2,
+            tokenizer,
+            system_prompt1,
+            system_prompt2,
+            initial_topic,
+            n_turns,
+            temperature,
+            top_p,
+        )
+        debate_text = ""
+        for turn in debate_generator:
+            debate_text = turn
+            yield debate_text
+
+    iface = gr.Interface(
+        fn=gradio_debate,
+        inputs=[
+            gr.Textbox(
+                label="System Prompt 1",
+                value="You are a passionate advocate for technology and innovation.",
+            ),
+            gr.Textbox(
+                label="System Prompt 2",
+                value="You are a cautious critic of rapid technological change.",
+            ),
+            gr.Textbox(
+                label="Initial Topic",
+                value="The impact of artificial intelligence on society",
+            ),
+            gr.Slider(minimum=1, maximum=10, step=1, label="Number of Turns", value=5),
+            gr.Slider(
+                minimum=0.1, maximum=1.0, step=0.1, label="Temperature", value=0.7
+            ),
+            gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="Top P", value=0.9),
+        ],
+        outputs=gr.Textbox(label="Debate", lines=20),
+        title="LLaMA 1B Model Debate",
+        description="Watch two LLaMA 1B models debate on a topic of your choice!",
+        live=False,  # Changed to False to prevent auto-updates
+    )
+    return iface
+
+
+if __name__ == "__main__":
+    iface = create_gradio_interface()
+    iface.launch(share=True)