瀏覽代碼

Implemented memory management to release GPU resources after inference

himanshushukla12 6 月之前
父節點
當前提交
22be586773
共有 1 個文件被更改,包括 85 次插入42 次删除
  1. 85 42
      recipes/quickstart/inference/local_inference/multi_modal_infer_Gradio_UI.py

+ 85 - 42
recipes/quickstart/inference/local_inference/multi_modal_infer_Gradio_UI.py

@@ -4,59 +4,102 @@ import torch
 from transformers import MllamaForConditionalGeneration, MllamaProcessor
 from accelerate import Accelerator
 import gradio as gr
+import gc  # Import garbage collector
 
-# Initialize accelerator
-accelerator = Accelerator()
-
-device = accelerator.device
 
 # Constants
 DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
 
-
-def load_model_and_processor(model_name: str, hf_token: str):
-    """
-    Load the model and processor based on the 11B or 90B model.
-    """
-    model = MllamaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, use_safetensors=True, 
-                                                            device_map=device, token=hf_token)
-    processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
-
-    model, processor = accelerator.prepare(model, processor)
-    return model, processor
-
-
-def process_image(image) -> PIL_Image.Image:
-    """
-    Open and convert an uploaded image to RGB format.
-    """
-    return image.convert("RGB")
+class LlamaInference:
+    def __init__(self, model_name=DEFAULT_MODEL, hf_token=None):
+        """
+        Initialize the inference class. Load model and processor.
+        """
+        self.accelerator = Accelerator()
+        self.device = self.accelerator.device
+        
+        self.model_name = model_name
+        self.hf_token = hf_token or os.getenv("HF_TOKEN")
+        
+        if self.hf_token is None:
+            raise ValueError("Error: Hugging Face token not found in environment. Please set the HF_TOKEN environment variable.")
+
+        # Load model and processor
+        self.model, self.processor = self.load_model_and_processor()
+
+    def load_model_and_processor(self):
+        """
+        Load the model and processor based on the model name.
+        """
+        model = MllamaForConditionalGeneration.from_pretrained(self.model_name, 
+                                                               torch_dtype=torch.bfloat16, 
+                                                               use_safetensors=True, 
+                                                               device_map=self.device, 
+                                                               token=self.hf_token)
+        processor = MllamaProcessor.from_pretrained(self.model_name, 
+                                                    token=self.hf_token, 
+                                                    use_safetensors=True)
+
+        # Prepare model and processor with accelerator
+        model, processor = self.accelerator.prepare(model, processor)
+        return model, processor
+
+    def process_image(self, image) -> PIL_Image.Image:
+        """
+        Open and convert an uploaded image to RGB format.
+        """
+        return image.convert("RGB")
+
+    def generate_text_from_image(self, image, prompt_text: str, temperature: float, top_p: float):
+        """
+        Generate text from an image using the model and processor.
+        """
+        conversation = [
+            {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
+        ]
+        prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
+        inputs = self.processor(image, prompt, return_tensors="pt").to(self.device)
+
+        # Perform inference without computing gradients to save memory
+        with torch.no_grad():
+            output = self.model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
+        
+        return self.processor.decode(output[0])[len(prompt):]
+
+    def cleanup(self):
+        """
+        Clean up instance variables to release memory.
+        """
+        # Move model and processor to CPU before deleting to free up GPU memory
+        self.model.to('cpu')
+        del self.model
+        del self.processor
+        torch.cuda.empty_cache()  # Release GPU memory
+        gc.collect()  # Force garbage collection
+
+        # Clear other instance variables
+        del self.accelerator
+        del self.device
+        del self.hf_token
+
+        print("Cleanup complete. Instance variables deleted and memory cleared.")
 
 
-def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
+def inference(image, prompt_text, temperature, top_p):
     """
-    Generate text from an image using the model and processor.
+    Main inference function to handle Gradio inputs and manage memory cleanup.
     """
-    conversation = [
-        {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
-    ]
-    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
-    inputs = processor(image, prompt, return_tensors="pt").to(device)
-    output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
-    return processor.decode(output[0])[len(prompt):]
+    # Initialize the inference instance (this loads the model)
+    llama_inference = LlamaInference()
 
+    try:
+        # Process the image and generate text
+        processed_image = llama_inference.process_image(image)
+        result = llama_inference.generate_text_from_image(processed_image, prompt_text, temperature, top_p)
+    finally:
+        # Perform memory cleanup
+        llama_inference.cleanup()
 
-def inference(image, prompt_text, temperature, top_p):
-    """
-    Wrapper function to load the model and generate text based on inputs from Gradio UI.
-    """
-    hf_token = os.getenv("HF_TOKEN")  # Get the Hugging Face token from the environment
-    if hf_token is None:
-        return "Error: Hugging Face token not found in environment. Please set the HF_TOKEN environment variable."
-    
-    model, processor = load_model_and_processor(DEFAULT_MODEL, hf_token)
-    processed_image = process_image(image)
-    result = generate_text_from_image(model, processor, processed_image, prompt_text, temperature, top_p)
     return result