فهرست منبع

added working in single file for 1. terminal inferencing, 2. gradio inferencing, 3. checkpoint inferencing

Himanshu Shukla 11 ماه پیش
والد
کامیت
6b1c0d582b
1فایلهای تغییر یافته به همراه96 افزوده شده و 73 حذف شده
  1. 96 73
      recipes/quickstart/inference/local_inference/multi_modal_infer.py

+ 96 - 73
recipes/quickstart/inference/local_inference/multi_modal_infer.py

@@ -1,108 +1,131 @@
 import argparse
 import os
 import sys
-
 import torch
 from accelerate import Accelerator
 from PIL import Image as PIL_Image
 from transformers import MllamaForConditionalGeneration, MllamaProcessor
+from peft import PeftModel
+import gradio as gr
 
+# Initialize accelerator
 accelerator = Accelerator()
-
 device = accelerator.device
 
 # Constants
 DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
+MAX_OUTPUT_TOKENS = 2048
+MAX_IMAGE_SIZE = (1120, 1120)
 
-
-def load_model_and_processor(model_name: str):
-    """
-    Load the model and processor based on the 11B or 90B model.
-    """
+def load_model_and_processor(model_name: str, hf_token: str = None, finetuning_path: str = None):
+    """Load model and processor with optional LoRA adapter"""
     model = MllamaForConditionalGeneration.from_pretrained(
         model_name,
         torch_dtype=torch.bfloat16,
         use_safetensors=True,
         device_map=device,
+        token=hf_token
     )
-    processor = MllamaProcessor.from_pretrained(model_name, use_safetensors=True)
-
+    processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
+
+    if finetuning_path and os.path.exists(finetuning_path):
+        print(f"Loading adapter from '{finetuning_path}'...")
+        model = PeftModel.from_pretrained(
+            model,
+            finetuning_path,
+            is_adapter=True,
+            torch_dtype=torch.bfloat16
+        )
+        print("Adapter merged successfully")
+    
     model, processor = accelerator.prepare(model, processor)
     return model, processor
 
-
 def process_image(image_path: str) -> PIL_Image.Image:
-    """
-    Open and convert an image from the specified path.
-    """
+    """Process and validate image input"""
     if not os.path.exists(image_path):
-        print(f"The image file '{image_path}' does not exist.")
+        print(f"Image file '{image_path}' does not exist.")
         sys.exit(1)
-    with open(image_path, "rb") as f:
-        return PIL_Image.open(f).convert("RGB")
+    return PIL_Image.open(image_path).convert("RGB")
 
-
-def generate_text_from_image(
-    model, processor, image, prompt_text: str, temperature: float, top_p: float
-):
-    """
-    Generate text from an image using the model and processor.
-    """
+def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
+    """Generate text from image using model"""
     conversation = [
-        {
-            "role": "user",
-            "content": [{"type": "image"}, {"type": "text", "text": prompt_text}],
-        }
+        {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
     ]
-    prompt = processor.apply_chat_template(
-        conversation, add_generation_prompt=True, tokenize=False
-    )
+    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
     inputs = processor(image, prompt, return_tensors="pt").to(device)
-    output = model.generate(
-        **inputs, temperature=temperature, top_p=top_p, max_new_tokens=512
+    output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS)
+    return processor.decode(output[0])[len(prompt):]
+
+def gradio_interface(model, processor):
+    """Create Gradio UI"""
+    def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
+        if image is not None:
+            image = image.resize(MAX_IMAGE_SIZE)
+            result = generate_text_from_image(model, processor, image, user_prompt, temperature, top_p)
+            history.append((user_prompt, result))
+        return history
+
+    def clear_chat():
+        return []
+
+    with gr.Blocks() as demo:
+        gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>")
+        
+        with gr.Row():
+            with gr.Column(scale=1):
+                image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512)
+                temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1)
+                top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1)
+                top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
+                max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50)
+
+            with gr.Column(scale=2):
+                chat_history = gr.Chatbot(label="Chat", height=512)
+                user_prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", lines=2)
+                
+                with gr.Row():
+                    generate_button = gr.Button("Generate")
+                    clear_button = gr.Button("Clear")
+
+                generate_button.click(
+                    fn=describe_image,
+                    inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history],
+                    outputs=[chat_history]
+                )
+                clear_button.click(fn=clear_chat, outputs=[chat_history])
+
+    return demo
+
+def main(args):
+    """Main execution flow"""
+    model, processor = load_model_and_processor(
+        args.model_name,
+        args.hf_token,
+        args.finetuning_path
     )
-    return processor.decode(output[0])[len(prompt) :]
-
-
-def main(
-    image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str
-):
-    """
-    Call all the functions.
-    """
-    model, processor = load_model_and_processor(model_name)
-    image = process_image(image_path)
-    result = generate_text_from_image(
-        model, processor, image, prompt_text, temperature, top_p
-    )
-    print("Generated Text: " + result)
 
+    if args.gradio_ui:
+        demo = gradio_interface(model, processor)
+        demo.launch()
+    else:
+        image = process_image(args.image_path)
+        result = generate_text_from_image(
+            model, processor, image, args.prompt_text, args.temperature, args.top_p
+        )
+        print("Generated Text:", result)
 
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(
-        description="Generate text from an image and prompt using the 3.2 MM Llama model."
-    )
-    parser.add_argument("--image_path", type=str, help="Path to the image file")
-    parser.add_argument(
-        "--prompt_text", type=str, help="Prompt text to describe the image"
-    )
-    parser.add_argument(
-        "--temperature",
-        type=float,
-        default=0.7,
-        help="Temperature for generation (default: 0.7)",
-    )
-    parser.add_argument(
-        "--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)"
-    )
-    parser.add_argument(
-        "--model_name",
-        type=str,
-        default=DEFAULT_MODEL,
-        help=f"Model name (default: '{DEFAULT_MODEL}')",
-    )
-
+    parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support")
+    parser.add_argument("--image_path", type=str, help="Path to the input image")
+    parser.add_argument("--prompt_text", type=str, help="Prompt text for the image")
+    parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
+    parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
+    parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name")
+    parser.add_argument("--hf_token", type=str, help="Hugging Face API token")
+    parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights")
+    parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI")
+    
     args = parser.parse_args()
-    main(
-        args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name
-    )
+    main(args)