multi_modal_infer_Gradio_UI.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import os
  2. from PIL import Image as PIL_Image
  3. import torch
  4. from transformers import MllamaForConditionalGeneration, MllamaProcessor
  5. from accelerate import Accelerator
  6. import gradio as gr
  7. import gc # Import garbage collector
  8. # Constants
  9. DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
  10. class LlamaInference:
  11. def __init__(self, model_name=DEFAULT_MODEL, hf_token=None):
  12. """
  13. Initialize the inference class. Load model and processor.
  14. """
  15. self.accelerator = Accelerator()
  16. self.device = self.accelerator.device
  17. self.model_name = model_name
  18. self.hf_token = hf_token or os.getenv("HF_TOKEN")
  19. if self.hf_token is None:
  20. raise ValueError("Error: Hugging Face token not found in environment. Please set the HF_TOKEN environment variable.")
  21. # Load model and processor
  22. self.model, self.processor = self.load_model_and_processor()
  23. def load_model_and_processor(self):
  24. """
  25. Load the model and processor based on the model name.
  26. """
  27. model = MllamaForConditionalGeneration.from_pretrained(self.model_name,
  28. torch_dtype=torch.bfloat16,
  29. use_safetensors=True,
  30. device_map=self.device,
  31. token=self.hf_token)
  32. processor = MllamaProcessor.from_pretrained(self.model_name,
  33. token=self.hf_token,
  34. use_safetensors=True)
  35. # Prepare model and processor with accelerator
  36. model, processor = self.accelerator.prepare(model, processor)
  37. return model, processor
  38. def process_image(self, image) -> PIL_Image.Image:
  39. """
  40. Open and convert an uploaded image to RGB format.
  41. """
  42. return image.convert("RGB")
  43. def generate_text_from_image(self, image, prompt_text: str, temperature: float, top_p: float):
  44. """
  45. Generate text from an image using the model and processor.
  46. """
  47. conversation = [
  48. {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
  49. ]
  50. prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
  51. inputs = self.processor(image, prompt, return_tensors="pt").to(self.device)
  52. # Perform inference without computing gradients to save memory
  53. with torch.no_grad():
  54. output = self.model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
  55. return self.processor.decode(output[0])[len(prompt):]
  56. def cleanup(self):
  57. """
  58. Clean up instance variables to release memory.
  59. """
  60. # Move model and processor to CPU before deleting to free up GPU memory
  61. self.model.to('cpu')
  62. del self.model
  63. del self.processor
  64. torch.cuda.empty_cache() # Release GPU memory
  65. gc.collect() # Force garbage collection
  66. # Clear other instance variables
  67. del self.accelerator
  68. del self.device
  69. del self.hf_token
  70. print("Cleanup complete. Instance variables deleted and memory cleared.")
  71. def inference(image, prompt_text, temperature, top_p):
  72. """
  73. Main inference function to handle Gradio inputs and manage memory cleanup.
  74. """
  75. # Initialize the inference instance (this loads the model)
  76. llama_inference = LlamaInference()
  77. try:
  78. # Process the image and generate text
  79. processed_image = llama_inference.process_image(image)
  80. result = llama_inference.generate_text_from_image(processed_image, prompt_text, temperature, top_p)
  81. finally:
  82. # Perform memory cleanup
  83. llama_inference.cleanup()
  84. return result
  85. # Gradio UI
  86. def create_gradio_interface():
  87. """
  88. Create the Gradio interface for image-to-text generation.
  89. """
  90. # Define the input components
  91. image_input = gr.Image(type="pil", label="Upload Image")
  92. prompt_input = gr.Textbox(lines=2, placeholder="Enter your prompt text", label="Prompt")
  93. temperature_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature")
  94. top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P")
  95. # Define the output component
  96. output_text = gr.Textbox(label="Generated Text")
  97. # Create the interface
  98. interface = gr.Interface(
  99. fn=inference,
  100. inputs=[image_input, prompt_input, temperature_input, top_p_input],
  101. outputs=output_text,
  102. title="LLama-3.2 Vision-Instruct",
  103. description="Generate descriptive text from an image using the LLama model.",
  104. theme="default",
  105. )
  106. # Launch the Gradio interface
  107. interface.launch()
  108. if __name__ == "__main__":
  109. create_gradio_interface()