瀏覽代碼

added working in single file for 1. terminal inferencing, 2. gradio inferencing, 3. checkpoint inferencing

Himanshu Shukla 5 月之前
父節點
當前提交
6b1c0d582b
共有 1 個文件被更改,包括 96 次插入73 次删除
  1. 96 73
      recipes/quickstart/inference/local_inference/multi_modal_infer.py

+ 96 - 73
recipes/quickstart/inference/local_inference/multi_modal_infer.py

@@ -1,108 +1,131 @@
 import argparse
 import os
 import sys
-
 import torch
 from accelerate import Accelerator
 from PIL import Image as PIL_Image
 from transformers import MllamaForConditionalGeneration, MllamaProcessor
+from peft import PeftModel
+import gradio as gr
 
+# Initialize accelerator
 accelerator = Accelerator()
-
 device = accelerator.device
 
 # Constants
 DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
+MAX_OUTPUT_TOKENS = 2048
+MAX_IMAGE_SIZE = (1120, 1120)
 
-
-def load_model_and_processor(model_name: str):
-    """
-    Load the model and processor based on the 11B or 90B model.
-    """
+def load_model_and_processor(model_name: str, hf_token: str = None, finetuning_path: str = None):
+    """Load model and processor with optional LoRA adapter"""
     model = MllamaForConditionalGeneration.from_pretrained(
         model_name,
         torch_dtype=torch.bfloat16,
         use_safetensors=True,
         device_map=device,
+        token=hf_token
     )
-    processor = MllamaProcessor.from_pretrained(model_name, use_safetensors=True)
-
+    processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
+
+    if finetuning_path and os.path.exists(finetuning_path):
+        print(f"Loading adapter from '{finetuning_path}'...")
+        model = PeftModel.from_pretrained(
+            model,
+            finetuning_path,
+            is_adapter=True,
+            torch_dtype=torch.bfloat16
+        )
+        print("Adapter merged successfully")
+    
     model, processor = accelerator.prepare(model, processor)
     return model, processor
 
-
 def process_image(image_path: str) -> PIL_Image.Image:
-    """
-    Open and convert an image from the specified path.
-    """
+    """Process and validate image input"""
     if not os.path.exists(image_path):
-        print(f"The image file '{image_path}' does not exist.")
+        print(f"Image file '{image_path}' does not exist.")
         sys.exit(1)
-    with open(image_path, "rb") as f:
-        return PIL_Image.open(f).convert("RGB")
+    return PIL_Image.open(image_path).convert("RGB")
 
-
-def generate_text_from_image(
-    model, processor, image, prompt_text: str, temperature: float, top_p: float
-):
-    """
-    Generate text from an image using the model and processor.
-    """
+def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
+    """Generate text from image using model"""
     conversation = [
-        {
-            "role": "user",
-            "content": [{"type": "image"}, {"type": "text", "text": prompt_text}],
-        }
+        {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
     ]
-    prompt = processor.apply_chat_template(
-        conversation, add_generation_prompt=True, tokenize=False
-    )
+    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
     inputs = processor(image, prompt, return_tensors="pt").to(device)
-    output = model.generate(
-        **inputs, temperature=temperature, top_p=top_p, max_new_tokens=512
+    output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS)
+    return processor.decode(output[0])[len(prompt):]
+
+def gradio_interface(model, processor):
+    """Create Gradio UI"""
+    def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
+        if image is not None:
+            image = image.resize(MAX_IMAGE_SIZE)
+            result = generate_text_from_image(model, processor, image, user_prompt, temperature, top_p)
+            history.append((user_prompt, result))
+        return history
+
+    def clear_chat():
+        return []
+
+    with gr.Blocks() as demo:
+        gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>")
+        
+        with gr.Row():
+            with gr.Column(scale=1):
+                image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512)
+                temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1)
+                top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1)
+                top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
+                max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50)
+
+            with gr.Column(scale=2):
+                chat_history = gr.Chatbot(label="Chat", height=512)
+                user_prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", lines=2)
+                
+                with gr.Row():
+                    generate_button = gr.Button("Generate")
+                    clear_button = gr.Button("Clear")
+
+                generate_button.click(
+                    fn=describe_image,
+                    inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history],
+                    outputs=[chat_history]
+                )
+                clear_button.click(fn=clear_chat, outputs=[chat_history])
+
+    return demo
+
+def main(args):
+    """Main execution flow"""
+    model, processor = load_model_and_processor(
+        args.model_name,
+        args.hf_token,
+        args.finetuning_path
     )
-    return processor.decode(output[0])[len(prompt) :]
-
-
-def main(
-    image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str
-):
-    """
-    Call all the functions.
-    """
-    model, processor = load_model_and_processor(model_name)
-    image = process_image(image_path)
-    result = generate_text_from_image(
-        model, processor, image, prompt_text, temperature, top_p
-    )
-    print("Generated Text: " + result)
 
+    if args.gradio_ui:
+        demo = gradio_interface(model, processor)
+        demo.launch()
+    else:
+        image = process_image(args.image_path)
+        result = generate_text_from_image(
+            model, processor, image, args.prompt_text, args.temperature, args.top_p
+        )
+        print("Generated Text:", result)
 
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(
-        description="Generate text from an image and prompt using the 3.2 MM Llama model."
-    )
-    parser.add_argument("--image_path", type=str, help="Path to the image file")
-    parser.add_argument(
-        "--prompt_text", type=str, help="Prompt text to describe the image"
-    )
-    parser.add_argument(
-        "--temperature",
-        type=float,
-        default=0.7,
-        help="Temperature for generation (default: 0.7)",
-    )
-    parser.add_argument(
-        "--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)"
-    )
-    parser.add_argument(
-        "--model_name",
-        type=str,
-        default=DEFAULT_MODEL,
-        help=f"Model name (default: '{DEFAULT_MODEL}')",
-    )
-
+    parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support")
+    parser.add_argument("--image_path", type=str, help="Path to the input image")
+    parser.add_argument("--prompt_text", type=str, help="Prompt text for the image")
+    parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
+    parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
+    parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name")
+    parser.add_argument("--hf_token", type=str, help="Hugging Face API token")
+    parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights")
+    parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI")
+    
     args = parser.parse_args()
-    main(
-        args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name
-    )
+    main(args)