|
@@ -15,6 +15,7 @@ def load_model_and_processor(model_name: str, hf_token: str):
|
|
|
Load the model and processor based on the 11B or 90B model.
|
|
|
"""
|
|
|
model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16, token=hf_token)
|
|
|
+ model = model.bfloat16().cuda()
|
|
|
processor = MllamaProcessor.from_pretrained(model_name, token=hf_token)
|
|
|
return model, processor
|
|
|
|