浏览代码

Added basic LlamaInference class structure, model loading functionality, image processing and text generation

himanshushukla12 6 月之前
父节点
当前提交
405371288b
共有 1 个文件被更改,包括 92 次插入0 次删除
  1. 92 0
      recipes/quickstart/inference/local_inference/multi_modal_infer_Gradio_UI.py

+ 92 - 0
recipes/quickstart/inference/local_inference/multi_modal_infer_Gradio_UI.py

@@ -0,0 +1,92 @@
+import os
+from PIL import Image as PIL_Image
+import torch
+from transformers import MllamaForConditionalGeneration, MllamaProcessor
+from accelerate import Accelerator
+import gradio as gr
+
+# Initialize accelerator
+accelerator = Accelerator()
+
+device = accelerator.device
+
+# Constants
+DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
+
+
+def load_model_and_processor(model_name: str, hf_token: str):
+    """
+    Load the model and processor based on the 11B or 90B model.
+    """
+    model = MllamaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, use_safetensors=True, 
+                                                            device_map=device, token=hf_token)
+    processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
+
+    model, processor = accelerator.prepare(model, processor)
+    return model, processor
+
+
+def process_image(image) -> PIL_Image.Image:
+    """
+    Open and convert an uploaded image to RGB format.
+    """
+    return image.convert("RGB")
+
+
+def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
+    """
+    Generate text from an image using the model and processor.
+    """
+    conversation = [
+        {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
+    ]
+    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
+    inputs = processor(image, prompt, return_tensors="pt").to(device)
+    output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
+    return processor.decode(output[0])[len(prompt):]
+
+
+def inference(image, prompt_text, temperature, top_p):
+    """
+    Wrapper function to load the model and generate text based on inputs from Gradio UI.
+    """
+    hf_token = os.getenv("HF_TOKEN")  # Get the Hugging Face token from the environment
+    if hf_token is None:
+        return "Error: Hugging Face token not found in environment. Please set the HF_TOKEN environment variable."
+    
+    model, processor = load_model_and_processor(DEFAULT_MODEL, hf_token)
+    processed_image = process_image(image)
+    result = generate_text_from_image(model, processor, processed_image, prompt_text, temperature, top_p)
+    return result
+
+
+# Gradio UI
+def create_gradio_interface():
+    """
+    Create the Gradio interface for image-to-text generation.
+    """
+    # Define the input components
+    image_input = gr.Image(type="pil", label="Upload Image")
+    prompt_input = gr.Textbox(lines=2, placeholder="Enter your prompt text", label="Prompt")
+    temperature_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature")
+    top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P")
+
+    # Define the output component
+    output_text = gr.Textbox(label="Generated Text")
+
+    # Create the interface
+    interface = gr.Interface(
+        fn=inference,
+        inputs=[image_input, prompt_input, temperature_input, top_p_input],
+        outputs=output_text,
+        title="LLama-3.2 Vision-Instruct",
+        description="Generate descriptive text from an image using the LLama model.",
+        theme="default",
+    )
+    
+    # Launch the Gradio interface
+    interface.launch()
+
+
+if __name__ == "__main__":
+    create_gradio_interface()