소스 검색

Fix reading in stdin for chat_completion, remove padding as we're feeding single samples

Matthias Reso 9 달 전
부모
커밋
b36830fdf6
1개의 변경된 파일10개의 추가작업 그리고 9개의 파일을 삭제
  1. 10 9
      recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py

+ 10 - 9
recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py

@@ -4,6 +4,7 @@
 # from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 
 import fire
+import json
 import os
 import sys
 
@@ -47,14 +48,18 @@ def main(
 
     elif not sys.stdin.isatty():
         dialogs = "\n".join(sys.stdin.readlines())
+        try:
+            dialogs = json.loads(dialogs)
+        except:
+            print("Could not parse json from stdin. Please provide a json file with the user prompts. Exiting.")
+            sys.exit(1)
     else:
         print("No user prompt provided. Exiting.")
         sys.exit(1)
 
     print(f"User dialogs:\n{dialogs}")
     print("\n==================================\n")
-
-
+    
     # Set the seeds for reproducibility
     if is_xpu_available():
         torch.xpu.manual_seed(seed)
@@ -66,14 +71,8 @@ def main(
         model = load_peft_model(model, peft_model)
 
     tokenizer = AutoTokenizer.from_pretrained(model_name)
-    tokenizer.add_special_tokens(
-        {
-
-            "pad_token": "<PAD>",
-        }
-    )
 
-    chats = tokenizer.apply_chat_template(dialogs)
+    chats = [tokenizer.apply_chat_template(dialog) for dialog in dialogs]
 
     with torch.no_grad():
         for idx, chat in enumerate(chats):
@@ -99,12 +98,14 @@ def main(
                 sys.exit(1)  # Exit the program with an error status
             tokens= torch.tensor(chat).long()
             tokens= tokens.unsqueeze(0)
+            attention_mask = torch.ones_like(tokens)
             if is_xpu_available():
                 tokens= tokens.to("xpu:0")
             else:
                 tokens= tokens.to("cuda:0")
             outputs = model.generate(
                 input_ids=tokens,
+                attention_mask=attention_mask,
                 max_new_tokens=max_new_tokens,
                 do_sample=do_sample,
                 top_p=top_p,