|
@@ -5,6 +5,7 @@ from transformers import MllamaForConditionalGeneration, MllamaProcessor
|
|
|
from accelerate import Accelerator
|
|
|
import gradio as gr
|
|
|
import gc # Import garbage collector
|
|
|
+import argparse # Import argparse for command-line arguments
|
|
|
|
|
|
|
|
|
# Constants
|
|
@@ -19,10 +20,10 @@ class LlamaInference:
|
|
|
self.device = self.accelerator.device
|
|
|
|
|
|
self.model_name = model_name
|
|
|
- self.hf_token = hf_token or os.getenv("HF_TOKEN")
|
|
|
+ self.hf_token = hf_token
|
|
|
|
|
|
if self.hf_token is None:
|
|
|
- raise ValueError("Error: Hugging Face token not found in environment. Please set the HF_TOKEN environment variable.")
|
|
|
+ raise ValueError("Error: Hugging Face token not provided.")
|
|
|
|
|
|
# Load model and processor
|
|
|
self.model, self.processor = self.load_model_and_processor()
|
|
@@ -85,12 +86,12 @@ class LlamaInference:
|
|
|
print("Cleanup complete. Instance variables deleted and memory cleared.")
|
|
|
|
|
|
|
|
|
-def inference(image, prompt_text, temperature, top_p):
|
|
|
+def inference(image, prompt_text, temperature, top_p, hf_token):
|
|
|
"""
|
|
|
Main inference function to handle Gradio inputs and manage memory cleanup.
|
|
|
"""
|
|
|
# Initialize the inference instance (this loads the model)
|
|
|
- llama_inference = LlamaInference()
|
|
|
+ llama_inference = LlamaInference(hf_token=hf_token)
|
|
|
|
|
|
try:
|
|
|
# Process the image and generate text
|
|
@@ -104,7 +105,7 @@ def inference(image, prompt_text, temperature, top_p):
|
|
|
|
|
|
|
|
|
# Gradio UI
|
|
|
-def create_gradio_interface():
|
|
|
+def create_gradio_interface(hf_token):
|
|
|
"""
|
|
|
Create the Gradio interface for image-to-text generation.
|
|
|
"""
|
|
@@ -119,7 +120,7 @@ def create_gradio_interface():
|
|
|
|
|
|
# Create the interface
|
|
|
interface = gr.Interface(
|
|
|
- fn=inference,
|
|
|
+ fn=lambda image, prompt_text, temperature, top_p: inference(image, prompt_text, temperature, top_p, hf_token),
|
|
|
inputs=[image_input, prompt_input, temperature_input, top_p_input],
|
|
|
outputs=output_text,
|
|
|
title="LLama-3.2 Vision-Instruct",
|
|
@@ -132,4 +133,10 @@ def create_gradio_interface():
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- create_gradio_interface()
|
|
|
+ # Parse command-line arguments
|
|
|
+ parser = argparse.ArgumentParser(description="Run LLama-3.2 Vision-Instruct with HF token passed via arguments.")
|
|
|
+ parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token to access the model")
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ # Pass the HF token to Gradio interface
|
|
|
+ create_gradio_interface(hf_token=args.hf_token)
|