|
@@ -4,59 +4,102 @@ import torch
|
|
|
from transformers import MllamaForConditionalGeneration, MllamaProcessor
|
|
|
from accelerate import Accelerator
|
|
|
import gradio as gr
|
|
|
+import gc # Import garbage collector
|
|
|
|
|
|
-# Initialize accelerator
|
|
|
-accelerator = Accelerator()
|
|
|
-
|
|
|
-device = accelerator.device
|
|
|
|
|
|
# Constants
|
|
|
DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
|
|
|
|
|
-
|
|
|
-def load_model_and_processor(model_name: str, hf_token: str):
|
|
|
- """
|
|
|
- Load the model and processor based on the 11B or 90B model.
|
|
|
- """
|
|
|
- model = MllamaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, use_safetensors=True,
|
|
|
- device_map=device, token=hf_token)
|
|
|
- processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
|
|
|
-
|
|
|
- model, processor = accelerator.prepare(model, processor)
|
|
|
- return model, processor
|
|
|
-
|
|
|
-
|
|
|
-def process_image(image) -> PIL_Image.Image:
|
|
|
- """
|
|
|
- Open and convert an uploaded image to RGB format.
|
|
|
- """
|
|
|
- return image.convert("RGB")
|
|
|
+class LlamaInference:
|
|
|
+ def __init__(self, model_name=DEFAULT_MODEL, hf_token=None):
|
|
|
+ """
|
|
|
+ Initialize the inference class. Load model and processor.
|
|
|
+ """
|
|
|
+ self.accelerator = Accelerator()
|
|
|
+ self.device = self.accelerator.device
|
|
|
+
|
|
|
+ self.model_name = model_name
|
|
|
+ self.hf_token = hf_token or os.getenv("HF_TOKEN")
|
|
|
+
|
|
|
+ if self.hf_token is None:
|
|
|
+ raise ValueError("Error: Hugging Face token not found in environment. Please set the HF_TOKEN environment variable.")
|
|
|
+
|
|
|
+ # Load model and processor
|
|
|
+ self.model, self.processor = self.load_model_and_processor()
|
|
|
+
|
|
|
+ def load_model_and_processor(self):
|
|
|
+ """
|
|
|
+ Load the model and processor based on the model name.
|
|
|
+ """
|
|
|
+ model = MllamaForConditionalGeneration.from_pretrained(self.model_name,
|
|
|
+ torch_dtype=torch.bfloat16,
|
|
|
+ use_safetensors=True,
|
|
|
+ device_map=self.device,
|
|
|
+ token=self.hf_token)
|
|
|
+ processor = MllamaProcessor.from_pretrained(self.model_name,
|
|
|
+ token=self.hf_token,
|
|
|
+ use_safetensors=True)
|
|
|
+
|
|
|
+ # Prepare model and processor with accelerator
|
|
|
+ model, processor = self.accelerator.prepare(model, processor)
|
|
|
+ return model, processor
|
|
|
+
|
|
|
+ def process_image(self, image) -> PIL_Image.Image:
|
|
|
+ """
|
|
|
+ Open and convert an uploaded image to RGB format.
|
|
|
+ """
|
|
|
+ return image.convert("RGB")
|
|
|
+
|
|
|
+ def generate_text_from_image(self, image, prompt_text: str, temperature: float, top_p: float):
|
|
|
+ """
|
|
|
+ Generate text from an image using the model and processor.
|
|
|
+ """
|
|
|
+ conversation = [
|
|
|
+ {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
|
|
|
+ ]
|
|
|
+ prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
|
|
+ inputs = self.processor(image, prompt, return_tensors="pt").to(self.device)
|
|
|
+
|
|
|
+ # Perform inference without computing gradients to save memory
|
|
|
+ with torch.no_grad():
|
|
|
+ output = self.model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
|
|
|
+
|
|
|
+ return self.processor.decode(output[0])[len(prompt):]
|
|
|
+
|
|
|
+ def cleanup(self):
|
|
|
+ """
|
|
|
+ Clean up instance variables to release memory.
|
|
|
+ """
|
|
|
+ # Move model and processor to CPU before deleting to free up GPU memory
|
|
|
+ self.model.to('cpu')
|
|
|
+ del self.model
|
|
|
+ del self.processor
|
|
|
+ torch.cuda.empty_cache() # Release GPU memory
|
|
|
+ gc.collect() # Force garbage collection
|
|
|
+
|
|
|
+ # Clear other instance variables
|
|
|
+ del self.accelerator
|
|
|
+ del self.device
|
|
|
+ del self.hf_token
|
|
|
+
|
|
|
+ print("Cleanup complete. Instance variables deleted and memory cleared.")
|
|
|
|
|
|
|
|
|
-def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
|
|
|
+def inference(image, prompt_text, temperature, top_p):
|
|
|
"""
|
|
|
- Generate text from an image using the model and processor.
|
|
|
+ Main inference function to handle Gradio inputs and manage memory cleanup.
|
|
|
"""
|
|
|
- conversation = [
|
|
|
- {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
|
|
|
- ]
|
|
|
- prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
|
|
- inputs = processor(image, prompt, return_tensors="pt").to(device)
|
|
|
- output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
|
|
|
- return processor.decode(output[0])[len(prompt):]
|
|
|
+ # Initialize the inference instance (this loads the model)
|
|
|
+ llama_inference = LlamaInference()
|
|
|
|
|
|
+ try:
|
|
|
+ # Process the image and generate text
|
|
|
+ processed_image = llama_inference.process_image(image)
|
|
|
+ result = llama_inference.generate_text_from_image(processed_image, prompt_text, temperature, top_p)
|
|
|
+ finally:
|
|
|
+ # Perform memory cleanup
|
|
|
+ llama_inference.cleanup()
|
|
|
|
|
|
-def inference(image, prompt_text, temperature, top_p):
|
|
|
- """
|
|
|
- Wrapper function to load the model and generate text based on inputs from Gradio UI.
|
|
|
- """
|
|
|
- hf_token = os.getenv("HF_TOKEN") # Get the Hugging Face token from the environment
|
|
|
- if hf_token is None:
|
|
|
- return "Error: Hugging Face token not found in environment. Please set the HF_TOKEN environment variable."
|
|
|
-
|
|
|
- model, processor = load_model_and_processor(DEFAULT_MODEL, hf_token)
|
|
|
- processed_image = process_image(image)
|
|
|
- result = generate_text_from_image(model, processor, processed_image, prompt_text, temperature, top_p)
|
|
|
return result
|
|
|
|
|
|
|