multi_modal_infer_Gradio_UI.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import os
  2. from PIL import Image as PIL_Image
  3. import torch
  4. from transformers import MllamaForConditionalGeneration, MllamaProcessor
  5. from accelerate import Accelerator
  6. import gradio as gr
  7. # Initialize accelerator
  8. accelerator = Accelerator()
  9. device = accelerator.device
  10. # Constants
  11. DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
  12. def load_model_and_processor(model_name: str, hf_token: str):
  13. """
  14. Load the model and processor based on the 11B or 90B model.
  15. """
  16. model = MllamaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, use_safetensors=True,
  17. device_map=device, token=hf_token)
  18. processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
  19. model, processor = accelerator.prepare(model, processor)
  20. return model, processor
  21. def process_image(image) -> PIL_Image.Image:
  22. """
  23. Open and convert an uploaded image to RGB format.
  24. """
  25. return image.convert("RGB")
  26. def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
  27. """
  28. Generate text from an image using the model and processor.
  29. """
  30. conversation = [
  31. {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
  32. ]
  33. prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
  34. inputs = processor(image, prompt, return_tensors="pt").to(device)
  35. output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
  36. return processor.decode(output[0])[len(prompt):]
  37. def inference(image, prompt_text, temperature, top_p):
  38. """
  39. Wrapper function to load the model and generate text based on inputs from Gradio UI.
  40. """
  41. hf_token = os.getenv("HF_TOKEN") # Get the Hugging Face token from the environment
  42. if hf_token is None:
  43. return "Error: Hugging Face token not found in environment. Please set the HF_TOKEN environment variable."
  44. model, processor = load_model_and_processor(DEFAULT_MODEL, hf_token)
  45. processed_image = process_image(image)
  46. result = generate_text_from_image(model, processor, processed_image, prompt_text, temperature, top_p)
  47. return result
  48. # Gradio UI
  49. def create_gradio_interface():
  50. """
  51. Create the Gradio interface for image-to-text generation.
  52. """
  53. # Define the input components
  54. image_input = gr.Image(type="pil", label="Upload Image")
  55. prompt_input = gr.Textbox(lines=2, placeholder="Enter your prompt text", label="Prompt")
  56. temperature_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature")
  57. top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P")
  58. # Define the output component
  59. output_text = gr.Textbox(label="Generated Text")
  60. # Create the interface
  61. interface = gr.Interface(
  62. fn=inference,
  63. inputs=[image_input, prompt_input, temperature_input, top_p_input],
  64. outputs=output_text,
  65. title="LLama-3.2 Vision-Instruct",
  66. description="Generate descriptive text from an image using the LLama model.",
  67. theme="default",
  68. )
  69. # Launch the Gradio interface
  70. interface.launch()
  71. if __name__ == "__main__":
  72. create_gradio_interface()