Jelajahi Sumber

Changes the UI from textbox to chatbox with max_tokens, rop_k, temperature and top_p sliders there.

himanshushukla12 6 bulan lalu
induk
melakukan
c0405b6efc

+ 122 - 127
recipes/quickstart/inference/local_inference/multi_modal_infer_Gradio_UI.py

@@ -1,142 +1,137 @@
-import os
-from PIL import Image as PIL_Image
+import gradio as gr
 import torch
-from transformers import MllamaForConditionalGeneration, MllamaProcessor
+import os
+from PIL import Image
 from accelerate import Accelerator
-import gradio as gr
-import gc  # Import garbage collector
-import argparse  # Import argparse for command-line arguments
+from transformers import MllamaForConditionalGeneration, AutoModelForCausalLM, AutoProcessor, GenerationConfig
+accelerate=Accelerator()
+device = accelerate.device
+# Set memory management for PyTorch
+os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'  # or adjust size as needed
+
+
+
+model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
+model = MllamaForConditionalGeneration.from_pretrained(
+    model_id,
+    torch_dtype=torch.bfloat16,
+    device_map=device,
+)
+processor = AutoProcessor.from_pretrained(model_id)
 
 
+# Visual theme
+visual_theme = gr.themes.Default()  # Default, Soft or Monochrome
+
 # Constants
-DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
+MAX_OUTPUT_TOKENS = 2048
+MAX_IMAGE_SIZE = (1120, 1120)
 
-class LlamaInference:
-    def __init__(self, model_name=DEFAULT_MODEL, hf_token=None):
-        """
-        Initialize the inference class. Load model and processor.
-        """
-        self.accelerator = Accelerator()
-        self.device = self.accelerator.device
-        
-        self.model_name = model_name
-        self.hf_token = hf_token
-        
-        if self.hf_token is None:
-            raise ValueError("Error: Hugging Face token not provided.")
+# Function to process the image and generate a description
+def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
+    # Resize image if necessary
+    image = image.resize(MAX_IMAGE_SIZE)
 
-        # Load model and processor
-        self.model, self.processor = self.load_model_and_processor()
+    # Initialize cleaned_output variable
+    cleaned_output = ""
 
-    def load_model_and_processor(self):
-        """
-        Load the model and processor based on the model name.
-        """
-        model = MllamaForConditionalGeneration.from_pretrained(self.model_name, 
-                                                               torch_dtype=torch.bfloat16, 
-                                                               use_safetensors=True, 
-                                                               device_map=self.device, 
-                                                               token=self.hf_token)
-        processor = MllamaProcessor.from_pretrained(self.model_name, 
-                                                    token=self.hf_token, 
-                                                    use_safetensors=True)
-
-        # Prepare model and processor with accelerator
-        model, processor = self.accelerator.prepare(model, processor)
-        return model, processor
-
-    def process_image(self, image) -> PIL_Image.Image:
-        """
-        Open and convert an uploaded image to RGB format.
-        """
-        return image.convert("RGB")
 
-    def generate_text_from_image(self, image, prompt_text: str, temperature: float, top_p: float):
-        """
-        Generate text from an image using the model and processor.
-        """
-        conversation = [
-            {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
-        ]
-        prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
-        inputs = self.processor(image, prompt, return_tensors="pt").to(self.device)
-
-        # Perform inference without computing gradients to save memory
-        with torch.no_grad():
-            output = self.model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
-        
-        return self.processor.decode(output[0])[len(prompt):]
+    prompt = f"<|image|><|begin_of_text|>{user_prompt} Answer:"
+    # Preprocess the image and prompt
+    inputs = processor(image, prompt, return_tensors="pt").to(device)
 
-    def cleanup(self):
-        """
-        Clean up instance variables to release memory.
-        """
-        # Move model and processor to CPU before deleting to free up GPU memory
-        self.model.to('cpu')
-        del self.model
-        del self.processor
-        torch.cuda.empty_cache()  # Release GPU memory
-        gc.collect()  # Force garbage collection
-
-        # Clear other instance variables
-        del self.accelerator
-        del self.device
-        del self.hf_token
-
-        print("Cleanup complete. Instance variables deleted and memory cleared.")
-
-
-def inference(image, prompt_text, temperature, top_p, hf_token):
-    """
-    Main inference function to handle Gradio inputs and manage memory cleanup.
-    """
-    # Initialize the inference instance (this loads the model)
-    llama_inference = LlamaInference(hf_token=hf_token)
-
-    try:
-        # Process the image and generate text
-        processed_image = llama_inference.process_image(image)
-        result = llama_inference.generate_text_from_image(processed_image, prompt_text, temperature, top_p)
-    finally:
-        # Perform memory cleanup
-        llama_inference.cleanup()
-
-    return result
-
-
-# Gradio UI
-def create_gradio_interface(hf_token):
-    """
-    Create the Gradio interface for image-to-text generation.
-    """
-    # Define the input components
-    image_input = gr.Image(type="pil", label="Upload Image")
-    prompt_input = gr.Textbox(lines=2, placeholder="Enter your prompt text", label="Prompt")
-    temperature_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature")
-    top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P")
-
-    # Define the output component
-    output_text = gr.Textbox(label="Generated Text")
-
-    # Create the interface
-    interface = gr.Interface(
-        fn=lambda image, prompt_text, temperature, top_p: inference(image, prompt_text, temperature, top_p, hf_token),
-        inputs=[image_input, prompt_input, temperature_input, top_p_input],
-        outputs=output_text,
-        title="LLama-3.2 Vision-Instruct",
-        description="Generate descriptive text from an image using the LLama model.",
-        theme="default",
+    # Generate output with model
+    output = model.generate(
+        **inputs,
+        max_new_tokens=min(max_tokens, MAX_OUTPUT_TOKENS),
+        temperature=temperature,
+        top_k=top_k,
+        top_p=top_p
     )
+
+    # Decode the raw output
+    raw_output = processor.decode(output[0])
+    
+    # Clean up the output to remove system tokens
+    cleaned_output = raw_output.replace("<|image|><|begin_of_text|>", "").strip().replace(" Answer:", "")
+
     
-    # Launch the Gradio interface
-    interface.launch()
+    # Ensure the prompt is not repeated in the output
+    if cleaned_output.startswith(user_prompt):
+        cleaned_output = cleaned_output[len(user_prompt):].strip()
+        
+    # Append the new conversation to the history
+    history.append((user_prompt, cleaned_output))
 
+    return history
 
-if __name__ == "__main__":
-    # Parse command-line arguments
-    parser = argparse.ArgumentParser(description="Run LLama-3.2 Vision-Instruct with HF token passed via arguments.")
-    parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token to access the model")
-    args = parser.parse_args()
+# Function to clear the chat history
+def clear_chat():
+    return []
 
-    # Pass the HF token to Gradio interface
-    create_gradio_interface(hf_token=args.hf_token)
+# Gradio Interface
+def gradio_interface():
+    with gr.Blocks(visual_theme) as demo:
+        gr.HTML(
+        """
+    <h1 style='text-align: center'>
+    Clean-UI
+    </h1>
+    """)
+        with gr.Row():
+            # Left column with image and parameter inputs
+            with gr.Column(scale=1):
+                image_input = gr.Image(
+                    label="Image", 
+                    type="pil", 
+                    image_mode="RGB", 
+                    height=512,  # Set the height
+                    width=512   # Set the width
+                )
+
+                # Parameter sliders
+                temperature = gr.Slider(
+                    label="Temperature", minimum=0.1, maximum=2.0, value=0.6, step=0.1, interactive=True)
+                top_k = gr.Slider(
+                    label="Top-k", minimum=1, maximum=100, value=50, step=1, interactive=True)
+                top_p = gr.Slider(
+                    label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1, interactive=True)
+                max_tokens = gr.Slider(
+                    label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50, interactive=True)
+
+            # Right column with the chat interface
+            with gr.Column(scale=2):
+                chat_history = gr.Chatbot(label="Chat", height=512)
+
+                # User input box for prompt
+                user_prompt = gr.Textbox(
+                    show_label=False,
+                    container=False,
+                    placeholder="Enter your prompt", 
+                    lines=2
+                )
+
+                # Generate and Clear buttons
+                with gr.Row():
+                    generate_button = gr.Button("Generate")
+                    clear_button = gr.Button("Clear")
+
+                # Define the action for the generate button
+                generate_button.click(
+                    fn=describe_image, 
+                    inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history],
+                    outputs=[chat_history]
+                )
+
+                # Define the action for the clear button
+                clear_button.click(
+                    fn=clear_chat,
+                    inputs=[],
+                    outputs=[chat_history]
+                )
+
+    return demo
+
+# Launch the interface
+demo = gradio_interface()
+demo.launch()