|
@@ -4,6 +4,7 @@
|
|
|
# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
|
|
|
|
|
import fire
|
|
|
+import json
|
|
|
import os
|
|
|
import sys
|
|
|
|
|
@@ -47,14 +48,18 @@ def main(
|
|
|
|
|
|
elif not sys.stdin.isatty():
|
|
|
dialogs = "\n".join(sys.stdin.readlines())
|
|
|
+ try:
|
|
|
+ dialogs = json.loads(dialogs)
|
|
|
+ except:
|
|
|
+ print("Could not parse json from stdin. Please provide a json file with the user prompts. Exiting.")
|
|
|
+ sys.exit(1)
|
|
|
else:
|
|
|
print("No user prompt provided. Exiting.")
|
|
|
sys.exit(1)
|
|
|
|
|
|
print(f"User dialogs:\n{dialogs}")
|
|
|
print("\n==================================\n")
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
# Set the seeds for reproducibility
|
|
|
if is_xpu_available():
|
|
|
torch.xpu.manual_seed(seed)
|
|
@@ -66,14 +71,8 @@ def main(
|
|
|
model = load_peft_model(model, peft_model)
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
- tokenizer.add_special_tokens(
|
|
|
- {
|
|
|
-
|
|
|
- "pad_token": "<PAD>",
|
|
|
- }
|
|
|
- )
|
|
|
|
|
|
- chats = tokenizer.apply_chat_template(dialogs)
|
|
|
+ chats = [tokenizer.apply_chat_template(dialog) for dialog in dialogs]
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for idx, chat in enumerate(chats):
|
|
@@ -99,12 +98,14 @@ def main(
|
|
|
sys.exit(1) # Exit the program with an error status
|
|
|
tokens= torch.tensor(chat).long()
|
|
|
tokens= tokens.unsqueeze(0)
|
|
|
+ attention_mask = torch.ones_like(tokens)
|
|
|
if is_xpu_available():
|
|
|
tokens= tokens.to("xpu:0")
|
|
|
else:
|
|
|
tokens= tokens.to("cuda:0")
|
|
|
outputs = model.generate(
|
|
|
input_ids=tokens,
|
|
|
+ attention_mask=attention_mask,
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
do_sample=do_sample,
|
|
|
top_p=top_p,
|