multi_modal_infer.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import argparse
  2. import os
  3. import sys
  4. import torch
  5. from accelerate import Accelerator
  6. from PIL import Image as PIL_Image
  7. from transformers import MllamaForConditionalGeneration, MllamaProcessor
  8. accelerator = Accelerator()
  9. device = accelerator.device
  10. # Constants
  11. DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
  12. def load_model_and_processor(model_name: str):
  13. """
  14. Load the model and processor based on the 11B or 90B model.
  15. """
  16. model = MllamaForConditionalGeneration.from_pretrained(
  17. model_name,
  18. torch_dtype=torch.bfloat16,
  19. use_safetensors=True,
  20. device_map=device,
  21. )
  22. processor = MllamaProcessor.from_pretrained(model_name, use_safetensors=True)
  23. model, processor = accelerator.prepare(model, processor)
  24. return model, processor
  25. def process_image(image_path: str) -> PIL_Image.Image:
  26. """
  27. Open and convert an image from the specified path.
  28. """
  29. if not os.path.exists(image_path):
  30. print(f"The image file '{image_path}' does not exist.")
  31. sys.exit(1)
  32. with open(image_path, "rb") as f:
  33. return PIL_Image.open(f).convert("RGB")
  34. def generate_text_from_image(
  35. model, processor, image, prompt_text: str, temperature: float, top_p: float
  36. ):
  37. """
  38. Generate text from an image using the model and processor.
  39. """
  40. conversation = [
  41. {
  42. "role": "user",
  43. "content": [{"type": "image"}, {"type": "text", "text": prompt_text}],
  44. }
  45. ]
  46. prompt = processor.apply_chat_template(
  47. conversation, add_generation_prompt=True, tokenize=False
  48. )
  49. inputs = processor(image, prompt, return_tensors="pt").to(device)
  50. output = model.generate(
  51. **inputs, temperature=temperature, top_p=top_p, max_new_tokens=512
  52. )
  53. return processor.decode(output[0])[len(prompt) :]
  54. def main(
  55. image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str
  56. ):
  57. """
  58. Call all the functions.
  59. """
  60. model, processor = load_model_and_processor(model_name)
  61. image = process_image(image_path)
  62. result = generate_text_from_image(
  63. model, processor, image, prompt_text, temperature, top_p
  64. )
  65. print("Generated Text: " + result)
  66. if __name__ == "__main__":
  67. parser = argparse.ArgumentParser(
  68. description="Generate text from an image and prompt using the 3.2 MM Llama model."
  69. )
  70. parser.add_argument("--image_path", type=str, help="Path to the image file")
  71. parser.add_argument(
  72. "--prompt_text", type=str, help="Prompt text to describe the image"
  73. )
  74. parser.add_argument(
  75. "--temperature",
  76. type=float,
  77. default=0.7,
  78. help="Temperature for generation (default: 0.7)",
  79. )
  80. parser.add_argument(
  81. "--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)"
  82. )
  83. parser.add_argument(
  84. "--model_name",
  85. type=str,
  86. default=DEFAULT_MODEL,
  87. help=f"Model name (default: '{DEFAULT_MODEL}')",
  88. )
  89. args = parser.parse_args()
  90. main(
  91. args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name
  92. )