multi_modal_infer_Gradio_UI.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os
  2. from PIL import Image as PIL_Image
  3. import torch
  4. from transformers import MllamaForConditionalGeneration, MllamaProcessor
  5. from accelerate import Accelerator
  6. import gradio as gr
  7. import gc # Import garbage collector
  8. import argparse # Import argparse for command-line arguments
  9. # Constants
  10. DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
  11. class LlamaInference:
  12. def __init__(self, model_name=DEFAULT_MODEL, hf_token=None):
  13. """
  14. Initialize the inference class. Load model and processor.
  15. """
  16. self.accelerator = Accelerator()
  17. self.device = self.accelerator.device
  18. self.model_name = model_name
  19. self.hf_token = hf_token
  20. if self.hf_token is None:
  21. raise ValueError("Error: Hugging Face token not provided.")
  22. # Load model and processor
  23. self.model, self.processor = self.load_model_and_processor()
  24. def load_model_and_processor(self):
  25. """
  26. Load the model and processor based on the model name.
  27. """
  28. model = MllamaForConditionalGeneration.from_pretrained(self.model_name,
  29. torch_dtype=torch.bfloat16,
  30. use_safetensors=True,
  31. device_map=self.device,
  32. token=self.hf_token)
  33. processor = MllamaProcessor.from_pretrained(self.model_name,
  34. token=self.hf_token,
  35. use_safetensors=True)
  36. # Prepare model and processor with accelerator
  37. model, processor = self.accelerator.prepare(model, processor)
  38. return model, processor
  39. def process_image(self, image) -> PIL_Image.Image:
  40. """
  41. Open and convert an uploaded image to RGB format.
  42. """
  43. return image.convert("RGB")
  44. def generate_text_from_image(self, image, prompt_text: str, temperature: float, top_p: float):
  45. """
  46. Generate text from an image using the model and processor.
  47. """
  48. conversation = [
  49. {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
  50. ]
  51. prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
  52. inputs = self.processor(image, prompt, return_tensors="pt").to(self.device)
  53. # Perform inference without computing gradients to save memory
  54. with torch.no_grad():
  55. output = self.model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
  56. return self.processor.decode(output[0])[len(prompt):]
  57. def cleanup(self):
  58. """
  59. Clean up instance variables to release memory.
  60. """
  61. # Move model and processor to CPU before deleting to free up GPU memory
  62. self.model.to('cpu')
  63. del self.model
  64. del self.processor
  65. torch.cuda.empty_cache() # Release GPU memory
  66. gc.collect() # Force garbage collection
  67. # Clear other instance variables
  68. del self.accelerator
  69. del self.device
  70. del self.hf_token
  71. print("Cleanup complete. Instance variables deleted and memory cleared.")
  72. def inference(image, prompt_text, temperature, top_p, hf_token):
  73. """
  74. Main inference function to handle Gradio inputs and manage memory cleanup.
  75. """
  76. # Initialize the inference instance (this loads the model)
  77. llama_inference = LlamaInference(hf_token=hf_token)
  78. try:
  79. # Process the image and generate text
  80. processed_image = llama_inference.process_image(image)
  81. result = llama_inference.generate_text_from_image(processed_image, prompt_text, temperature, top_p)
  82. finally:
  83. # Perform memory cleanup
  84. llama_inference.cleanup()
  85. return result
  86. # Gradio UI
  87. def create_gradio_interface(hf_token):
  88. """
  89. Create the Gradio interface for image-to-text generation.
  90. """
  91. # Define the input components
  92. image_input = gr.Image(type="pil", label="Upload Image")
  93. prompt_input = gr.Textbox(lines=2, placeholder="Enter your prompt text", label="Prompt")
  94. temperature_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature")
  95. top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P")
  96. # Define the output component
  97. output_text = gr.Textbox(label="Generated Text")
  98. # Create the interface
  99. interface = gr.Interface(
  100. fn=lambda image, prompt_text, temperature, top_p: inference(image, prompt_text, temperature, top_p, hf_token),
  101. inputs=[image_input, prompt_input, temperature_input, top_p_input],
  102. outputs=output_text,
  103. title="LLama-3.2 Vision-Instruct",
  104. description="Generate descriptive text from an image using the LLama model.",
  105. theme="default",
  106. )
  107. # Launch the Gradio interface
  108. interface.launch()
  109. if __name__ == "__main__":
  110. # Parse command-line arguments
  111. parser = argparse.ArgumentParser(description="Run LLama-3.2 Vision-Instruct with HF token passed via arguments.")
  112. parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token to access the model")
  113. args = parser.parse_args()
  114. # Pass the HF token to Gradio interface
  115. create_gradio_interface(hf_token=args.hf_token)