|
@@ -1,108 +1,206 @@
|
|
|
import argparse
|
|
|
import os
|
|
|
import sys
|
|
|
-
|
|
|
import torch
|
|
|
from accelerate import Accelerator
|
|
|
from PIL import Image as PIL_Image
|
|
|
from transformers import MllamaForConditionalGeneration, MllamaProcessor
|
|
|
-
|
|
|
+from peft import PeftModel
|
|
|
+import gradio as gr
|
|
|
+from huggingface_hub import HfFolder
|
|
|
+# Initialize accelerator
|
|
|
accelerator = Accelerator()
|
|
|
-
|
|
|
device = accelerator.device
|
|
|
|
|
|
# Constants
|
|
|
DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
|
|
+MAX_OUTPUT_TOKENS = 2048
|
|
|
+MAX_IMAGE_SIZE = (1120, 1120)
|
|
|
+
|
|
|
+
|
|
|
+def get_hf_token():
|
|
|
+ """Retrieve Hugging Face token from the cache or environment."""
|
|
|
+ # Check if a token is explicitly set in the environment
|
|
|
+ token = os.getenv("HUGGINGFACE_TOKEN")
|
|
|
+ if token:
|
|
|
+ return token
|
|
|
+
|
|
|
+ # Automatically retrieve the token from the Hugging Face cache (set via huggingface-cli login)
|
|
|
+ token = HfFolder.get_token()
|
|
|
+ if token:
|
|
|
+ return token
|
|
|
|
|
|
+ print("Hugging Face token not found. Please login using `huggingface-cli login`.")
|
|
|
+ sys.exit(1)
|
|
|
|
|
|
-def load_model_and_processor(model_name: str):
|
|
|
- """
|
|
|
- Load the model and processor based on the 11B or 90B model.
|
|
|
- """
|
|
|
+
|
|
|
+def load_model_and_processor(model_name: str, finetuning_path: str = None):
|
|
|
+ """Load model and processor with optional LoRA adapter"""
|
|
|
+ print(f"Loading model: {model_name}")
|
|
|
+ hf_token = get_hf_token()
|
|
|
model = MllamaForConditionalGeneration.from_pretrained(
|
|
|
model_name,
|
|
|
torch_dtype=torch.bfloat16,
|
|
|
use_safetensors=True,
|
|
|
device_map=device,
|
|
|
+ token=hf_token
|
|
|
)
|
|
|
- processor = MllamaProcessor.from_pretrained(model_name, use_safetensors=True)
|
|
|
-
|
|
|
+ processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
|
|
|
+
|
|
|
+ if finetuning_path and os.path.exists(finetuning_path):
|
|
|
+ print(f"Loading LoRA adapter from '{finetuning_path}'...")
|
|
|
+ model = PeftModel.from_pretrained(
|
|
|
+ model,
|
|
|
+ finetuning_path,
|
|
|
+ is_adapter=True,
|
|
|
+ torch_dtype=torch.bfloat16
|
|
|
+ )
|
|
|
+ print("LoRA adapter merged successfully")
|
|
|
+
|
|
|
model, processor = accelerator.prepare(model, processor)
|
|
|
return model, processor
|
|
|
|
|
|
+def process_image(image_path: str = None, image = None) -> PIL_Image.Image:
|
|
|
+ """Process and validate image input"""
|
|
|
+ if image is not None:
|
|
|
+ return image.convert("RGB")
|
|
|
+ if image_path and os.path.exists(image_path):
|
|
|
+ return PIL_Image.open(image_path).convert("RGB")
|
|
|
+ raise ValueError("No valid image provided")
|
|
|
|
|
|
-def process_image(image_path: str) -> PIL_Image.Image:
|
|
|
- """
|
|
|
- Open and convert an image from the specified path.
|
|
|
- """
|
|
|
- if not os.path.exists(image_path):
|
|
|
- print(f"The image file '{image_path}' does not exist.")
|
|
|
- sys.exit(1)
|
|
|
- with open(image_path, "rb") as f:
|
|
|
- return PIL_Image.open(f).convert("RGB")
|
|
|
-
|
|
|
-
|
|
|
-def generate_text_from_image(
|
|
|
- model, processor, image, prompt_text: str, temperature: float, top_p: float
|
|
|
-):
|
|
|
- """
|
|
|
- Generate text from an image using the model and processor.
|
|
|
- """
|
|
|
+def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
|
|
|
+ """Generate text from image using model"""
|
|
|
conversation = [
|
|
|
- {
|
|
|
- "role": "user",
|
|
|
- "content": [{"type": "image"}, {"type": "text", "text": prompt_text}],
|
|
|
- }
|
|
|
+ {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
|
|
|
]
|
|
|
- prompt = processor.apply_chat_template(
|
|
|
- conversation, add_generation_prompt=True, tokenize=False
|
|
|
- )
|
|
|
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
|
|
inputs = processor(image, prompt, return_tensors="pt").to(device)
|
|
|
- output = model.generate(
|
|
|
- **inputs, temperature=temperature, top_p=top_p, max_new_tokens=512
|
|
|
- )
|
|
|
- return processor.decode(output[0])[len(prompt) :]
|
|
|
-
|
|
|
-
|
|
|
-def main(
|
|
|
- image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str
|
|
|
-):
|
|
|
- """
|
|
|
- Call all the functions.
|
|
|
- """
|
|
|
- model, processor = load_model_and_processor(model_name)
|
|
|
- image = process_image(image_path)
|
|
|
- result = generate_text_from_image(
|
|
|
- model, processor, image, prompt_text, temperature, top_p
|
|
|
- )
|
|
|
- print("Generated Text: " + result)
|
|
|
-
|
|
|
+ output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS)
|
|
|
+ return processor.decode(output[0])[len(prompt):]
|
|
|
+
|
|
|
+def gradio_interface(model_name: str):
|
|
|
+ """Create Gradio UI with LoRA support"""
|
|
|
+ # Initialize model state
|
|
|
+ current_model = {"model": None, "processor": None}
|
|
|
+
|
|
|
+ def load_or_reload_model(enable_lora: bool, lora_path: str = None):
|
|
|
+ current_model["model"], current_model["processor"] = load_model_and_processor(
|
|
|
+ model_name,
|
|
|
+ lora_path if enable_lora else None
|
|
|
+ )
|
|
|
+ return "Model loaded successfully" + (" with LoRA" if enable_lora else "")
|
|
|
+
|
|
|
+ def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
|
|
|
+ if image is not None:
|
|
|
+ try:
|
|
|
+ processed_image = process_image(image=image)
|
|
|
+ result = generate_text_from_image(
|
|
|
+ current_model["model"],
|
|
|
+ current_model["processor"],
|
|
|
+ processed_image,
|
|
|
+ user_prompt,
|
|
|
+ temperature,
|
|
|
+ top_p
|
|
|
+ )
|
|
|
+ history.append((user_prompt, result))
|
|
|
+ except Exception as e:
|
|
|
+ history.append((user_prompt, f"Error: {str(e)}"))
|
|
|
+ return history
|
|
|
+
|
|
|
+ def clear_chat():
|
|
|
+ return []
|
|
|
+
|
|
|
+ with gr.Blocks() as demo:
|
|
|
+ gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>")
|
|
|
+
|
|
|
+ with gr.Row():
|
|
|
+ with gr.Column(scale=1):
|
|
|
+ # Model loading controls
|
|
|
+ with gr.Group():
|
|
|
+ enable_lora = gr.Checkbox(label="Enable LoRA", value=False)
|
|
|
+ lora_path = gr.Textbox(
|
|
|
+ label="LoRA Weights Path",
|
|
|
+ placeholder="Path to LoRA weights folder",
|
|
|
+ visible=False
|
|
|
+ )
|
|
|
+ load_status = gr.Textbox(label="Load Status", interactive=False)
|
|
|
+ load_button = gr.Button("Load/Reload Model")
|
|
|
+
|
|
|
+ # Image and parameter controls
|
|
|
+ image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512)
|
|
|
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1)
|
|
|
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1)
|
|
|
+ top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
|
|
|
+ max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50)
|
|
|
+
|
|
|
+ with gr.Column(scale=2):
|
|
|
+ chat_history = gr.Chatbot(label="Chat", height=512)
|
|
|
+ user_prompt = gr.Textbox(
|
|
|
+ show_label=False,
|
|
|
+ placeholder="Enter your prompt",
|
|
|
+ lines=2
|
|
|
+ )
|
|
|
+
|
|
|
+ with gr.Row():
|
|
|
+ generate_button = gr.Button("Generate")
|
|
|
+ clear_button = gr.Button("Clear")
|
|
|
+
|
|
|
+ # Event handlers
|
|
|
+ enable_lora.change(
|
|
|
+ fn=lambda x: gr.update(visible=x),
|
|
|
+ inputs=[enable_lora],
|
|
|
+ outputs=[lora_path]
|
|
|
+ )
|
|
|
+
|
|
|
+ load_button.click(
|
|
|
+ fn=load_or_reload_model,
|
|
|
+ inputs=[enable_lora, lora_path],
|
|
|
+ outputs=[load_status]
|
|
|
+ )
|
|
|
+
|
|
|
+ generate_button.click(
|
|
|
+ fn=describe_image,
|
|
|
+ inputs=[
|
|
|
+ image_input, user_prompt, temperature,
|
|
|
+ top_k, top_p, max_tokens, chat_history
|
|
|
+ ],
|
|
|
+ outputs=[chat_history]
|
|
|
+ )
|
|
|
+
|
|
|
+ clear_button.click(fn=clear_chat, outputs=[chat_history])
|
|
|
+
|
|
|
+ # Initial model load
|
|
|
+ load_or_reload_model(False)
|
|
|
+ return demo
|
|
|
+
|
|
|
+def main(args):
|
|
|
+ """Main execution flow"""
|
|
|
+ if args.gradio_ui:
|
|
|
+ demo = gradio_interface(args.model_name)
|
|
|
+ demo.launch()
|
|
|
+ else:
|
|
|
+ model, processor = load_model_and_processor(
|
|
|
+ args.model_name,
|
|
|
+ args.finetuning_path
|
|
|
+ )
|
|
|
+ image = process_image(image_path=args.image_path)
|
|
|
+ result = generate_text_from_image(
|
|
|
+ model, processor, image,
|
|
|
+ args.prompt_text,
|
|
|
+ args.temperature,
|
|
|
+ args.top_p
|
|
|
+ )
|
|
|
+ print("Generated Text:", result)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- parser = argparse.ArgumentParser(
|
|
|
- description="Generate text from an image and prompt using the 3.2 MM Llama model."
|
|
|
- )
|
|
|
- parser.add_argument("--image_path", type=str, help="Path to the image file")
|
|
|
- parser.add_argument(
|
|
|
- "--prompt_text", type=str, help="Prompt text to describe the image"
|
|
|
- )
|
|
|
- parser.add_argument(
|
|
|
- "--temperature",
|
|
|
- type=float,
|
|
|
- default=0.7,
|
|
|
- help="Temperature for generation (default: 0.7)",
|
|
|
- )
|
|
|
- parser.add_argument(
|
|
|
- "--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)"
|
|
|
- )
|
|
|
- parser.add_argument(
|
|
|
- "--model_name",
|
|
|
- type=str,
|
|
|
- default=DEFAULT_MODEL,
|
|
|
- help=f"Model name (default: '{DEFAULT_MODEL}')",
|
|
|
- )
|
|
|
-
|
|
|
+ parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support")
|
|
|
+ parser.add_argument("--image_path", type=str, help="Path to the input image")
|
|
|
+ parser.add_argument("--prompt_text", type=str, help="Prompt text for the image")
|
|
|
+ parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
|
|
|
+ parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
|
|
|
+ parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name")
|
|
|
+ parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights")
|
|
|
+ parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI")
|
|
|
+
|
|
|
args = parser.parse_args()
|
|
|
- main(
|
|
|
- args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name
|
|
|
- )
|
|
|
+ main(args)
|