multi_modal_infer.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import os
  2. import sys
  3. import argparse
  4. from PIL import Image as PIL_Image
  5. import torch
  6. from transformers import MllamaForConditionalGeneration, MllamaProcessor
  7. from accelerate import Accelerator
  8. from peft import PeftModel # Make sure to install the `peft` library
  9. accelerator = Accelerator()
  10. device = accelerator.device
  11. # Constants
  12. DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
  13. def load_model_and_processor(model_name: str, hf_token: str, finetuning_path: str = None):
  14. """
  15. Load the model and processor, and optionally load adapter weights if specified
  16. """
  17. # Load pre-trained model and processor
  18. model = MllamaForConditionalGeneration.from_pretrained(
  19. model_name,
  20. torch_dtype=torch.bfloat16,
  21. use_safetensors=True,
  22. device_map=device,
  23. token=hf_token
  24. )
  25. processor = MllamaProcessor.from_pretrained(
  26. model_name,
  27. token=hf_token,
  28. use_safetensors=True
  29. )
  30. # If a finetuning path is provided, load the adapter model
  31. if finetuning_path and os.path.exists(finetuning_path):
  32. adapter_weights_path = os.path.join(finetuning_path, "adapter_model.safetensors")
  33. adapter_config_path = os.path.join(finetuning_path, "adapter_config.json")
  34. if os.path.exists(adapter_weights_path) and os.path.exists(adapter_config_path):
  35. print(f"Loading adapter from '{finetuning_path}'...")
  36. # Load the model with adapters using `peft`
  37. model = PeftModel.from_pretrained(
  38. model,
  39. finetuning_path, # This should be the folder containing the adapter files
  40. is_adapter=True,
  41. torch_dtype=torch.bfloat16
  42. )
  43. print("Adapter merged successfully with the pre-trained model.")
  44. else:
  45. print(f"Adapter files not found in '{finetuning_path}'. Using pre-trained model only.")
  46. else:
  47. print(f"No fine-tuned weights or adapters found in '{finetuning_path}'. Using pre-trained model only.")
  48. # Prepare the model and processor for accelerated training
  49. model, processor = accelerator.prepare(model, processor)
  50. return model, processor
  51. def process_image(image_path: str) -> PIL_Image.Image:
  52. """
  53. Open and convert an image from the specified path.
  54. """
  55. if not os.path.exists(image_path):
  56. print(f"The image file '{image_path}' does not exist.")
  57. sys.exit(1)
  58. with open(image_path, "rb") as f:
  59. return PIL_Image.open(f).convert("RGB")
  60. def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
  61. """
  62. Generate text from an image using the model and processor.
  63. """
  64. conversation = [
  65. {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
  66. ]
  67. prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
  68. inputs = processor(image, prompt, return_tensors="pt").to(device)
  69. output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=2048)
  70. return processor.decode(output[0])[len(prompt):]
  71. def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str, hf_token: str, finetuning_path: str = None):
  72. """
  73. Call all the functions and optionally merge adapter weights from a specified path.
  74. """
  75. model, processor = load_model_and_processor(model_name, hf_token, finetuning_path)
  76. image = process_image(image_path)
  77. result = generate_text_from_image(model, processor, image, prompt_text, temperature, top_p)
  78. print("Generated Text: " + result)
  79. if __name__ == "__main__":
  80. # Example usage with argparse (optional)
  81. parser = argparse.ArgumentParser(description="Generate text from an image using a fine-tuned model with adapters.")
  82. parser.add_argument("--image_path", type=str, required=True, help="Path to the input image.")
  83. parser.add_argument("--prompt_text", type=str, required=True, help="Prompt text for the image.")
  84. parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.")
  85. parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling.")
  86. parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Pre-trained model name.")
  87. parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face API token.")
  88. parser.add_argument("--finetuning_path", type=str, help="Path to the fine-tuning weights (adapters).")
  89. args = parser.parse_args()
  90. main(
  91. image_path=args.image_path,
  92. prompt_text=args.prompt_text,
  93. temperature=args.temperature,
  94. top_p=args.top_p,
  95. model_name=args.model_name,
  96. hf_token=args.hf_token,
  97. finetuning_path=args.finetuning_path
  98. )