multi_modal_infer_Gradio_UI.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import gradio as gr
  2. import torch
  3. import os
  4. from PIL import Image
  5. from accelerate import Accelerator
  6. from transformers import MllamaForConditionalGeneration, AutoModelForCausalLM, AutoProcessor, GenerationConfig
  7. accelerate=Accelerator()
  8. device = accelerate.device
  9. # Set memory management for PyTorch
  10. os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' # or adjust size as needed
  11. model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
  12. model = MllamaForConditionalGeneration.from_pretrained(
  13. model_id,
  14. torch_dtype=torch.bfloat16,
  15. device_map=device,
  16. )
  17. processor = AutoProcessor.from_pretrained(model_id)
  18. # Visual theme
  19. visual_theme = gr.themes.Default() # Default, Soft or Monochrome
  20. # Constants
  21. MAX_OUTPUT_TOKENS = 2048
  22. MAX_IMAGE_SIZE = (1120, 1120)
  23. # Function to process the image and generate a description
  24. def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
  25. # Resize image if necessary
  26. image = image.resize(MAX_IMAGE_SIZE)
  27. # Initialize cleaned_output variable
  28. cleaned_output = ""
  29. prompt = f"<|image|><|begin_of_text|>{user_prompt} Answer:"
  30. # Preprocess the image and prompt
  31. inputs = processor(image, prompt, return_tensors="pt").to(device)
  32. # Generate output with model
  33. output = model.generate(
  34. **inputs,
  35. max_new_tokens=min(max_tokens, MAX_OUTPUT_TOKENS),
  36. temperature=temperature,
  37. top_k=top_k,
  38. top_p=top_p
  39. )
  40. # Decode the raw output
  41. raw_output = processor.decode(output[0])
  42. # Clean up the output to remove system tokens
  43. cleaned_output = raw_output.replace("<|image|><|begin_of_text|>", "").strip().replace(" Answer:", "")
  44. # Ensure the prompt is not repeated in the output
  45. if cleaned_output.startswith(user_prompt):
  46. cleaned_output = cleaned_output[len(user_prompt):].strip()
  47. # Append the new conversation to the history
  48. history.append((user_prompt, cleaned_output))
  49. return history
  50. # Function to clear the chat history
  51. def clear_chat():
  52. return []
  53. # Gradio Interface
  54. def gradio_interface():
  55. with gr.Blocks(visual_theme) as demo:
  56. gr.HTML(
  57. """
  58. <h1 style='text-align: center'>
  59. Clean-UI
  60. </h1>
  61. """)
  62. with gr.Row():
  63. # Left column with image and parameter inputs
  64. with gr.Column(scale=1):
  65. image_input = gr.Image(
  66. label="Image",
  67. type="pil",
  68. image_mode="RGB",
  69. height=512, # Set the height
  70. width=512 # Set the width
  71. )
  72. # Parameter sliders
  73. temperature = gr.Slider(
  74. label="Temperature", minimum=0.1, maximum=2.0, value=0.6, step=0.1, interactive=True)
  75. top_k = gr.Slider(
  76. label="Top-k", minimum=1, maximum=100, value=50, step=1, interactive=True)
  77. top_p = gr.Slider(
  78. label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1, interactive=True)
  79. max_tokens = gr.Slider(
  80. label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50, interactive=True)
  81. # Right column with the chat interface
  82. with gr.Column(scale=2):
  83. chat_history = gr.Chatbot(label="Chat", height=512)
  84. # User input box for prompt
  85. user_prompt = gr.Textbox(
  86. show_label=False,
  87. container=False,
  88. placeholder="Enter your prompt",
  89. lines=2
  90. )
  91. # Generate and Clear buttons
  92. with gr.Row():
  93. generate_button = gr.Button("Generate")
  94. clear_button = gr.Button("Clear")
  95. # Define the action for the generate button
  96. generate_button.click(
  97. fn=describe_image,
  98. inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history],
  99. outputs=[chat_history]
  100. )
  101. # Define the action for the clear button
  102. clear_button.click(
  103. fn=clear_chat,
  104. inputs=[],
  105. outputs=[chat_history]
  106. )
  107. return demo
  108. # Launch the interface
  109. demo = gradio_interface()
  110. demo.launch()