multi_modal_infer.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import argparse
  2. import os
  3. import sys
  4. import torch
  5. from accelerate import Accelerator
  6. from PIL import Image as PIL_Image
  7. from transformers import MllamaForConditionalGeneration, MllamaProcessor
  8. from peft import PeftModel
  9. import gradio as gr
  10. from huggingface_hub import HfFolder
  11. # Initialize accelerator
  12. accelerator = Accelerator()
  13. device = accelerator.device
  14. # Constants
  15. DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
  16. MAX_OUTPUT_TOKENS = 2048
  17. MAX_IMAGE_SIZE = (1120, 1120)
  18. def get_hf_token():
  19. """Retrieve Hugging Face token from the cache or environment."""
  20. # Check if a token is explicitly set in the environment
  21. token = os.getenv("HUGGINGFACE_TOKEN")
  22. if token:
  23. return token
  24. # Automatically retrieve the token from the Hugging Face cache (set via huggingface-cli login)
  25. token = HfFolder.get_token()
  26. if token:
  27. return token
  28. print("Hugging Face token not found. Please login using `huggingface-cli login`.")
  29. sys.exit(1)
  30. def load_model_and_processor(model_name: str, finetuning_path: str = None):
  31. """Load model and processor with optional LoRA adapter"""
  32. print(f"Loading model: {model_name}")
  33. hf_token = get_hf_token()
  34. model = MllamaForConditionalGeneration.from_pretrained(
  35. model_name,
  36. torch_dtype=torch.bfloat16,
  37. use_safetensors=True,
  38. device_map=device,
  39. token=hf_token
  40. )
  41. processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
  42. if finetuning_path and os.path.exists(finetuning_path):
  43. print(f"Loading LoRA adapter from '{finetuning_path}'...")
  44. model = PeftModel.from_pretrained(
  45. model,
  46. finetuning_path,
  47. is_adapter=True,
  48. torch_dtype=torch.bfloat16
  49. )
  50. print("LoRA adapter merged successfully")
  51. model, processor = accelerator.prepare(model, processor)
  52. return model, processor
  53. def process_image(image_path: str = None, image = None) -> PIL_Image.Image:
  54. """Process and validate image input"""
  55. if image is not None:
  56. return image.convert("RGB")
  57. if image_path and os.path.exists(image_path):
  58. return PIL_Image.open(image_path).convert("RGB")
  59. raise ValueError("No valid image provided")
  60. def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
  61. """Generate text from image using model"""
  62. conversation = [
  63. {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
  64. ]
  65. prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
  66. inputs = processor(image, prompt, return_tensors="pt").to(device)
  67. output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS)
  68. return processor.decode(output[0])[len(prompt):]
  69. def gradio_interface(model_name: str):
  70. """Create Gradio UI with LoRA support"""
  71. # Initialize model state
  72. current_model = {"model": None, "processor": None}
  73. def load_or_reload_model(enable_lora: bool, lora_path: str = None):
  74. current_model["model"], current_model["processor"] = load_model_and_processor(
  75. model_name,
  76. lora_path if enable_lora else None
  77. )
  78. return "Model loaded successfully" + (" with LoRA" if enable_lora else "")
  79. def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
  80. if image is not None:
  81. try:
  82. processed_image = process_image(image=image)
  83. result = generate_text_from_image(
  84. current_model["model"],
  85. current_model["processor"],
  86. processed_image,
  87. user_prompt,
  88. temperature,
  89. top_p
  90. )
  91. history.append((user_prompt, result))
  92. except Exception as e:
  93. history.append((user_prompt, f"Error: {str(e)}"))
  94. return history
  95. def clear_chat():
  96. return []
  97. with gr.Blocks() as demo:
  98. gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>")
  99. with gr.Row():
  100. with gr.Column(scale=1):
  101. # Model loading controls
  102. with gr.Group():
  103. enable_lora = gr.Checkbox(label="Enable LoRA", value=False)
  104. lora_path = gr.Textbox(
  105. label="LoRA Weights Path",
  106. placeholder="Path to LoRA weights folder",
  107. visible=False
  108. )
  109. load_status = gr.Textbox(label="Load Status", interactive=False)
  110. load_button = gr.Button("Load/Reload Model")
  111. # Image and parameter controls
  112. image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512)
  113. temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1)
  114. top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1)
  115. top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
  116. max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50)
  117. with gr.Column(scale=2):
  118. chat_history = gr.Chatbot(label="Chat", height=512)
  119. user_prompt = gr.Textbox(
  120. show_label=False,
  121. placeholder="Enter your prompt",
  122. lines=2
  123. )
  124. with gr.Row():
  125. generate_button = gr.Button("Generate")
  126. clear_button = gr.Button("Clear")
  127. # Event handlers
  128. enable_lora.change(
  129. fn=lambda x: gr.update(visible=x),
  130. inputs=[enable_lora],
  131. outputs=[lora_path]
  132. )
  133. load_button.click(
  134. fn=load_or_reload_model,
  135. inputs=[enable_lora, lora_path],
  136. outputs=[load_status]
  137. )
  138. generate_button.click(
  139. fn=describe_image,
  140. inputs=[
  141. image_input, user_prompt, temperature,
  142. top_k, top_p, max_tokens, chat_history
  143. ],
  144. outputs=[chat_history]
  145. )
  146. clear_button.click(fn=clear_chat, outputs=[chat_history])
  147. # Initial model load
  148. load_or_reload_model(False)
  149. return demo
  150. def main(args):
  151. """Main execution flow"""
  152. if args.gradio_ui:
  153. demo = gradio_interface(args.model_name)
  154. demo.launch()
  155. else:
  156. model, processor = load_model_and_processor(
  157. args.model_name,
  158. args.finetuning_path
  159. )
  160. image = process_image(image_path=args.image_path)
  161. result = generate_text_from_image(
  162. model, processor, image,
  163. args.prompt_text,
  164. args.temperature,
  165. args.top_p
  166. )
  167. print("Generated Text:", result)
  168. if __name__ == "__main__":
  169. parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support")
  170. parser.add_argument("--image_path", type=str, help="Path to the input image")
  171. parser.add_argument("--prompt_text", type=str, help="Prompt text for the image")
  172. parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
  173. parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
  174. parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name")
  175. parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights")
  176. parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI")
  177. args = parser.parse_args()
  178. main(args)