|
@@ -3,22 +3,37 @@ import torch
|
|
|
import os
|
|
|
from PIL import Image
|
|
|
from accelerate import Accelerator
|
|
|
-from transformers import MllamaForConditionalGeneration, AutoModelForCausalLM, AutoProcessor, GenerationConfig
|
|
|
-accelerate=Accelerator()
|
|
|
+from transformers import MllamaForConditionalGeneration, AutoProcessor
|
|
|
+import argparse # Import argparse
|
|
|
+
|
|
|
+# Parse the command line arguments
|
|
|
+parser = argparse.ArgumentParser(description="Run Gradio app with Hugging Face model")
|
|
|
+parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face authentication token")
|
|
|
+args = parser.parse_args()
|
|
|
+
|
|
|
+# Hugging Face token
|
|
|
+hf_token = args.hf_token
|
|
|
+
|
|
|
+# Initialize Accelerator
|
|
|
+accelerate = Accelerator()
|
|
|
device = accelerate.device
|
|
|
+
|
|
|
# Set memory management for PyTorch
|
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' # or adjust size as needed
|
|
|
|
|
|
-
|
|
|
-
|
|
|
+# Model ID
|
|
|
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
|
|
+
|
|
|
+# Load model with the Hugging Face token
|
|
|
model = MllamaForConditionalGeneration.from_pretrained(
|
|
|
model_id,
|
|
|
torch_dtype=torch.bfloat16,
|
|
|
device_map=device,
|
|
|
+ use_auth_token=hf_token # Pass the Hugging Face token here
|
|
|
)
|
|
|
-processor = AutoProcessor.from_pretrained(model_id)
|
|
|
|
|
|
+# Load the processor
|
|
|
+processor = AutoProcessor.from_pretrained(model_id, use_auth_token=hf_token)
|
|
|
|
|
|
# Visual theme
|
|
|
visual_theme = gr.themes.Default() # Default, Soft or Monochrome
|
|
@@ -35,7 +50,6 @@ def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, hi
|
|
|
# Initialize cleaned_output variable
|
|
|
cleaned_output = ""
|
|
|
|
|
|
-
|
|
|
prompt = f"<|image|><|begin_of_text|>{user_prompt} Answer:"
|
|
|
# Preprocess the image and prompt
|
|
|
inputs = processor(image, prompt, return_tensors="pt").to(device)
|
|
@@ -75,7 +89,7 @@ def gradio_interface():
|
|
|
gr.HTML(
|
|
|
"""
|
|
|
<h1 style='text-align: center'>
|
|
|
- Clean-UI
|
|
|
+ meta-llama/Llama-3.2-11B-Vision-Instruct
|
|
|
</h1>
|
|
|
""")
|
|
|
with gr.Row():
|
|
@@ -134,4 +148,4 @@ def gradio_interface():
|
|
|
|
|
|
# Launch the interface
|
|
|
demo = gradio_interface()
|
|
|
-demo.launch()
|
|
|
+demo.launch()
|