Browse Source

Added passing of Hugging-face token from the arguments

himanshushukla12 6 months ago
parent
commit
750b499d14

+ 14 - 7
recipes/quickstart/inference/local_inference/multi_modal_infer_Gradio_UI.py

@@ -5,6 +5,7 @@ from transformers import MllamaForConditionalGeneration, MllamaProcessor
 from accelerate import Accelerator
 import gradio as gr
 import gc  # Import garbage collector
+import argparse  # Import argparse for command-line arguments
 
 
 # Constants
@@ -19,10 +20,10 @@ class LlamaInference:
         self.device = self.accelerator.device
         
         self.model_name = model_name
-        self.hf_token = hf_token or os.getenv("HF_TOKEN")
+        self.hf_token = hf_token
         
         if self.hf_token is None:
-            raise ValueError("Error: Hugging Face token not found in environment. Please set the HF_TOKEN environment variable.")
+            raise ValueError("Error: Hugging Face token not provided.")
 
         # Load model and processor
         self.model, self.processor = self.load_model_and_processor()
@@ -85,12 +86,12 @@ class LlamaInference:
         print("Cleanup complete. Instance variables deleted and memory cleared.")
 
 
-def inference(image, prompt_text, temperature, top_p):
+def inference(image, prompt_text, temperature, top_p, hf_token):
     """
     Main inference function to handle Gradio inputs and manage memory cleanup.
     """
     # Initialize the inference instance (this loads the model)
-    llama_inference = LlamaInference()
+    llama_inference = LlamaInference(hf_token=hf_token)
 
     try:
         # Process the image and generate text
@@ -104,7 +105,7 @@ def inference(image, prompt_text, temperature, top_p):
 
 
 # Gradio UI
-def create_gradio_interface():
+def create_gradio_interface(hf_token):
     """
     Create the Gradio interface for image-to-text generation.
     """
@@ -119,7 +120,7 @@ def create_gradio_interface():
 
     # Create the interface
     interface = gr.Interface(
-        fn=inference,
+        fn=lambda image, prompt_text, temperature, top_p: inference(image, prompt_text, temperature, top_p, hf_token),
         inputs=[image_input, prompt_input, temperature_input, top_p_input],
         outputs=output_text,
         title="LLama-3.2 Vision-Instruct",
@@ -132,4 +133,10 @@ def create_gradio_interface():
 
 
 if __name__ == "__main__":
-    create_gradio_interface()
+    # Parse command-line arguments
+    parser = argparse.ArgumentParser(description="Run LLama-3.2 Vision-Instruct with HF token passed via arguments.")
+    parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token to access the model")
+    args = parser.parse_args()
+
+    # Pass the HF token to Gradio interface
+    create_gradio_interface(hf_token=args.hf_token)