浏览代码

removed dependency of singly GPU using accelerate

himanshushukla12 6 月之前
父节点
当前提交
625860d3db
共有 1 个文件被更改,包括 11 次插入5 次删除
  1. 11 5
      recipes/quickstart/inference/local_inference/multi_modal_infer.py

+ 11 - 5
recipes/quickstart/inference/local_inference/multi_modal_infer.py

@@ -4,7 +4,11 @@ import argparse
 from PIL import Image as PIL_Image
 import torch
 from transformers import MllamaForConditionalGeneration, MllamaProcessor
+from accelerate import  Accelerator
 
+accelerator = Accelerator()
+
+device = accelerator.device
 
 # Constants
 DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
@@ -14,9 +18,11 @@ def load_model_and_processor(model_name: str, hf_token: str):
     """
     Load the model and processor based on the 11B or 90B model.
     """
-    model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16, token=hf_token)
-    model = model.bfloat16().cuda()
-    processor = MllamaProcessor.from_pretrained(model_name, token=hf_token)
+    model = MllamaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16,use_safetensors=True, device_map=device,
+                                                            token=hf_token)
+    processor = MllamaProcessor.from_pretrained(model_name, token=hf_token,use_safetensors=True)
+
+    model, processor=accelerator.prepare(model, processor)
     return model, processor
 
 
@@ -39,7 +45,7 @@ def generate_text_from_image(model, processor, image, prompt_text: str, temperat
         {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
     ]
     prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
-    inputs = processor(image, prompt, return_tensors="pt").to(model.device)
+    inputs = processor(image, prompt, return_tensors="pt").to(device)
     output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
     return processor.decode(output[0])[len(prompt):]
 
@@ -64,4 +70,4 @@ if __name__ == "__main__":
     parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token for authentication")
 
     args = parser.parse_args()
-    main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)
+    main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)