|
@@ -4,7 +4,11 @@ import argparse
|
|
|
from PIL import Image as PIL_Image
|
|
|
import torch
|
|
|
from transformers import MllamaForConditionalGeneration, MllamaProcessor
|
|
|
+from accelerate import Accelerator
|
|
|
|
|
|
+accelerator = Accelerator()
|
|
|
+
|
|
|
+device = accelerator.device
|
|
|
|
|
|
# Constants
|
|
|
DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
|
@@ -14,9 +18,11 @@ def load_model_and_processor(model_name: str, hf_token: str):
|
|
|
"""
|
|
|
Load the model and processor based on the 11B or 90B model.
|
|
|
"""
|
|
|
- model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16, token=hf_token)
|
|
|
- model = model.bfloat16().cuda()
|
|
|
- processor = MllamaProcessor.from_pretrained(model_name, token=hf_token)
|
|
|
+ model = MllamaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16,use_safetensors=True, device_map=device,
|
|
|
+ token=hf_token)
|
|
|
+ processor = MllamaProcessor.from_pretrained(model_name, token=hf_token,use_safetensors=True)
|
|
|
+
|
|
|
+ model, processor=accelerator.prepare(model, processor)
|
|
|
return model, processor
|
|
|
|
|
|
|
|
@@ -39,7 +45,7 @@ def generate_text_from_image(model, processor, image, prompt_text: str, temperat
|
|
|
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
|
|
|
]
|
|
|
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
|
|
- inputs = processor(image, prompt, return_tensors="pt").to(model.device)
|
|
|
+ inputs = processor(image, prompt, return_tensors="pt").to(device)
|
|
|
output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
|
|
|
return processor.decode(output[0])[len(prompt):]
|
|
|
|
|
@@ -64,4 +70,4 @@ if __name__ == "__main__":
|
|
|
parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token for authentication")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
- main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)
|
|
|
+ main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)
|