multi_modal_infer_Gradio_UI.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import gradio as gr
  2. import torch
  3. import os
  4. from PIL import Image
  5. from accelerate import Accelerator
  6. from transformers import MllamaForConditionalGeneration, AutoProcessor
  7. import argparse # Import argparse
  8. # Parse the command line arguments
  9. parser = argparse.ArgumentParser(description="Run Gradio app with Hugging Face model")
  10. parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face authentication token")
  11. args = parser.parse_args()
  12. # Hugging Face token
  13. hf_token = args.hf_token
  14. # Initialize Accelerator
  15. accelerate = Accelerator()
  16. device = accelerate.device
  17. # Set memory management for PyTorch
  18. os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' # or adjust size as needed
  19. # Model ID
  20. model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
  21. # Load model with the Hugging Face token
  22. model = MllamaForConditionalGeneration.from_pretrained(
  23. model_id,
  24. torch_dtype=torch.bfloat16,
  25. device_map=device,
  26. use_auth_token=hf_token # Pass the Hugging Face token here
  27. )
  28. # Load the processor
  29. processor = AutoProcessor.from_pretrained(model_id, use_auth_token=hf_token)
  30. # Visual theme
  31. visual_theme = gr.themes.Default() # Default, Soft or Monochrome
  32. # Constants
  33. MAX_OUTPUT_TOKENS = 2048
  34. MAX_IMAGE_SIZE = (1120, 1120)
  35. # Function to process the image and generate a description
  36. def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
  37. # Resize image if necessary
  38. image = image.resize(MAX_IMAGE_SIZE)
  39. # Initialize cleaned_output variable
  40. cleaned_output = ""
  41. prompt = f"<|image|><|begin_of_text|>{user_prompt} Answer:"
  42. # Preprocess the image and prompt
  43. inputs = processor(image, prompt, return_tensors="pt").to(device)
  44. # Generate output with model
  45. output = model.generate(
  46. **inputs,
  47. max_new_tokens=min(max_tokens, MAX_OUTPUT_TOKENS),
  48. temperature=temperature,
  49. top_k=top_k,
  50. top_p=top_p
  51. )
  52. # Decode the raw output
  53. raw_output = processor.decode(output[0])
  54. # Clean up the output to remove system tokens
  55. cleaned_output = raw_output.replace("<|image|><|begin_of_text|>", "").strip().replace(" Answer:", "")
  56. # Ensure the prompt is not repeated in the output
  57. if cleaned_output.startswith(user_prompt):
  58. cleaned_output = cleaned_output[len(user_prompt):].strip()
  59. # Append the new conversation to the history
  60. history.append((user_prompt, cleaned_output))
  61. return history
  62. # Function to clear the chat history
  63. def clear_chat():
  64. return []
  65. # Gradio Interface
  66. def gradio_interface():
  67. with gr.Blocks(visual_theme) as demo:
  68. gr.HTML(
  69. """
  70. <h1 style='text-align: center'>
  71. meta-llama/Llama-3.2-11B-Vision-Instruct
  72. </h1>
  73. """)
  74. with gr.Row():
  75. # Left column with image and parameter inputs
  76. with gr.Column(scale=1):
  77. image_input = gr.Image(
  78. label="Image",
  79. type="pil",
  80. image_mode="RGB",
  81. height=512, # Set the height
  82. width=512 # Set the width
  83. )
  84. # Parameter sliders
  85. temperature = gr.Slider(
  86. label="Temperature", minimum=0.1, maximum=2.0, value=0.6, step=0.1, interactive=True)
  87. top_k = gr.Slider(
  88. label="Top-k", minimum=1, maximum=100, value=50, step=1, interactive=True)
  89. top_p = gr.Slider(
  90. label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1, interactive=True)
  91. max_tokens = gr.Slider(
  92. label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50, interactive=True)
  93. # Right column with the chat interface
  94. with gr.Column(scale=2):
  95. chat_history = gr.Chatbot(label="Chat", height=512)
  96. # User input box for prompt
  97. user_prompt = gr.Textbox(
  98. show_label=False,
  99. container=False,
  100. placeholder="Enter your prompt",
  101. lines=2
  102. )
  103. # Generate and Clear buttons
  104. with gr.Row():
  105. generate_button = gr.Button("Generate")
  106. clear_button = gr.Button("Clear")
  107. # Define the action for the generate button
  108. generate_button.click(
  109. fn=describe_image,
  110. inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history],
  111. outputs=[chat_history]
  112. )
  113. # Define the action for the clear button
  114. clear_button.click(
  115. fn=clear_chat,
  116. inputs=[],
  117. outputs=[chat_history]
  118. )
  119. return demo
  120. # Launch the interface
  121. demo = gradio_interface()
  122. demo.launch()