multi_modal_infer.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import argparse
  2. import os
  3. import sys
  4. import torch
  5. from accelerate import Accelerator
  6. from PIL import Image as PIL_Image
  7. from transformers import MllamaForConditionalGeneration, MllamaProcessor
  8. from peft import PeftModel
  9. import gradio as gr
  10. # Initialize accelerator
  11. accelerator = Accelerator()
  12. device = accelerator.device
  13. # Constants
  14. DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
  15. MAX_OUTPUT_TOKENS = 2048
  16. MAX_IMAGE_SIZE = (1120, 1120)
  17. def load_model_and_processor(model_name: str, hf_token: str = None, finetuning_path: str = None):
  18. """Load model and processor with optional LoRA adapter"""
  19. print(f"Loading model: {model_name}")
  20. model = MllamaForConditionalGeneration.from_pretrained(
  21. model_name,
  22. torch_dtype=torch.bfloat16,
  23. use_safetensors=True,
  24. device_map=device,
  25. token=hf_token
  26. )
  27. processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
  28. if finetuning_path and os.path.exists(finetuning_path):
  29. print(f"Loading LoRA adapter from '{finetuning_path}'...")
  30. model = PeftModel.from_pretrained(
  31. model,
  32. finetuning_path,
  33. is_adapter=True,
  34. torch_dtype=torch.bfloat16
  35. )
  36. print("LoRA adapter merged successfully")
  37. model, processor = accelerator.prepare(model, processor)
  38. return model, processor
  39. def process_image(image_path: str = None, image = None) -> PIL_Image.Image:
  40. """Process and validate image input"""
  41. if image is not None:
  42. return image.convert("RGB")
  43. if image_path and os.path.exists(image_path):
  44. return PIL_Image.open(image_path).convert("RGB")
  45. raise ValueError("No valid image provided")
  46. def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
  47. """Generate text from image using model"""
  48. conversation = [
  49. {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
  50. ]
  51. prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
  52. inputs = processor(image, prompt, return_tensors="pt").to(device)
  53. output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS)
  54. return processor.decode(output[0])[len(prompt):]
  55. def gradio_interface(model_name: str, hf_token: str):
  56. """Create Gradio UI with LoRA support"""
  57. # Initialize model state
  58. current_model = {"model": None, "processor": None}
  59. def load_or_reload_model(enable_lora: bool, lora_path: str = None):
  60. current_model["model"], current_model["processor"] = load_model_and_processor(
  61. model_name,
  62. hf_token,
  63. lora_path if enable_lora else None
  64. )
  65. return "Model loaded successfully" + (" with LoRA" if enable_lora else "")
  66. def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
  67. if image is not None:
  68. try:
  69. processed_image = process_image(image=image)
  70. result = generate_text_from_image(
  71. current_model["model"],
  72. current_model["processor"],
  73. processed_image,
  74. user_prompt,
  75. temperature,
  76. top_p
  77. )
  78. history.append((user_prompt, result))
  79. except Exception as e:
  80. history.append((user_prompt, f"Error: {str(e)}"))
  81. return history
  82. def clear_chat():
  83. return []
  84. with gr.Blocks() as demo:
  85. gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>")
  86. with gr.Row():
  87. with gr.Column(scale=1):
  88. # Model loading controls
  89. with gr.Group():
  90. enable_lora = gr.Checkbox(label="Enable LoRA", value=False)
  91. lora_path = gr.Textbox(
  92. label="LoRA Weights Path",
  93. placeholder="Path to LoRA weights folder",
  94. visible=False
  95. )
  96. load_status = gr.Textbox(label="Load Status", interactive=False)
  97. load_button = gr.Button("Load/Reload Model")
  98. # Image and parameter controls
  99. image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512)
  100. temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1)
  101. top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1)
  102. top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
  103. max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50)
  104. with gr.Column(scale=2):
  105. chat_history = gr.Chatbot(label="Chat", height=512)
  106. user_prompt = gr.Textbox(
  107. show_label=False,
  108. placeholder="Enter your prompt",
  109. lines=2
  110. )
  111. with gr.Row():
  112. generate_button = gr.Button("Generate")
  113. clear_button = gr.Button("Clear")
  114. # Event handlers
  115. enable_lora.change(
  116. fn=lambda x: gr.update(visible=x),
  117. inputs=[enable_lora],
  118. outputs=[lora_path]
  119. )
  120. load_button.click(
  121. fn=load_or_reload_model,
  122. inputs=[enable_lora, lora_path],
  123. outputs=[load_status]
  124. )
  125. generate_button.click(
  126. fn=describe_image,
  127. inputs=[
  128. image_input, user_prompt, temperature,
  129. top_k, top_p, max_tokens, chat_history
  130. ],
  131. outputs=[chat_history]
  132. )
  133. clear_button.click(fn=clear_chat, outputs=[chat_history])
  134. # Initial model load
  135. load_or_reload_model(False)
  136. return demo
  137. def main(args):
  138. """Main execution flow"""
  139. if args.gradio_ui:
  140. demo = gradio_interface(args.model_name, args.hf_token)
  141. demo.launch()
  142. else:
  143. model, processor = load_model_and_processor(
  144. args.model_name,
  145. args.hf_token,
  146. args.finetuning_path
  147. )
  148. image = process_image(image_path=args.image_path)
  149. result = generate_text_from_image(
  150. model, processor, image,
  151. args.prompt_text,
  152. args.temperature,
  153. args.top_p
  154. )
  155. print("Generated Text:", result)
  156. if __name__ == "__main__":
  157. parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support")
  158. parser.add_argument("--image_path", type=str, help="Path to the input image")
  159. parser.add_argument("--prompt_text", type=str, help="Prompt text for the image")
  160. parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
  161. parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
  162. parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name")
  163. parser.add_argument("--hf_token", type=str, help="Hugging Face API token")
  164. parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights")
  165. parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI")
  166. args = parser.parse_args()
  167. main(args)