import argparse import os import sys import gradio as gr import torch from accelerate import Accelerator from huggingface_hub import HfFolder from peft import PeftModel from PIL import Image as PIL_Image from transformers import MllamaForConditionalGeneration, MllamaProcessor # Initialize accelerator accelerator = Accelerator() device = accelerator.device # Constants DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct" MAX_OUTPUT_TOKENS = 2048 MAX_IMAGE_SIZE = (1120, 1120) def get_hf_token(): """Retrieve Hugging Face token from the cache or environment.""" # Check if a token is explicitly set in the environment token = os.getenv("HUGGINGFACE_TOKEN") if token: return token # Automatically retrieve the token from the Hugging Face cache (set via huggingface-cli login) token = HfFolder.get_token() if token: return token print("Hugging Face token not found. Please login using `huggingface-cli login`.") sys.exit(1) def load_model_and_processor(model_name: str, finetuning_path: str = None): """Load model and processor with optional LoRA adapter""" print(f"Loading model: {model_name}") hf_token = get_hf_token() model = MllamaForConditionalGeneration.from_pretrained( model_name, torch_dtype=torch.bfloat16, use_safetensors=True, device_map=device, token=hf_token, ) processor = MllamaProcessor.from_pretrained( model_name, token=hf_token, use_safetensors=True ) if finetuning_path and os.path.exists(finetuning_path): print(f"Loading LoRA adapter from '{finetuning_path}'...") model = PeftModel.from_pretrained( model, finetuning_path, is_adapter=True, torch_dtype=torch.bfloat16 ) print("LoRA adapter merged successfully") model, processor = accelerator.prepare(model, processor) return model, processor def process_image(image_path: str = None, image=None) -> PIL_Image.Image: """Process and validate image input""" if image is not None: return image.convert("RGB") if image_path and os.path.exists(image_path): return PIL_Image.open(image_path).convert("RGB") raise ValueError("No valid image provided") def generate_text_from_image( model, processor, image, prompt_text: str, temperature: float, top_p: float ): """Generate text from image using model""" conversation = [ { "role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}], } ] prompt = processor.apply_chat_template( conversation, add_generation_prompt=True, tokenize=False ) inputs = processor( image, prompt, text_kwargs={"add_special_tokens": False}, return_tensors="pt" ).to(device) print("Input Prompt:\n", processor.tokenizer.decode(inputs.input_ids[0])) output = model.generate( **inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS ) return processor.decode(output[0])[len(prompt) :] def gradio_interface(model_name: str): """Create Gradio UI with LoRA support""" # Initialize model state current_model = {"model": None, "processor": None} def load_or_reload_model(enable_lora: bool, lora_path: str = None): current_model["model"], current_model["processor"] = load_model_and_processor( model_name, lora_path if enable_lora else None ) return "Model loaded successfully" + (" with LoRA" if enable_lora else "") def describe_image( image, user_prompt, temperature, top_k, top_p, max_tokens, history ): if image is not None: try: processed_image = process_image(image=image) result = generate_text_from_image( current_model["model"], current_model["processor"], processed_image, user_prompt, temperature, top_p, ) history.append((user_prompt, result)) except Exception as e: history.append((user_prompt, f"Error: {str(e)}")) return history def clear_chat(): return [] with gr.Blocks() as demo: gr.HTML("

Llama Vision Model Interface

") with gr.Row(): with gr.Column(scale=1): # Model loading controls with gr.Group(): enable_lora = gr.Checkbox(label="Enable LoRA", value=False) lora_path = gr.Textbox( label="LoRA Weights Path", placeholder="Path to LoRA weights folder", visible=False, ) load_status = gr.Textbox(label="Load Status", interactive=False) load_button = gr.Button("Load/Reload Model") # Image and parameter controls image_input = gr.Image( label="Image", type="pil", image_mode="RGB", height=512, width=512 ) temperature = gr.Slider( label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1 ) top_k = gr.Slider( label="Top-k", minimum=1, maximum=100, value=50, step=1 ) top_p = gr.Slider( label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1 ) max_tokens = gr.Slider( label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50, ) with gr.Column(scale=2): chat_history = gr.Chatbot(label="Chat", height=512) user_prompt = gr.Textbox( show_label=False, placeholder="Enter your prompt", lines=2 ) with gr.Row(): generate_button = gr.Button("Generate") clear_button = gr.Button("Clear") # Event handlers enable_lora.change( fn=lambda x: gr.update(visible=x), inputs=[enable_lora], outputs=[lora_path] ) load_button.click( fn=load_or_reload_model, inputs=[enable_lora, lora_path], outputs=[load_status], ) generate_button.click( fn=describe_image, inputs=[ image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history, ], outputs=[chat_history], ) clear_button.click(fn=clear_chat, outputs=[chat_history]) # Initial model load load_or_reload_model(False) return demo def main(args): """Main execution flow""" if args.gradio_ui: demo = gradio_interface(args.model_name) demo.launch() else: model, processor = load_model_and_processor( args.model_name, args.finetuning_path ) image = process_image(image_path=args.image_path) result = generate_text_from_image( model, processor, image, args.prompt_text, args.temperature, args.top_p ) print("Generated Text:", result) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Multi-modal inference with optional Gradio UI and LoRA support" ) parser.add_argument("--image_path", type=str, help="Path to the input image") parser.add_argument("--prompt_text", type=str, help="Prompt text for the image") parser.add_argument( "--temperature", type=float, default=0.7, help="Sampling temperature" ) parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling") parser.add_argument( "--model_name", type=str, default=DEFAULT_MODEL, help="Model name" ) parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights") parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI") args = parser.parse_args() main(args)