|
@@ -14,6 +14,7 @@ from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM
|
|
from safety_utils import get_safety_checker
|
|
from safety_utils import get_safety_checker
|
|
from model_utils import load_model, load_peft_model
|
|
from model_utils import load_model, load_peft_model
|
|
from chat_utils import read_dialogs_from_file, format_tokens
|
|
from chat_utils import read_dialogs_from_file, format_tokens
|
|
|
|
+from accelerate.utils import is_xpu_available
|
|
|
|
|
|
def main(
|
|
def main(
|
|
model_name,
|
|
model_name,
|