multi_modal_infer.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import argparse
  2. import os
  3. import sys
  4. import torch
  5. from accelerate import Accelerator
  6. from PIL import Image as PIL_Image
  7. from transformers import MllamaForConditionalGeneration, MllamaProcessor
  8. from peft import PeftModel
  9. import gradio as gr
  10. from huggingface_hub import login
  11. # Initialize accelerator
  12. accelerator = Accelerator()
  13. device = accelerator.device
  14. # Constants
  15. DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
  16. MAX_OUTPUT_TOKENS = 2048
  17. MAX_IMAGE_SIZE = (1120, 1120)
  18. from huggingface_hub import HfFolder
  19. def get_hf_token():
  20. """Retrieve Hugging Face token from the cache or environment."""
  21. # Check if a token is explicitly set in the environment
  22. token = os.getenv("HUGGINGFACE_TOKEN")
  23. if token:
  24. return token
  25. # Automatically retrieve the token from the Hugging Face cache (set via huggingface-cli login)
  26. token = HfFolder.get_token()
  27. if token:
  28. return token
  29. print("Hugging Face token not found. Please login using `huggingface-cli login`.")
  30. sys.exit(1)
  31. def load_model_and_processor(model_name: str, finetuning_path: str = None):
  32. """Load model and processor with optional LoRA adapter"""
  33. print(f"Loading model: {model_name}")
  34. hf_token = get_hf_token()
  35. model = MllamaForConditionalGeneration.from_pretrained(
  36. model_name,
  37. torch_dtype=torch.bfloat16,
  38. use_safetensors=True,
  39. device_map=device,
  40. token=hf_token
  41. )
  42. processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
  43. if finetuning_path and os.path.exists(finetuning_path):
  44. print(f"Loading LoRA adapter from '{finetuning_path}'...")
  45. model = PeftModel.from_pretrained(
  46. model,
  47. finetuning_path,
  48. is_adapter=True,
  49. torch_dtype=torch.bfloat16
  50. )
  51. print("LoRA adapter merged successfully")
  52. model, processor = accelerator.prepare(model, processor)
  53. return model, processor
  54. def process_image(image_path: str = None, image = None) -> PIL_Image.Image:
  55. """Process and validate image input"""
  56. if image is not None:
  57. return image.convert("RGB")
  58. if image_path and os.path.exists(image_path):
  59. return PIL_Image.open(image_path).convert("RGB")
  60. raise ValueError("No valid image provided")
  61. def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
  62. """Generate text from image using model"""
  63. conversation = [
  64. {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
  65. ]
  66. prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
  67. inputs = processor(image, prompt, return_tensors="pt").to(device)
  68. output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS)
  69. return processor.decode(output[0])[len(prompt):]
  70. def gradio_interface(model_name: str):
  71. """Create Gradio UI with LoRA support"""
  72. # Initialize model state
  73. current_model = {"model": None, "processor": None}
  74. def load_or_reload_model(enable_lora: bool, lora_path: str = None):
  75. current_model["model"], current_model["processor"] = load_model_and_processor(
  76. model_name,
  77. lora_path if enable_lora else None
  78. )
  79. return "Model loaded successfully" + (" with LoRA" if enable_lora else "")
  80. def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
  81. if image is not None:
  82. try:
  83. processed_image = process_image(image=image)
  84. result = generate_text_from_image(
  85. current_model["model"],
  86. current_model["processor"],
  87. processed_image,
  88. user_prompt,
  89. temperature,
  90. top_p
  91. )
  92. history.append((user_prompt, result))
  93. except Exception as e:
  94. history.append((user_prompt, f"Error: {str(e)}"))
  95. return history
  96. def clear_chat():
  97. return []
  98. with gr.Blocks() as demo:
  99. gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>")
  100. with gr.Row():
  101. with gr.Column(scale=1):
  102. # Model loading controls
  103. with gr.Group():
  104. enable_lora = gr.Checkbox(label="Enable LoRA", value=False)
  105. lora_path = gr.Textbox(
  106. label="LoRA Weights Path",
  107. placeholder="Path to LoRA weights folder",
  108. visible=False
  109. )
  110. load_status = gr.Textbox(label="Load Status", interactive=False)
  111. load_button = gr.Button("Load/Reload Model")
  112. # Image and parameter controls
  113. image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512)
  114. temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1)
  115. top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1)
  116. top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
  117. max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50)
  118. with gr.Column(scale=2):
  119. chat_history = gr.Chatbot(label="Chat", height=512)
  120. user_prompt = gr.Textbox(
  121. show_label=False,
  122. placeholder="Enter your prompt",
  123. lines=2
  124. )
  125. with gr.Row():
  126. generate_button = gr.Button("Generate")
  127. clear_button = gr.Button("Clear")
  128. # Event handlers
  129. enable_lora.change(
  130. fn=lambda x: gr.update(visible=x),
  131. inputs=[enable_lora],
  132. outputs=[lora_path]
  133. )
  134. load_button.click(
  135. fn=load_or_reload_model,
  136. inputs=[enable_lora, lora_path],
  137. outputs=[load_status]
  138. )
  139. generate_button.click(
  140. fn=describe_image,
  141. inputs=[
  142. image_input, user_prompt, temperature,
  143. top_k, top_p, max_tokens, chat_history
  144. ],
  145. outputs=[chat_history]
  146. )
  147. clear_button.click(fn=clear_chat, outputs=[chat_history])
  148. # Initial model load
  149. load_or_reload_model(False)
  150. return demo
  151. def main(args):
  152. """Main execution flow"""
  153. if args.gradio_ui:
  154. demo = gradio_interface(args.model_name)
  155. demo.launch()
  156. else:
  157. model, processor = load_model_and_processor(
  158. args.model_name,
  159. args.finetuning_path
  160. )
  161. image = process_image(image_path=args.image_path)
  162. result = generate_text_from_image(
  163. model, processor, image,
  164. args.prompt_text,
  165. args.temperature,
  166. args.top_p
  167. )
  168. print("Generated Text:", result)
  169. if __name__ == "__main__":
  170. parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support")
  171. parser.add_argument("--image_path", type=str, help="Path to the input image")
  172. parser.add_argument("--prompt_text", type=str, help="Prompt text for the image")
  173. parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
  174. parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
  175. parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name")
  176. parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights")
  177. parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI")
  178. args = parser.parse_args()
  179. main(args)