Selaa lähdekoodia

added the passing of hugging-face token from the argument

himanshushukla12 6 kuukautta sitten
vanhempi
commit
6f7c028e6a

+ 22 - 8
recipes/quickstart/inference/local_inference/multi_modal_infer_Gradio_UI.py

@@ -3,22 +3,37 @@ import torch
 import os
 from PIL import Image
 from accelerate import Accelerator
-from transformers import MllamaForConditionalGeneration, AutoModelForCausalLM, AutoProcessor, GenerationConfig
-accelerate=Accelerator()
+from transformers import MllamaForConditionalGeneration, AutoProcessor
+import argparse  # Import argparse
+
+# Parse the command line arguments
+parser = argparse.ArgumentParser(description="Run Gradio app with Hugging Face model")
+parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face authentication token")
+args = parser.parse_args()
+
+# Hugging Face token
+hf_token = args.hf_token
+
+# Initialize Accelerator
+accelerate = Accelerator()
 device = accelerate.device
+
 # Set memory management for PyTorch
 os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'  # or adjust size as needed
 
-
-
+# Model ID
 model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
+
+# Load model with the Hugging Face token
 model = MllamaForConditionalGeneration.from_pretrained(
     model_id,
     torch_dtype=torch.bfloat16,
     device_map=device,
+    use_auth_token=hf_token  # Pass the Hugging Face token here
 )
-processor = AutoProcessor.from_pretrained(model_id)
 
+# Load the processor
+processor = AutoProcessor.from_pretrained(model_id, use_auth_token=hf_token)
 
 # Visual theme
 visual_theme = gr.themes.Default()  # Default, Soft or Monochrome
@@ -35,7 +50,6 @@ def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, hi
     # Initialize cleaned_output variable
     cleaned_output = ""
 
-
     prompt = f"<|image|><|begin_of_text|>{user_prompt} Answer:"
     # Preprocess the image and prompt
     inputs = processor(image, prompt, return_tensors="pt").to(device)
@@ -75,7 +89,7 @@ def gradio_interface():
         gr.HTML(
         """
     <h1 style='text-align: center'>
-    Clean-UI
+    meta-llama/Llama-3.2-11B-Vision-Instruct
     </h1>
     """)
         with gr.Row():
@@ -134,4 +148,4 @@ def gradio_interface():
 
 # Launch the interface
 demo = gradio_interface()
-demo.launch()
+demo.launch()