multi_modal_infer.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import argparse
  2. import os
  3. import sys
  4. import torch
  5. from accelerate import Accelerator
  6. from PIL import Image as PIL_Image
  7. from transformers import MllamaForConditionalGeneration, MllamaProcessor
  8. from peft import PeftModel
  9. import gradio as gr
  10. # Initialize accelerator
  11. accelerator = Accelerator()
  12. device = accelerator.device
  13. # Constants
  14. DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
  15. MAX_OUTPUT_TOKENS = 2048
  16. MAX_IMAGE_SIZE = (1120, 1120)
  17. def load_model_and_processor(model_name: str, hf_token: str = None, finetuning_path: str = None):
  18. """Load model and processor with optional LoRA adapter"""
  19. model = MllamaForConditionalGeneration.from_pretrained(
  20. model_name,
  21. torch_dtype=torch.bfloat16,
  22. use_safetensors=True,
  23. device_map=device,
  24. token=hf_token
  25. )
  26. processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
  27. if finetuning_path and os.path.exists(finetuning_path):
  28. print(f"Loading adapter from '{finetuning_path}'...")
  29. model = PeftModel.from_pretrained(
  30. model,
  31. finetuning_path,
  32. is_adapter=True,
  33. torch_dtype=torch.bfloat16
  34. )
  35. print("Adapter merged successfully")
  36. model, processor = accelerator.prepare(model, processor)
  37. return model, processor
  38. def process_image(image_path: str) -> PIL_Image.Image:
  39. """Process and validate image input"""
  40. if not os.path.exists(image_path):
  41. print(f"Image file '{image_path}' does not exist.")
  42. sys.exit(1)
  43. return PIL_Image.open(image_path).convert("RGB")
  44. def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
  45. """Generate text from image using model"""
  46. conversation = [
  47. {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
  48. ]
  49. prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
  50. inputs = processor(image, prompt, return_tensors="pt").to(device)
  51. output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS)
  52. return processor.decode(output[0])[len(prompt):]
  53. def gradio_interface(model, processor):
  54. """Create Gradio UI"""
  55. def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
  56. if image is not None:
  57. image = image.resize(MAX_IMAGE_SIZE)
  58. result = generate_text_from_image(model, processor, image, user_prompt, temperature, top_p)
  59. history.append((user_prompt, result))
  60. return history
  61. def clear_chat():
  62. return []
  63. with gr.Blocks() as demo:
  64. gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>")
  65. with gr.Row():
  66. with gr.Column(scale=1):
  67. image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512)
  68. temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1)
  69. top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1)
  70. top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
  71. max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50)
  72. with gr.Column(scale=2):
  73. chat_history = gr.Chatbot(label="Chat", height=512)
  74. user_prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", lines=2)
  75. with gr.Row():
  76. generate_button = gr.Button("Generate")
  77. clear_button = gr.Button("Clear")
  78. generate_button.click(
  79. fn=describe_image,
  80. inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history],
  81. outputs=[chat_history]
  82. )
  83. clear_button.click(fn=clear_chat, outputs=[chat_history])
  84. return demo
  85. def main(args):
  86. """Main execution flow"""
  87. model, processor = load_model_and_processor(
  88. args.model_name,
  89. args.hf_token,
  90. args.finetuning_path
  91. )
  92. if args.gradio_ui:
  93. demo = gradio_interface(model, processor)
  94. demo.launch()
  95. else:
  96. image = process_image(args.image_path)
  97. result = generate_text_from_image(
  98. model, processor, image, args.prompt_text, args.temperature, args.top_p
  99. )
  100. print("Generated Text:", result)
  101. if __name__ == "__main__":
  102. parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support")
  103. parser.add_argument("--image_path", type=str, help="Path to the input image")
  104. parser.add_argument("--prompt_text", type=str, help="Prompt text for the image")
  105. parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
  106. parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
  107. parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name")
  108. parser.add_argument("--hf_token", type=str, help="Hugging Face API token")
  109. parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights")
  110. parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI")
  111. args = parser.parse_args()
  112. main(args)