multi_modal_infer.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import argparse
  2. import os
  3. import sys
  4. import gradio as gr
  5. import torch
  6. from accelerate import Accelerator
  7. from huggingface_hub import HfFolder
  8. from peft import PeftModel
  9. from PIL import Image as PIL_Image
  10. from transformers import MllamaForConditionalGeneration, MllamaProcessor
  11. # Initialize accelerator
  12. accelerator = Accelerator()
  13. device = accelerator.device
  14. # Constants
  15. DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
  16. MAX_OUTPUT_TOKENS = 2048
  17. MAX_IMAGE_SIZE = (1120, 1120)
  18. def get_hf_token():
  19. """Retrieve Hugging Face token from the cache or environment."""
  20. # Check if a token is explicitly set in the environment
  21. token = os.getenv("HUGGINGFACE_TOKEN")
  22. if token:
  23. return token
  24. # Automatically retrieve the token from the Hugging Face cache (set via huggingface-cli login)
  25. token = HfFolder.get_token()
  26. if token:
  27. return token
  28. print("Hugging Face token not found. Please login using `huggingface-cli login`.")
  29. sys.exit(1)
  30. def load_model_and_processor(model_name: str, finetuning_path: str = None):
  31. """Load model and processor with optional LoRA adapter"""
  32. print(f"Loading model: {model_name}")
  33. hf_token = get_hf_token()
  34. model = MllamaForConditionalGeneration.from_pretrained(
  35. model_name,
  36. torch_dtype=torch.bfloat16,
  37. use_safetensors=True,
  38. device_map=device,
  39. token=hf_token,
  40. )
  41. processor = MllamaProcessor.from_pretrained(
  42. model_name, token=hf_token, use_safetensors=True
  43. )
  44. if finetuning_path and os.path.exists(finetuning_path):
  45. print(f"Loading LoRA adapter from '{finetuning_path}'...")
  46. model = PeftModel.from_pretrained(
  47. model, finetuning_path, is_adapter=True, torch_dtype=torch.bfloat16
  48. )
  49. print("LoRA adapter merged successfully")
  50. model, processor = accelerator.prepare(model, processor)
  51. return model, processor
  52. def process_image(image_path: str = None, image=None) -> PIL_Image.Image:
  53. """Process and validate image input"""
  54. if image is not None:
  55. return image.convert("RGB")
  56. if image_path and os.path.exists(image_path):
  57. return PIL_Image.open(image_path).convert("RGB")
  58. raise ValueError("No valid image provided")
  59. def generate_text_from_image(
  60. model, processor, image, prompt_text: str, temperature: float, top_p: float
  61. ):
  62. """Generate text from image using model"""
  63. conversation = [
  64. {
  65. "role": "user",
  66. "content": [{"type": "image"}, {"type": "text", "text": prompt_text}],
  67. }
  68. ]
  69. prompt = processor.apply_chat_template(
  70. conversation, add_generation_prompt=True, tokenize=False
  71. )
  72. inputs = processor(
  73. image, prompt, text_kwargs={"add_special_tokens": False}, return_tensors="pt"
  74. ).to(device)
  75. print("Input Prompt:\n", processor.tokenizer.decode(inputs.input_ids[0]))
  76. output = model.generate(
  77. **inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS
  78. )
  79. return processor.decode(output[0])[len(prompt) :]
  80. def gradio_interface(model_name: str):
  81. """Create Gradio UI with LoRA support"""
  82. # Initialize model state
  83. current_model = {"model": None, "processor": None}
  84. def load_or_reload_model(enable_lora: bool, lora_path: str = None):
  85. current_model["model"], current_model["processor"] = load_model_and_processor(
  86. model_name, lora_path if enable_lora else None
  87. )
  88. return "Model loaded successfully" + (" with LoRA" if enable_lora else "")
  89. def describe_image(
  90. image, user_prompt, temperature, top_k, top_p, max_tokens, history
  91. ):
  92. if image is not None:
  93. try:
  94. processed_image = process_image(image=image)
  95. result = generate_text_from_image(
  96. current_model["model"],
  97. current_model["processor"],
  98. processed_image,
  99. user_prompt,
  100. temperature,
  101. top_p,
  102. )
  103. history.append((user_prompt, result))
  104. except Exception as e:
  105. history.append((user_prompt, f"Error: {str(e)}"))
  106. return history
  107. def clear_chat():
  108. return []
  109. with gr.Blocks() as demo:
  110. gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>")
  111. with gr.Row():
  112. with gr.Column(scale=1):
  113. # Model loading controls
  114. with gr.Group():
  115. enable_lora = gr.Checkbox(label="Enable LoRA", value=False)
  116. lora_path = gr.Textbox(
  117. label="LoRA Weights Path",
  118. placeholder="Path to LoRA weights folder",
  119. visible=False,
  120. )
  121. load_status = gr.Textbox(label="Load Status", interactive=False)
  122. load_button = gr.Button("Load/Reload Model")
  123. # Image and parameter controls
  124. image_input = gr.Image(
  125. label="Image", type="pil", image_mode="RGB", height=512, width=512
  126. )
  127. temperature = gr.Slider(
  128. label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1
  129. )
  130. top_k = gr.Slider(
  131. label="Top-k", minimum=1, maximum=100, value=50, step=1
  132. )
  133. top_p = gr.Slider(
  134. label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1
  135. )
  136. max_tokens = gr.Slider(
  137. label="Max Tokens",
  138. minimum=50,
  139. maximum=MAX_OUTPUT_TOKENS,
  140. value=100,
  141. step=50,
  142. )
  143. with gr.Column(scale=2):
  144. chat_history = gr.Chatbot(label="Chat", height=512)
  145. user_prompt = gr.Textbox(
  146. show_label=False, placeholder="Enter your prompt", lines=2
  147. )
  148. with gr.Row():
  149. generate_button = gr.Button("Generate")
  150. clear_button = gr.Button("Clear")
  151. # Event handlers
  152. enable_lora.change(
  153. fn=lambda x: gr.update(visible=x), inputs=[enable_lora], outputs=[lora_path]
  154. )
  155. load_button.click(
  156. fn=load_or_reload_model,
  157. inputs=[enable_lora, lora_path],
  158. outputs=[load_status],
  159. )
  160. generate_button.click(
  161. fn=describe_image,
  162. inputs=[
  163. image_input,
  164. user_prompt,
  165. temperature,
  166. top_k,
  167. top_p,
  168. max_tokens,
  169. chat_history,
  170. ],
  171. outputs=[chat_history],
  172. )
  173. clear_button.click(fn=clear_chat, outputs=[chat_history])
  174. # Initial model load
  175. load_or_reload_model(False)
  176. return demo
  177. def main(args):
  178. """Main execution flow"""
  179. if args.gradio_ui:
  180. demo = gradio_interface(args.model_name)
  181. demo.launch()
  182. else:
  183. model, processor = load_model_and_processor(
  184. args.model_name, args.finetuning_path
  185. )
  186. image = process_image(image_path=args.image_path)
  187. result = generate_text_from_image(
  188. model, processor, image, args.prompt_text, args.temperature, args.top_p
  189. )
  190. print("Generated Text:", result)
  191. if __name__ == "__main__":
  192. parser = argparse.ArgumentParser(
  193. description="Multi-modal inference with optional Gradio UI and LoRA support"
  194. )
  195. parser.add_argument("--image_path", type=str, help="Path to the input image")
  196. parser.add_argument("--prompt_text", type=str, help="Prompt text for the image")
  197. parser.add_argument(
  198. "--temperature", type=float, default=0.7, help="Sampling temperature"
  199. )
  200. parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
  201. parser.add_argument(
  202. "--model_name", type=str, default=DEFAULT_MODEL, help="Model name"
  203. )
  204. parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights")
  205. parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI")
  206. args = parser.parse_args()
  207. main(args)