瀏覽代碼

passing input_ids as peft doesn't pass position args to base_model

Hamid Shojanazeri 1 年之前
父節點
當前提交
0e9d1dfa78
共有 1 個文件被更改,包括 1 次插入1 次删除
  1. 1 1
      inference/chat_completion.py

+ 1 - 1
inference/chat_completion.py

@@ -107,7 +107,7 @@ def main(
             tokens= tokens.unsqueeze(0)
             tokens= tokens.unsqueeze(0)
             tokens= tokens.to("cuda:0")
             tokens= tokens.to("cuda:0")
             outputs = model.generate(
             outputs = model.generate(
-                tokens,
+                input_ids=tokens,
                 max_new_tokens=max_new_tokens,
                 max_new_tokens=max_new_tokens,
                 do_sample=do_sample,
                 do_sample=do_sample,
                 top_p=top_p,
                 top_p=top_p,