|
@@ -5,21 +5,13 @@ import re
|
|
|
import sqlite3
|
|
import sqlite3
|
|
|
from typing import Dict
|
|
from typing import Dict
|
|
|
|
|
|
|
|
-import torch
|
|
|
|
|
-from langchain_together import ChatTogether
|
|
|
|
|
from llama_api_client import LlamaAPIClient
|
|
from llama_api_client import LlamaAPIClient
|
|
|
-from peft import AutoPeftModelForCausalLM
|
|
|
|
|
-from tqdm import tqdm
|
|
|
|
|
-from transformers import (
|
|
|
|
|
- AutoModelForCausalLM,
|
|
|
|
|
- AutoTokenizer,
|
|
|
|
|
- BitsAndBytesConfig,
|
|
|
|
|
- pipeline,
|
|
|
|
|
-)
|
|
|
|
|
-
|
|
|
|
|
-MAX_NEW_TOKENS=10240 # If API has max tokens (vs max new tokens), we calculate it
|
|
|
|
|
-
|
|
|
|
|
-def local_llama(prompt, pipe):
|
|
|
|
|
|
|
+
|
|
|
|
|
+MAX_NEW_TOKENS = 10240 # If API has max tokens (vs max new tokens), we calculate it
|
|
|
|
|
+TIMEOUT = 60 # Timeout in seconds for each API call
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def local_llama(client, prompt, model):
|
|
|
SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
|
|
SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
|
|
|
# UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
|
|
# UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
|
|
|
# SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
|
|
# SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
|
|
@@ -28,27 +20,13 @@ def local_llama(prompt, pipe):
|
|
|
{"content": SYSTEM_PROMPT, "role": "system"},
|
|
{"content": SYSTEM_PROMPT, "role": "system"},
|
|
|
{"role": "user", "content": prompt},
|
|
{"role": "user", "content": prompt},
|
|
|
]
|
|
]
|
|
|
-
|
|
|
|
|
- raw_prompt = pipe.tokenizer.apply_chat_template(
|
|
|
|
|
- messages,
|
|
|
|
|
- tokenize=False,
|
|
|
|
|
- add_generation_prompt=True,
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- print(f"local_llama: {raw_prompt=}")
|
|
|
|
|
-
|
|
|
|
|
- outputs = pipe(
|
|
|
|
|
- raw_prompt,
|
|
|
|
|
- max_new_tokens=MAX_NEW_TOKENS,
|
|
|
|
|
- do_sample=False,
|
|
|
|
|
- temperature=0.0,
|
|
|
|
|
- top_k=50,
|
|
|
|
|
- top_p=0.1,
|
|
|
|
|
- eos_token_id=pipe.tokenizer.eos_token_id,
|
|
|
|
|
- pad_token_id=pipe.tokenizer.pad_token_id,
|
|
|
|
|
|
|
+ print(f"local_llama: {model=}")
|
|
|
|
|
+ chat_response = client.chat.completions.create(
|
|
|
|
|
+ model=model,
|
|
|
|
|
+ messages=messages,
|
|
|
|
|
+ timeout=TIMEOUT,
|
|
|
)
|
|
)
|
|
|
-
|
|
|
|
|
- answer = outputs[0]["generated_text"][len(raw_prompt) :].strip()
|
|
|
|
|
|
|
+ answer = chat_response.choices[0].message.content.strip()
|
|
|
|
|
|
|
|
pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
|
|
pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
|
|
|
matches = pattern.findall(answer)
|
|
matches = pattern.findall(answer)
|
|
@@ -98,12 +76,9 @@ def nice_look_table(column_names: list, values: list):
|
|
|
for i in range(len(column_names))
|
|
for i in range(len(column_names))
|
|
|
]
|
|
]
|
|
|
|
|
|
|
|
- # Print the column names
|
|
|
|
|
header = "".join(
|
|
header = "".join(
|
|
|
f"{column.rjust(width)} " for column, width in zip(column_names, widths)
|
|
f"{column.rjust(width)} " for column, width in zip(column_names, widths)
|
|
|
)
|
|
)
|
|
|
- # print(header)
|
|
|
|
|
- # Print the values
|
|
|
|
|
for value in values:
|
|
for value in values:
|
|
|
row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
|
|
row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
|
|
|
rows.append(row)
|
|
rows.append(row)
|
|
@@ -176,33 +151,22 @@ def generate_combined_prompts_one(db_path, question, knowledge=None):
|
|
|
return combined_prompts
|
|
return combined_prompts
|
|
|
|
|
|
|
|
|
|
|
|
|
-def cloud_llama(api_key, model, prompt, max_tokens, temperature, stop):
|
|
|
|
|
|
|
+def cloud_llama(client, api_key, model, prompt, max_tokens, temperature, stop):
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
|
|
SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
|
|
|
try:
|
|
try:
|
|
|
- if model.startswith("meta-llama/"):
|
|
|
|
|
- final_prompt = SYSTEM_PROMPT + "\n\n" + prompt
|
|
|
|
|
- final_max_tokens = len(final_prompt) + MAX_NEW_TOKENS
|
|
|
|
|
- llm = ChatTogether(
|
|
|
|
|
- model=model,
|
|
|
|
|
- temperature=0,
|
|
|
|
|
- max_tokens=final_max_tokens,
|
|
|
|
|
- )
|
|
|
|
|
- answer = llm.invoke(final_prompt).content
|
|
|
|
|
- else:
|
|
|
|
|
- client = LlamaAPIClient()
|
|
|
|
|
- messages = [
|
|
|
|
|
- {"content": SYSTEM_PROMPT, "role": "system"},
|
|
|
|
|
- {"role": "user", "content": prompt},
|
|
|
|
|
- ]
|
|
|
|
|
- final_max_tokens = len(messages) + MAX_NEW_TOKENS
|
|
|
|
|
- response = client.chat.completions.create(
|
|
|
|
|
- model=model,
|
|
|
|
|
- messages=messages,
|
|
|
|
|
- temperature=0,
|
|
|
|
|
- max_completion_tokens=final_max_tokens,
|
|
|
|
|
- )
|
|
|
|
|
- answer = response.completion_message.content.text
|
|
|
|
|
|
|
+ messages = [
|
|
|
|
|
+ {"content": SYSTEM_PROMPT, "role": "system"},
|
|
|
|
|
+ {"role": "user", "content": prompt},
|
|
|
|
|
+ ]
|
|
|
|
|
+ final_max_tokens = len(messages) + MAX_NEW_TOKENS
|
|
|
|
|
+ response = client.chat.completions.create(
|
|
|
|
|
+ model=model,
|
|
|
|
|
+ messages=messages,
|
|
|
|
|
+ temperature=0,
|
|
|
|
|
+ max_completion_tokens=final_max_tokens,
|
|
|
|
|
+ )
|
|
|
|
|
+ answer = response.completion_message.content.text
|
|
|
|
|
|
|
|
pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
|
|
pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
|
|
|
matches = pattern.findall(answer)
|
|
matches = pattern.findall(answer)
|
|
@@ -218,57 +182,6 @@ def cloud_llama(api_key, model, prompt, max_tokens, temperature, stop):
|
|
|
return result
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
-def huggingface_finetuned(api_key, model):
|
|
|
|
|
- if api_key == "finetuned":
|
|
|
|
|
- model_id = model
|
|
|
|
|
-
|
|
|
|
|
- # Check if this is a PEFT model by looking for adapter_config.json
|
|
|
|
|
- import os
|
|
|
|
|
-
|
|
|
|
|
- is_peft_model = os.path.exists(os.path.join(model_id, "adapter_config.json"))
|
|
|
|
|
-
|
|
|
|
|
- if is_peft_model:
|
|
|
|
|
- # Use AutoPeftModelForCausalLM for PEFT fine-tuned models
|
|
|
|
|
- print(f"Loading PEFT model from {model_id}")
|
|
|
|
|
- model = AutoPeftModelForCausalLM.from_pretrained(
|
|
|
|
|
- model_id, device_map="auto", torch_dtype=torch.float16
|
|
|
|
|
- )
|
|
|
|
|
- tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
|
- else:
|
|
|
|
|
- # Use AutoModelForCausalLM for FFT (Full Fine-Tuning) models
|
|
|
|
|
- print(f"Loading FFT model from {model_id}")
|
|
|
|
|
- model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
- model_id, device_map="auto", torch_dtype=torch.float16
|
|
|
|
|
- )
|
|
|
|
|
- tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
|
-
|
|
|
|
|
- # For FFT models, handle pad token if it was added during training
|
|
|
|
|
- if tokenizer.pad_token is None:
|
|
|
|
|
- tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
|
|
|
|
- model.resize_token_embeddings(len(tokenizer))
|
|
|
|
|
-
|
|
|
|
|
- tokenizer.padding_side = "right" # to prevent warnings
|
|
|
|
|
-
|
|
|
|
|
- elif api_key == "huggingface":
|
|
|
|
|
- model_id = model
|
|
|
|
|
- bnb_config = BitsAndBytesConfig(
|
|
|
|
|
- load_in_4bit=True,
|
|
|
|
|
- bnb_4bit_use_double_quant=True,
|
|
|
|
|
- bnb_4bit_quant_type="nf4",
|
|
|
|
|
- bnb_4bit_compute_dtype=torch.bfloat16,
|
|
|
|
|
- )
|
|
|
|
|
- model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
- model_id,
|
|
|
|
|
- device_map="auto",
|
|
|
|
|
- torch_dtype=torch.bfloat16,
|
|
|
|
|
- quantization_config=bnb_config, # None
|
|
|
|
|
- )
|
|
|
|
|
- tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
|
-
|
|
|
|
|
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
|
|
|
|
- return pipe
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
def collect_response_from_llama(
|
|
def collect_response_from_llama(
|
|
|
db_path_list, question_list, api_key, model, knowledge_list=None
|
|
db_path_list, question_list, api_key, model, knowledge_list=None
|
|
|
):
|
|
):
|
|
@@ -280,9 +193,19 @@ def collect_response_from_llama(
|
|
|
response_list = []
|
|
response_list = []
|
|
|
|
|
|
|
|
if api_key in ["huggingface", "finetuned"]:
|
|
if api_key in ["huggingface", "finetuned"]:
|
|
|
- pipe = huggingface_finetuned(api_key=api_key, model=model)
|
|
|
|
|
|
|
+ from openai import OpenAI
|
|
|
|
|
+
|
|
|
|
|
+ openai_api_key = "EMPTY"
|
|
|
|
|
+ openai_api_base = "http://localhost:8000/v1"
|
|
|
|
|
|
|
|
- for i, question in tqdm(enumerate(question_list)):
|
|
|
|
|
|
|
+ client = OpenAI(
|
|
|
|
|
+ api_key=openai_api_key,
|
|
|
|
|
+ base_url=openai_api_base,
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ client = LlamaAPIClient()
|
|
|
|
|
+
|
|
|
|
|
+ for i, question in enumerate(question_list):
|
|
|
print(
|
|
print(
|
|
|
"--------------------- processing question #{}---------------------".format(
|
|
"--------------------- processing question #{}---------------------".format(
|
|
|
i + 1
|
|
i + 1
|
|
@@ -300,9 +223,11 @@ def collect_response_from_llama(
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
if api_key in ["huggingface", "finetuned"]:
|
|
if api_key in ["huggingface", "finetuned"]:
|
|
|
- plain_result = local_llama(prompt=cur_prompt, pipe=pipe)
|
|
|
|
|
|
|
+ plain_result = local_llama(client=client, prompt=cur_prompt, model=model)
|
|
|
else:
|
|
else:
|
|
|
|
|
+
|
|
|
plain_result = cloud_llama(
|
|
plain_result = cloud_llama(
|
|
|
|
|
+ client=client,
|
|
|
api_key=api_key,
|
|
api_key=api_key,
|
|
|
model=model,
|
|
model=model,
|
|
|
prompt=cur_prompt,
|
|
prompt=cur_prompt,
|
|
@@ -310,7 +235,7 @@ def collect_response_from_llama(
|
|
|
temperature=0,
|
|
temperature=0,
|
|
|
stop=["--", "\n\n", ";", "#"],
|
|
stop=["--", "\n\n", ";", "#"],
|
|
|
)
|
|
)
|
|
|
- if isinstance(plain_result, str):
|
|
|
|
|
|
|
+ if type(plain_result) == str:
|
|
|
sql = plain_result
|
|
sql = plain_result
|
|
|
else:
|
|
else:
|
|
|
sql = "SELECT" + plain_result["choices"][0]["text"]
|
|
sql = "SELECT" + plain_result["choices"][0]["text"]
|
|
@@ -379,37 +304,23 @@ if __name__ == "__main__":
|
|
|
args_parser.add_argument("--data_output_path", type=str)
|
|
args_parser.add_argument("--data_output_path", type=str)
|
|
|
args = args_parser.parse_args()
|
|
args = args_parser.parse_args()
|
|
|
|
|
|
|
|
- if args.api_key not in ["huggingface", "finetuned"]:
|
|
|
|
|
- if args.model.startswith("meta-llama/"): # Llama model on together
|
|
|
|
|
|
|
+ if not args.api_key in ["huggingface", "finetuned"]:
|
|
|
|
|
+ os.environ["LLAMA_API_KEY"] = args.api_key
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ client = LlamaAPIClient()
|
|
|
|
|
|
|
|
- os.environ["TOGETHER_API_KEY"] = args.api_key
|
|
|
|
|
- llm = ChatTogether(
|
|
|
|
|
|
|
+ response = client.chat.completions.create(
|
|
|
model=args.model,
|
|
model=args.model,
|
|
|
|
|
+ messages=[{"role": "user", "content": "125*125 is?"}],
|
|
|
temperature=0,
|
|
temperature=0,
|
|
|
)
|
|
)
|
|
|
- try:
|
|
|
|
|
- response = llm.invoke("125*125 is?").content
|
|
|
|
|
- print(f"{response=}")
|
|
|
|
|
- except Exception as exception:
|
|
|
|
|
- print(f"{exception=}")
|
|
|
|
|
- exit(1)
|
|
|
|
|
- else: # Llama model on Llama API
|
|
|
|
|
- os.environ["LLAMA_API_KEY"] = args.api_key
|
|
|
|
|
-
|
|
|
|
|
- try:
|
|
|
|
|
- client = LlamaAPIClient()
|
|
|
|
|
-
|
|
|
|
|
- response = client.chat.completions.create(
|
|
|
|
|
- model=args.model,
|
|
|
|
|
- messages=[{"role": "user", "content": "125*125 is?"}],
|
|
|
|
|
- temperature=0,
|
|
|
|
|
- )
|
|
|
|
|
- answer = response.completion_message.content.text
|
|
|
|
|
|
|
+ answer = response.completion_message.content.text
|
|
|
|
|
|
|
|
- print(f"{answer=}")
|
|
|
|
|
- except Exception as exception:
|
|
|
|
|
- print(f"{exception=}")
|
|
|
|
|
- exit(1)
|
|
|
|
|
|
|
+ print(f"{answer=}")
|
|
|
|
|
+ except Exception as exception:
|
|
|
|
|
+ print(f"{exception=}")
|
|
|
|
|
+ exit(1)
|
|
|
|
|
|
|
|
eval_data = json.load(open(args.eval_path, "r"))
|
|
eval_data = json.load(open(args.eval_path, "r"))
|
|
|
# '''for debug'''
|
|
# '''for debug'''
|
|
@@ -422,7 +333,7 @@ if __name__ == "__main__":
|
|
|
assert len(question_list) == len(db_path_list) == len(knowledge_list)
|
|
assert len(question_list) == len(db_path_list) == len(knowledge_list)
|
|
|
|
|
|
|
|
if args.use_knowledge == "True":
|
|
if args.use_knowledge == "True":
|
|
|
- responses = collect_response_from_llama(
|
|
|
|
|
|
|
+ responses = collect_response_from_llama( # collect_batch_response_from_llama
|
|
|
db_path_list=db_path_list,
|
|
db_path_list=db_path_list,
|
|
|
question_list=question_list,
|
|
question_list=question_list,
|
|
|
api_key=args.api_key,
|
|
api_key=args.api_key,
|