|
@@ -1,142 +1,137 @@
|
|
|
-import os
|
|
|
-from PIL import Image as PIL_Image
|
|
|
+import gradio as gr
|
|
|
import torch
|
|
|
-from transformers import MllamaForConditionalGeneration, MllamaProcessor
|
|
|
+import os
|
|
|
+from PIL import Image
|
|
|
from accelerate import Accelerator
|
|
|
-import gradio as gr
|
|
|
-import gc # Import garbage collector
|
|
|
-import argparse # Import argparse for command-line arguments
|
|
|
+from transformers import MllamaForConditionalGeneration, AutoModelForCausalLM, AutoProcessor, GenerationConfig
|
|
|
+accelerate=Accelerator()
|
|
|
+device = accelerate.device
|
|
|
+# Set memory management for PyTorch
|
|
|
+os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' # or adjust size as needed
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
|
|
+model = MllamaForConditionalGeneration.from_pretrained(
|
|
|
+ model_id,
|
|
|
+ torch_dtype=torch.bfloat16,
|
|
|
+ device_map=device,
|
|
|
+)
|
|
|
+processor = AutoProcessor.from_pretrained(model_id)
|
|
|
|
|
|
|
|
|
+# Visual theme
|
|
|
+visual_theme = gr.themes.Default() # Default, Soft or Monochrome
|
|
|
+
|
|
|
# Constants
|
|
|
-DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
|
|
+MAX_OUTPUT_TOKENS = 2048
|
|
|
+MAX_IMAGE_SIZE = (1120, 1120)
|
|
|
|
|
|
-class LlamaInference:
|
|
|
- def __init__(self, model_name=DEFAULT_MODEL, hf_token=None):
|
|
|
- """
|
|
|
- Initialize the inference class. Load model and processor.
|
|
|
- """
|
|
|
- self.accelerator = Accelerator()
|
|
|
- self.device = self.accelerator.device
|
|
|
-
|
|
|
- self.model_name = model_name
|
|
|
- self.hf_token = hf_token
|
|
|
-
|
|
|
- if self.hf_token is None:
|
|
|
- raise ValueError("Error: Hugging Face token not provided.")
|
|
|
+# Function to process the image and generate a description
|
|
|
+def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
|
|
|
+ # Resize image if necessary
|
|
|
+ image = image.resize(MAX_IMAGE_SIZE)
|
|
|
|
|
|
- # Load model and processor
|
|
|
- self.model, self.processor = self.load_model_and_processor()
|
|
|
+ # Initialize cleaned_output variable
|
|
|
+ cleaned_output = ""
|
|
|
|
|
|
- def load_model_and_processor(self):
|
|
|
- """
|
|
|
- Load the model and processor based on the model name.
|
|
|
- """
|
|
|
- model = MllamaForConditionalGeneration.from_pretrained(self.model_name,
|
|
|
- torch_dtype=torch.bfloat16,
|
|
|
- use_safetensors=True,
|
|
|
- device_map=self.device,
|
|
|
- token=self.hf_token)
|
|
|
- processor = MllamaProcessor.from_pretrained(self.model_name,
|
|
|
- token=self.hf_token,
|
|
|
- use_safetensors=True)
|
|
|
-
|
|
|
- # Prepare model and processor with accelerator
|
|
|
- model, processor = self.accelerator.prepare(model, processor)
|
|
|
- return model, processor
|
|
|
-
|
|
|
- def process_image(self, image) -> PIL_Image.Image:
|
|
|
- """
|
|
|
- Open and convert an uploaded image to RGB format.
|
|
|
- """
|
|
|
- return image.convert("RGB")
|
|
|
|
|
|
- def generate_text_from_image(self, image, prompt_text: str, temperature: float, top_p: float):
|
|
|
- """
|
|
|
- Generate text from an image using the model and processor.
|
|
|
- """
|
|
|
- conversation = [
|
|
|
- {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
|
|
|
- ]
|
|
|
- prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
|
|
- inputs = self.processor(image, prompt, return_tensors="pt").to(self.device)
|
|
|
-
|
|
|
- # Perform inference without computing gradients to save memory
|
|
|
- with torch.no_grad():
|
|
|
- output = self.model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
|
|
|
-
|
|
|
- return self.processor.decode(output[0])[len(prompt):]
|
|
|
+ prompt = f"<|image|><|begin_of_text|>{user_prompt} Answer:"
|
|
|
+ # Preprocess the image and prompt
|
|
|
+ inputs = processor(image, prompt, return_tensors="pt").to(device)
|
|
|
|
|
|
- def cleanup(self):
|
|
|
- """
|
|
|
- Clean up instance variables to release memory.
|
|
|
- """
|
|
|
- # Move model and processor to CPU before deleting to free up GPU memory
|
|
|
- self.model.to('cpu')
|
|
|
- del self.model
|
|
|
- del self.processor
|
|
|
- torch.cuda.empty_cache() # Release GPU memory
|
|
|
- gc.collect() # Force garbage collection
|
|
|
-
|
|
|
- # Clear other instance variables
|
|
|
- del self.accelerator
|
|
|
- del self.device
|
|
|
- del self.hf_token
|
|
|
-
|
|
|
- print("Cleanup complete. Instance variables deleted and memory cleared.")
|
|
|
-
|
|
|
-
|
|
|
-def inference(image, prompt_text, temperature, top_p, hf_token):
|
|
|
- """
|
|
|
- Main inference function to handle Gradio inputs and manage memory cleanup.
|
|
|
- """
|
|
|
- # Initialize the inference instance (this loads the model)
|
|
|
- llama_inference = LlamaInference(hf_token=hf_token)
|
|
|
-
|
|
|
- try:
|
|
|
- # Process the image and generate text
|
|
|
- processed_image = llama_inference.process_image(image)
|
|
|
- result = llama_inference.generate_text_from_image(processed_image, prompt_text, temperature, top_p)
|
|
|
- finally:
|
|
|
- # Perform memory cleanup
|
|
|
- llama_inference.cleanup()
|
|
|
-
|
|
|
- return result
|
|
|
-
|
|
|
-
|
|
|
-# Gradio UI
|
|
|
-def create_gradio_interface(hf_token):
|
|
|
- """
|
|
|
- Create the Gradio interface for image-to-text generation.
|
|
|
- """
|
|
|
- # Define the input components
|
|
|
- image_input = gr.Image(type="pil", label="Upload Image")
|
|
|
- prompt_input = gr.Textbox(lines=2, placeholder="Enter your prompt text", label="Prompt")
|
|
|
- temperature_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature")
|
|
|
- top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P")
|
|
|
-
|
|
|
- # Define the output component
|
|
|
- output_text = gr.Textbox(label="Generated Text")
|
|
|
-
|
|
|
- # Create the interface
|
|
|
- interface = gr.Interface(
|
|
|
- fn=lambda image, prompt_text, temperature, top_p: inference(image, prompt_text, temperature, top_p, hf_token),
|
|
|
- inputs=[image_input, prompt_input, temperature_input, top_p_input],
|
|
|
- outputs=output_text,
|
|
|
- title="LLama-3.2 Vision-Instruct",
|
|
|
- description="Generate descriptive text from an image using the LLama model.",
|
|
|
- theme="default",
|
|
|
+ # Generate output with model
|
|
|
+ output = model.generate(
|
|
|
+ **inputs,
|
|
|
+ max_new_tokens=min(max_tokens, MAX_OUTPUT_TOKENS),
|
|
|
+ temperature=temperature,
|
|
|
+ top_k=top_k,
|
|
|
+ top_p=top_p
|
|
|
)
|
|
|
+
|
|
|
+ # Decode the raw output
|
|
|
+ raw_output = processor.decode(output[0])
|
|
|
+
|
|
|
+ # Clean up the output to remove system tokens
|
|
|
+ cleaned_output = raw_output.replace("<|image|><|begin_of_text|>", "").strip().replace(" Answer:", "")
|
|
|
+
|
|
|
|
|
|
- # Launch the Gradio interface
|
|
|
- interface.launch()
|
|
|
+ # Ensure the prompt is not repeated in the output
|
|
|
+ if cleaned_output.startswith(user_prompt):
|
|
|
+ cleaned_output = cleaned_output[len(user_prompt):].strip()
|
|
|
+
|
|
|
+ # Append the new conversation to the history
|
|
|
+ history.append((user_prompt, cleaned_output))
|
|
|
|
|
|
+ return history
|
|
|
|
|
|
-if __name__ == "__main__":
|
|
|
- # Parse command-line arguments
|
|
|
- parser = argparse.ArgumentParser(description="Run LLama-3.2 Vision-Instruct with HF token passed via arguments.")
|
|
|
- parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token to access the model")
|
|
|
- args = parser.parse_args()
|
|
|
+# Function to clear the chat history
|
|
|
+def clear_chat():
|
|
|
+ return []
|
|
|
|
|
|
- # Pass the HF token to Gradio interface
|
|
|
- create_gradio_interface(hf_token=args.hf_token)
|
|
|
+# Gradio Interface
|
|
|
+def gradio_interface():
|
|
|
+ with gr.Blocks(visual_theme) as demo:
|
|
|
+ gr.HTML(
|
|
|
+ """
|
|
|
+ <h1 style='text-align: center'>
|
|
|
+ Clean-UI
|
|
|
+ </h1>
|
|
|
+ """)
|
|
|
+ with gr.Row():
|
|
|
+ # Left column with image and parameter inputs
|
|
|
+ with gr.Column(scale=1):
|
|
|
+ image_input = gr.Image(
|
|
|
+ label="Image",
|
|
|
+ type="pil",
|
|
|
+ image_mode="RGB",
|
|
|
+ height=512, # Set the height
|
|
|
+ width=512 # Set the width
|
|
|
+ )
|
|
|
+
|
|
|
+ # Parameter sliders
|
|
|
+ temperature = gr.Slider(
|
|
|
+ label="Temperature", minimum=0.1, maximum=2.0, value=0.6, step=0.1, interactive=True)
|
|
|
+ top_k = gr.Slider(
|
|
|
+ label="Top-k", minimum=1, maximum=100, value=50, step=1, interactive=True)
|
|
|
+ top_p = gr.Slider(
|
|
|
+ label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1, interactive=True)
|
|
|
+ max_tokens = gr.Slider(
|
|
|
+ label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50, interactive=True)
|
|
|
+
|
|
|
+ # Right column with the chat interface
|
|
|
+ with gr.Column(scale=2):
|
|
|
+ chat_history = gr.Chatbot(label="Chat", height=512)
|
|
|
+
|
|
|
+ # User input box for prompt
|
|
|
+ user_prompt = gr.Textbox(
|
|
|
+ show_label=False,
|
|
|
+ container=False,
|
|
|
+ placeholder="Enter your prompt",
|
|
|
+ lines=2
|
|
|
+ )
|
|
|
+
|
|
|
+ # Generate and Clear buttons
|
|
|
+ with gr.Row():
|
|
|
+ generate_button = gr.Button("Generate")
|
|
|
+ clear_button = gr.Button("Clear")
|
|
|
+
|
|
|
+ # Define the action for the generate button
|
|
|
+ generate_button.click(
|
|
|
+ fn=describe_image,
|
|
|
+ inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history],
|
|
|
+ outputs=[chat_history]
|
|
|
+ )
|
|
|
+
|
|
|
+ # Define the action for the clear button
|
|
|
+ clear_button.click(
|
|
|
+ fn=clear_chat,
|
|
|
+ inputs=[],
|
|
|
+ outputs=[chat_history]
|
|
|
+ )
|
|
|
+
|
|
|
+ return demo
|
|
|
+
|
|
|
+# Launch the interface
|
|
|
+demo = gradio_interface()
|
|
|
+demo.launch()
|