Переглянути джерело

fixed gradio UI during performing the tests, it is working in this commit

Himanshu Shukla 5 місяців тому
батько
коміт
6b6bb37ecd

+ 158 - 84
recipes/quickstart/inference/local_inference/multi_modal_infer.py

@@ -1,117 +1,191 @@
+import argparse
 import os
 import sys
-import argparse
-from PIL import Image as PIL_Image
 import torch
-from transformers import MllamaForConditionalGeneration, MllamaProcessor
 from accelerate import Accelerator
-from peft import PeftModel  # Make sure to install the `peft` library
+from PIL import Image as PIL_Image
+from transformers import MllamaForConditionalGeneration, MllamaProcessor
+from peft import PeftModel
+import gradio as gr
 
+# Initialize accelerator
 accelerator = Accelerator()
 device = accelerator.device
 
 # Constants
 DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
+MAX_OUTPUT_TOKENS = 2048
+MAX_IMAGE_SIZE = (1120, 1120)
 
-
-def load_model_and_processor(model_name: str, hf_token: str, finetuning_path: str = None):
-    """
-    Load the model and processor, and optionally load adapter weights if specified
-    """
-    # Load pre-trained model and processor
+def load_model_and_processor(model_name: str, hf_token: str = None, finetuning_path: str = None):
+    """Load model and processor with optional LoRA adapter"""
+    print(f"Loading model: {model_name}")
     model = MllamaForConditionalGeneration.from_pretrained(
-        model_name, 
-        torch_dtype=torch.bfloat16, 
-        use_safetensors=True, 
+        model_name,
+        torch_dtype=torch.bfloat16,
+        use_safetensors=True,
         device_map=device,
         token=hf_token
     )
-    processor = MllamaProcessor.from_pretrained(
-        model_name, 
-        token=hf_token, 
-        use_safetensors=True
-    )
+    processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
 
-    # If a finetuning path is provided, load the adapter model
     if finetuning_path and os.path.exists(finetuning_path):
-        adapter_weights_path = os.path.join(finetuning_path, "adapter_model.safetensors")
-        adapter_config_path = os.path.join(finetuning_path, "adapter_config.json")
-
-        if os.path.exists(adapter_weights_path) and os.path.exists(adapter_config_path):
-            print(f"Loading adapter from '{finetuning_path}'...")
-            # Load the model with adapters using `peft`
-            model = PeftModel.from_pretrained(
-                model,
-                finetuning_path,  # This should be the folder containing the adapter files
-                is_adapter=True,
-                torch_dtype=torch.bfloat16
-            )
-
-            print("Adapter merged successfully with the pre-trained model.")
-        else:
-            print(f"Adapter files not found in '{finetuning_path}'. Using pre-trained model only.")
-    else:
-        print(f"No fine-tuned weights or adapters found in '{finetuning_path}'. Using pre-trained model only.")
-
-    # Prepare the model and processor for accelerated training
-    model, processor = accelerator.prepare(model, processor)
+        print(f"Loading LoRA adapter from '{finetuning_path}'...")
+        model = PeftModel.from_pretrained(
+            model,
+            finetuning_path,
+            is_adapter=True,
+            torch_dtype=torch.bfloat16
+        )
+        print("LoRA adapter merged successfully")
     
+    model, processor = accelerator.prepare(model, processor)
     return model, processor
 
-
-def process_image(image_path: str) -> PIL_Image.Image:
-    """
-    Open and convert an image from the specified path.
-    """
-    if not os.path.exists(image_path):
-        print(f"The image file '{image_path}' does not exist.")
-        sys.exit(1)
-    with open(image_path, "rb") as f:
-        return PIL_Image.open(f).convert("RGB")
-
+def process_image(image_path: str = None, image = None) -> PIL_Image.Image:
+    """Process and validate image input"""
+    if image is not None:
+        return image.convert("RGB")
+    if image_path and os.path.exists(image_path):
+        return PIL_Image.open(image_path).convert("RGB")
+    raise ValueError("No valid image provided")
 
 def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
-    """
-    Generate text from an image using the model and processor.
-    """
+    """Generate text from image using model"""
     conversation = [
         {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
     ]
     prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
     inputs = processor(image, prompt, return_tensors="pt").to(device)
-    output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=2048)
+    output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS)
     return processor.decode(output[0])[len(prompt):]
 
-
-def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str, hf_token: str, finetuning_path: str = None):
-    """
-    Call all the functions and optionally merge adapter weights from a specified path.
-    """
-    model, processor = load_model_and_processor(model_name, hf_token, finetuning_path)
-    image = process_image(image_path)
-    result = generate_text_from_image(model, processor, image, prompt_text, temperature, top_p)
-    print("Generated Text: " + result)
-
+def gradio_interface(model_name: str, hf_token: str):
+    """Create Gradio UI with LoRA support"""
+    # Initialize model state
+    current_model = {"model": None, "processor": None}
+    
+    def load_or_reload_model(enable_lora: bool, lora_path: str = None):
+        current_model["model"], current_model["processor"] = load_model_and_processor(
+            model_name, 
+            hf_token,
+            lora_path if enable_lora else None
+        )
+        return "Model loaded successfully" + (" with LoRA" if enable_lora else "")
+
+    def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
+        if image is not None:
+            try:
+                processed_image = process_image(image=image)
+                result = generate_text_from_image(
+                    current_model["model"],
+                    current_model["processor"],
+                    processed_image,
+                    user_prompt,
+                    temperature,
+                    top_p
+                )
+                history.append((user_prompt, result))
+            except Exception as e:
+                history.append((user_prompt, f"Error: {str(e)}"))
+        return history
+
+    def clear_chat():
+        return []
+
+    with gr.Blocks() as demo:
+        gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>")
+        
+        with gr.Row():
+            with gr.Column(scale=1):
+                # Model loading controls
+                with gr.Group():
+                    enable_lora = gr.Checkbox(label="Enable LoRA", value=False)
+                    lora_path = gr.Textbox(
+                        label="LoRA Weights Path",
+                        placeholder="Path to LoRA weights folder",
+                        visible=False
+                    )
+                    load_status = gr.Textbox(label="Load Status", interactive=False)
+                    load_button = gr.Button("Load/Reload Model")
+
+                # Image and parameter controls
+                image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512)
+                temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1)
+                top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1)
+                top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
+                max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50)
+
+            with gr.Column(scale=2):
+                chat_history = gr.Chatbot(label="Chat", height=512)
+                user_prompt = gr.Textbox(
+                    show_label=False,
+                    placeholder="Enter your prompt",
+                    lines=2
+                )
+                
+                with gr.Row():
+                    generate_button = gr.Button("Generate")
+                    clear_button = gr.Button("Clear")
+
+        # Event handlers
+        enable_lora.change(
+            fn=lambda x: gr.update(visible=x),
+            inputs=[enable_lora],
+            outputs=[lora_path]
+        )
+        
+        load_button.click(
+            fn=load_or_reload_model,
+            inputs=[enable_lora, lora_path],
+            outputs=[load_status]
+        )
+
+        generate_button.click(
+            fn=describe_image,
+            inputs=[
+                image_input, user_prompt, temperature,
+                top_k, top_p, max_tokens, chat_history
+            ],
+            outputs=[chat_history]
+        )
+        
+        clear_button.click(fn=clear_chat, outputs=[chat_history])
+
+    # Initial model load
+    load_or_reload_model(False)
+    return demo
+
+def main(args):
+    """Main execution flow"""
+    if args.gradio_ui:
+        demo = gradio_interface(args.model_name, args.hf_token)
+        demo.launch()
+    else:
+        model, processor = load_model_and_processor(
+            args.model_name,
+            args.hf_token,
+            args.finetuning_path
+        )
+        image = process_image(image_path=args.image_path)
+        result = generate_text_from_image(
+            model, processor, image,
+            args.prompt_text,
+            args.temperature,
+            args.top_p
+        )
+        print("Generated Text:", result)
 
 if __name__ == "__main__":
-    # Example usage with argparse (optional)
-    parser = argparse.ArgumentParser(description="Generate text from an image using a fine-tuned model with adapters.")
-    parser.add_argument("--image_path", type=str, required=True, help="Path to the input image.")
-    parser.add_argument("--prompt_text", type=str, required=True, help="Prompt text for the image.")
-    parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.")
-    parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling.")
-    parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Pre-trained model name.")
-    parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face API token.")
-    parser.add_argument("--finetuning_path", type=str, help="Path to the fine-tuning weights (adapters).")
+    parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support")
+    parser.add_argument("--image_path", type=str, help="Path to the input image")
+    parser.add_argument("--prompt_text", type=str, help="Prompt text for the image")
+    parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
+    parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
+    parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name")
+    parser.add_argument("--hf_token", type=str, help="Hugging Face API token")
+    parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights")
+    parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI")
     
     args = parser.parse_args()
-
-    main(
-        image_path=args.image_path,
-        prompt_text=args.prompt_text,
-        temperature=args.temperature,
-        top_p=args.top_p,
-        model_name=args.model_name,
-        hf_token=args.hf_token,
-        finetuning_path=args.finetuning_path
-    )
+    main(args)