multi_modal_infer_gradio_UI.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import gradio as gr
  2. import torch
  3. import os
  4. from PIL import Image
  5. from accelerate import Accelerator
  6. from transformers import MllamaForConditionalGeneration, AutoProcessor
  7. import argparse # Import argparse
  8. # Parse the command line arguments
  9. parser = argparse.ArgumentParser(description="Run Gradio app with Hugging Face model")
  10. parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face authentication token")
  11. args = parser.parse_args()
  12. # Hugging Face token
  13. hf_token = args.hf_token
  14. # Initialize Accelerator
  15. accelerate = Accelerator()
  16. device = accelerate.device
  17. # Set memory management for PyTorch
  18. os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' # or adjust size as needed
  19. # Model ID
  20. model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
  21. # Load model with the Hugging Face token
  22. model = MllamaForConditionalGeneration.from_pretrained(
  23. model_id,
  24. torch_dtype=torch.bfloat16,
  25. device_map=device,
  26. use_auth_token=hf_token # Pass the Hugging Face token here
  27. )
  28. # Load the processor
  29. processor = AutoProcessor.from_pretrained(model_id, use_auth_token=hf_token)
  30. # Visual theme
  31. visual_theme = gr.themes.Default() # Default, Soft or Monochrome
  32. # Constants
  33. MAX_OUTPUT_TOKENS = 2048
  34. MAX_IMAGE_SIZE = (1120, 1120)
  35. # Function to process the image and generate a description
  36. def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
  37. # Initialize cleaned_output variable
  38. cleaned_output = ""
  39. if image is not None:
  40. # Resize image if necessary
  41. image = image.resize(MAX_IMAGE_SIZE)
  42. prompt = f"<|image|><|begin_of_text|>{user_prompt} Answer:"
  43. # Preprocess the image and prompt
  44. inputs = processor(image, prompt, return_tensors="pt").to(device)
  45. else:
  46. # Text-only input if no image is provided
  47. prompt = f"<|begin_of_text|>{user_prompt} Answer:"
  48. # Preprocess the prompt only (no image)
  49. inputs = processor(prompt, return_tensors="pt").to(device)
  50. # Generate output with model
  51. output = model.generate(
  52. **inputs,
  53. max_new_tokens=min(max_tokens, MAX_OUTPUT_TOKENS),
  54. temperature=temperature,
  55. top_k=top_k,
  56. top_p=top_p
  57. )
  58. # Decode the raw output
  59. raw_output = processor.decode(output[0])
  60. # Clean up the output to remove system tokens
  61. cleaned_output = raw_output.replace("<|image|><|begin_of_text|>", "").strip().replace(" Answer:", "")
  62. # Ensure the prompt is not repeated in the output
  63. if cleaned_output.startswith(user_prompt):
  64. cleaned_output = cleaned_output[len(user_prompt):].strip()
  65. # Append the new conversation to the history
  66. history.append((user_prompt, cleaned_output))
  67. return history
  68. # Function to clear the chat history
  69. def clear_chat():
  70. return []
  71. # Gradio Interface
  72. def gradio_interface():
  73. with gr.Blocks(visual_theme) as demo:
  74. gr.HTML(
  75. """
  76. <h1 style='text-align: center'>
  77. meta-llama/Llama-3.2-11B-Vision-Instruct
  78. </h1>
  79. """)
  80. with gr.Row():
  81. # Left column with image and parameter inputs
  82. with gr.Column(scale=1):
  83. image_input = gr.Image(
  84. label="Image",
  85. type="pil",
  86. image_mode="RGB",
  87. height=512, # Set the height
  88. width=512 # Set the width
  89. )
  90. # Parameter sliders
  91. temperature = gr.Slider(
  92. label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1, interactive=True)
  93. top_k = gr.Slider(
  94. label="Top-k", minimum=1, maximum=100, value=50, step=1, interactive=True)
  95. top_p = gr.Slider(
  96. label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1, interactive=True)
  97. max_tokens = gr.Slider(
  98. label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50, interactive=True)
  99. # Right column with the chat interface
  100. with gr.Column(scale=2):
  101. chat_history = gr.Chatbot(label="Chat", height=512)
  102. # User input box for prompt
  103. user_prompt = gr.Textbox(
  104. show_label=False,
  105. container=False,
  106. placeholder="Enter your prompt",
  107. lines=2
  108. )
  109. # Generate and Clear buttons
  110. with gr.Row():
  111. generate_button = gr.Button("Generate")
  112. clear_button = gr.Button("Clear")
  113. # Define the action for the generate button
  114. generate_button.click(
  115. fn=describe_image,
  116. inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history],
  117. outputs=[chat_history]
  118. )
  119. # Define the action for the clear button
  120. clear_button.click(
  121. fn=clear_chat,
  122. inputs=[],
  123. outputs=[chat_history]
  124. )
  125. return demo
  126. # Launch the interface
  127. demo = gradio_interface()
  128. # demo.launch(server_name="0.0.0.0", server_port=12003)
  129. demo.launch()