123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- import gradio as gr
- import torch
- import os
- from PIL import Image
- from accelerate import Accelerator
- from transformers import MllamaForConditionalGeneration, AutoProcessor
- import argparse # Import argparse
- # Parse the command line arguments
- parser = argparse.ArgumentParser(description="Run Gradio app with Hugging Face model")
- parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face authentication token")
- args = parser.parse_args()
- # Hugging Face token
- hf_token = args.hf_token
- # Initialize Accelerator
- accelerate = Accelerator()
- device = accelerate.device
- # Set memory management for PyTorch
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' # or adjust size as needed
- # Model ID
- model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
- # Load model with the Hugging Face token
- model = MllamaForConditionalGeneration.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
- device_map=device,
- use_auth_token=hf_token # Pass the Hugging Face token here
- )
- # Load the processor
- processor = AutoProcessor.from_pretrained(model_id, use_auth_token=hf_token)
- # Visual theme
- visual_theme = gr.themes.Default() # Default, Soft or Monochrome
- # Constants
- MAX_OUTPUT_TOKENS = 2048
- MAX_IMAGE_SIZE = (1120, 1120)
- # Function to process the image and generate a description
- def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
- # Initialize cleaned_output variable
- cleaned_output = ""
- if image is not None:
- # Resize image if necessary
- image = image.resize(MAX_IMAGE_SIZE)
- prompt = f"<|image|><|begin_of_text|>{user_prompt} Answer:"
- # Preprocess the image and prompt
- inputs = processor(image, prompt, return_tensors="pt").to(device)
- else:
- # Text-only input if no image is provided
- prompt = f"<|begin_of_text|>{user_prompt} Answer:"
- # Preprocess the prompt only (no image)
- inputs = processor(prompt, return_tensors="pt").to(device)
- # Generate output with model
- output = model.generate(
- **inputs,
- max_new_tokens=min(max_tokens, MAX_OUTPUT_TOKENS),
- temperature=temperature,
- top_k=top_k,
- top_p=top_p
- )
- # Decode the raw output
- raw_output = processor.decode(output[0])
- # Clean up the output to remove system tokens
- cleaned_output = raw_output.replace("<|image|><|begin_of_text|>", "").strip().replace(" Answer:", "")
- # Ensure the prompt is not repeated in the output
- if cleaned_output.startswith(user_prompt):
- cleaned_output = cleaned_output[len(user_prompt):].strip()
- # Append the new conversation to the history
- history.append((user_prompt, cleaned_output))
- return history
- # Function to clear the chat history
- def clear_chat():
- return []
- # Gradio Interface
- def gradio_interface():
- with gr.Blocks(visual_theme) as demo:
- gr.HTML(
- """
- <h1 style='text-align: center'>
- meta-llama/Llama-3.2-11B-Vision-Instruct
- </h1>
- """)
- with gr.Row():
- # Left column with image and parameter inputs
- with gr.Column(scale=1):
- image_input = gr.Image(
- label="Image",
- type="pil",
- image_mode="RGB",
- height=512, # Set the height
- width=512 # Set the width
- )
- # Parameter sliders
- temperature = gr.Slider(
- label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1, interactive=True)
- top_k = gr.Slider(
- label="Top-k", minimum=1, maximum=100, value=50, step=1, interactive=True)
- top_p = gr.Slider(
- label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1, interactive=True)
- max_tokens = gr.Slider(
- label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50, interactive=True)
- # Right column with the chat interface
- with gr.Column(scale=2):
- chat_history = gr.Chatbot(label="Chat", height=512)
- # User input box for prompt
- user_prompt = gr.Textbox(
- show_label=False,
- container=False,
- placeholder="Enter your prompt",
- lines=2
- )
- # Generate and Clear buttons
- with gr.Row():
- generate_button = gr.Button("Generate")
- clear_button = gr.Button("Clear")
- # Define the action for the generate button
- generate_button.click(
- fn=describe_image,
- inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history],
- outputs=[chat_history]
- )
- # Define the action for the clear button
- clear_button.click(
- fn=clear_chat,
- inputs=[],
- outputs=[chat_history]
- )
- return demo
- # Launch the interface
- demo = gradio_interface()
- # demo.launch(server_name="0.0.0.0", server_port=12003)
- demo.launch()
|