multi_modal_infer.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import os
  2. import sys
  3. import argparse
  4. from PIL import Image as PIL_Image
  5. import torch
  6. from transformers import MllamaForConditionalGeneration, MllamaProcessor
  7. from accelerate import Accelerator
  8. accelerator = Accelerator()
  9. device = accelerator.device
  10. # Constants
  11. DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
  12. def load_model_and_processor(model_name: str, hf_token: str):
  13. """
  14. Load the model and processor based on the 11B or 90B model.
  15. """
  16. model = MllamaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16,use_safetensors=True, device_map=device,
  17. token=hf_token)
  18. processor = MllamaProcessor.from_pretrained(model_name, token=hf_token,use_safetensors=True)
  19. model, processor=accelerator.prepare(model, processor)
  20. return model, processor
  21. def process_image(image_path: str) -> PIL_Image.Image:
  22. """
  23. Open and convert an image from the specified path.
  24. """
  25. if not os.path.exists(image_path):
  26. print(f"The image file '{image_path}' does not exist.")
  27. sys.exit(1)
  28. with open(image_path, "rb") as f:
  29. return PIL_Image.open(f).convert("RGB")
  30. def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
  31. """
  32. Generate text from an image using the model and processor.
  33. """
  34. conversation = [
  35. {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
  36. ]
  37. prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
  38. inputs = processor(image, prompt, return_tensors="pt").to(device)
  39. output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
  40. return processor.decode(output[0])[len(prompt):]
  41. def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str, hf_token: str):
  42. """
  43. Call all the functions.
  44. """
  45. model, processor = load_model_and_processor(model_name, hf_token)
  46. image = process_image(image_path)
  47. result = generate_text_from_image(model, processor, image, prompt_text, temperature, top_p)
  48. print("Generated Text: " + result)
  49. if __name__ == "__main__":
  50. parser = argparse.ArgumentParser(description="Generate text from an image and prompt using the 3.2 MM Llama model.")
  51. parser.add_argument("--image_path", type=str, help="Path to the image file")
  52. parser.add_argument("--prompt_text", type=str, help="Prompt text to describe the image")
  53. parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation (default: 0.7)")
  54. parser.add_argument("--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)")
  55. parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help=f"Model name (default: '{DEFAULT_MODEL}')")
  56. parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token for authentication")
  57. args = parser.parse_args()
  58. main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)