123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553 |
- import argparse
- import concurrent.futures
- import json
- import os
- import re
- import sqlite3
- from typing import Dict
- from llama_api_client import LlamaAPIClient
- from tqdm import tqdm
- MAX_NEW_TOKENS = 10240 # If API has max tokens (vs max new tokens), we calculate it
- TIMEOUT = 60 # Timeout in seconds for each API call
- def local_llama(client, prompt, model):
- SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
- # UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
- # SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
- messages = [
- {"content": SYSTEM_PROMPT, "role": "system"},
- {"role": "user", "content": prompt},
- ]
- print(f"local_llama: {model=}")
- chat_response = client.chat.completions.create(
- model=model,
- messages=messages,
- timeout=TIMEOUT,
- temperature=0,
- )
- answer = chat_response.choices[0].message.content.strip()
- pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
- matches = pattern.findall(answer)
- if not matches:
- result = answer
- else:
- result = matches[0]
- print(f"{result=}")
- return result
- def batch_local_llama(client, prompts, model, max_workers=8):
- """
- Process multiple prompts in parallel using the vllm server.
- Args:
- client: OpenAI client
- prompts: List of prompts to process
- model: Model name
- max_workers: Maximum number of parallel workers
- Returns:
- List of results in the same order as prompts
- """
- SYSTEM_PROMPT = (
- "You are a text to SQL query translator. Using the SQLite DB Schema "
- "and the External Knowledge, translate the following text question "
- "into a SQLite SQL select statement."
- )
- def process_single_prompt(prompt):
- messages = [
- {"content": SYSTEM_PROMPT, "role": "system"},
- {"role": "user", "content": prompt},
- ]
- try:
- chat_response = client.chat.completions.create(
- model=model,
- messages=messages,
- timeout=TIMEOUT,
- temperature=0,
- )
- answer = chat_response.choices[0].message.content.strip()
- pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
- matches = pattern.findall(answer)
- if not matches:
- result = answer
- else:
- result = matches[0]
- return result
- except Exception as e:
- print(f"Error processing prompt: {e}")
- return f"error:{e}"
- print(
- f"batch_local_llama: Processing {len(prompts)} prompts with {model=} "
- f"using {max_workers} workers"
- )
- results = []
- with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
- # Submit all tasks and create a map of futures to their indices
- future_to_index = {
- executor.submit(process_single_prompt, prompt): i
- for i, prompt in enumerate(prompts)
- }
- # Initialize results list with None values
- results = [None] * len(prompts)
- # Process completed futures as they complete
- for future in tqdm(
- concurrent.futures.as_completed(future_to_index),
- total=len(prompts),
- desc="Processing prompts",
- ):
- index = future_to_index[future]
- try:
- results[index] = future.result()
- except Exception as e:
- print(f"Error processing prompt at index {index}: {e}")
- results[index] = f"error:{e}"
- return results
- def new_directory(path):
- if not os.path.exists(path):
- os.makedirs(path)
- def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
- """
- Read an sqlite file, and return the CREATE commands for each of the tables in the database.
- """
- asdf = "database" if bench_root == "spider" else "databases"
- with sqlite3.connect(
- f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True
- ) as conn:
- # conn.text_factory = bytes
- cursor = conn.cursor()
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
- tables = cursor.fetchall()
- schemas = {}
- for table in tables:
- cursor.execute(
- "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
- table[0]
- )
- )
- schemas[table[0]] = cursor.fetchone()[0]
- return schemas
- def nice_look_table(column_names: list, values: list):
- rows = []
- # Determine the maximum width of each column
- widths = [
- max(len(str(value[i])) for value in values + [column_names])
- for i in range(len(column_names))
- ]
- header = "".join(
- f"{column.rjust(width)} " for column, width in zip(column_names, widths)
- )
- for value in values:
- row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
- rows.append(row)
- rows = "\n".join(rows)
- final_output = header + "\n" + rows
- return final_output
- def generate_schema_prompt(db_path, num_rows=None):
- # extract create ddls
- """
- :param root_place:
- :param db_name:
- :return:
- """
- full_schema_prompt_list = []
- conn = sqlite3.connect(db_path)
- # Create a cursor object
- cursor = conn.cursor()
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
- tables = cursor.fetchall()
- schemas = {}
- for table in tables:
- if table == "sqlite_sequence":
- continue
- cursor.execute(
- "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
- table[0]
- )
- )
- create_prompt = cursor.fetchone()[0]
- schemas[table[0]] = create_prompt
- if num_rows:
- cur_table = table[0]
- if cur_table in ["order", "by", "group"]:
- cur_table = "`{}`".format(cur_table)
- cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
- column_names = [description[0] for description in cursor.description]
- values = cursor.fetchall()
- rows_prompt = nice_look_table(column_names=column_names, values=values)
- verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
- num_rows, cur_table, num_rows, rows_prompt
- )
- schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
- for k, v in schemas.items():
- full_schema_prompt_list.append(v)
- schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
- return schema_prompt
- def generate_comment_prompt(question, knowledge=None):
- knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
- question_prompt = "-- Question: {}".format(question)
- result_prompt = knowledge_prompt + "\n\n" + question_prompt
- return result_prompt
- def generate_combined_prompts_one(db_path, question, knowledge=None):
- schema_prompt = generate_schema_prompt(db_path, num_rows=None)
- comment_prompt = generate_comment_prompt(question, knowledge)
- combined_prompts = schema_prompt + "\n\n" + comment_prompt
- return combined_prompts
- def cloud_llama(client, api_key, model, prompt, max_tokens, temperature, stop):
- SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
- try:
- messages = [
- {"content": SYSTEM_PROMPT, "role": "system"},
- {"role": "user", "content": prompt},
- ]
- final_max_tokens = len(messages) + MAX_NEW_TOKENS
- response = client.chat.completions.create(
- model=model,
- messages=messages,
- temperature=0,
- max_completion_tokens=final_max_tokens,
- )
- answer = response.completion_message.content.text
- pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
- matches = pattern.findall(answer)
- if matches != []:
- result = matches[0]
- else:
- result = answer
- print(result)
- except Exception as e:
- result = "error:{}".format(e)
- print(f"{result=}")
- return result
- def collect_response_from_llama(
- db_path_list, question_list, api_key, model, knowledge_list=None
- ):
- """
- :param db_path: str
- :param question_list: []
- :return: dict of responses
- """
- response_list = []
- if api_key in ["huggingface", "finetuned"]:
- from openai import OpenAI
- openai_api_key = "EMPTY"
- openai_api_base = "http://localhost:8000/v1"
- client = OpenAI(
- api_key=openai_api_key,
- base_url=openai_api_base,
- )
- else:
- client = LlamaAPIClient()
- for i, question in enumerate(question_list):
- print(
- "--------------------- processing question #{}---------------------".format(
- i + 1
- )
- )
- print("the question is: {}".format(question))
- if knowledge_list:
- cur_prompt = generate_combined_prompts_one(
- db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
- )
- else:
- cur_prompt = generate_combined_prompts_one(
- db_path=db_path_list[i], question=question
- )
- if api_key in ["huggingface", "finetuned"]:
- plain_result = local_llama(client=client, prompt=cur_prompt, model=model)
- else:
- plain_result = cloud_llama(
- client=client,
- api_key=api_key,
- model=model,
- prompt=cur_prompt,
- max_tokens=10240,
- temperature=0,
- stop=["--", "\n\n", ";", "#"],
- )
- if isinstance(plain_result, str):
- sql = plain_result
- else:
- sql = "SELECT" + plain_result["choices"][0]["text"]
- # responses_dict[i] = sql
- db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
- sql = (
- sql + "\t----- bird -----\t" + db_id
- ) # to avoid unpredicted \t appearing in codex results
- response_list.append(sql)
- return response_list
- def batch_collect_response_from_llama(
- db_path_list, question_list, api_key, model, knowledge_list=None, batch_size=8
- ):
- """
- Process multiple questions in parallel using the vllm server.
- Args:
- db_path_list: List of database paths
- question_list: List of questions
- api_key: API key
- model: Model name
- knowledge_list: List of knowledge strings (optional)
- batch_size: Number of parallel requests
- Returns:
- List of SQL responses
- """
- if api_key in ["huggingface", "finetuned"]:
- from openai import OpenAI
- openai_api_key = "EMPTY"
- openai_api_base = "http://localhost:8000/v1"
- client = OpenAI(
- api_key=openai_api_key,
- base_url=openai_api_base,
- )
- else:
- client = LlamaAPIClient()
- # Generate all prompts first
- prompts = []
- for i, question in enumerate(question_list):
- if knowledge_list:
- cur_prompt = generate_combined_prompts_one(
- db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
- )
- else:
- cur_prompt = generate_combined_prompts_one(
- db_path=db_path_list[i], question=question
- )
- prompts.append(cur_prompt)
- print(f"Generated {len(prompts)} prompts for batch processing")
- # Process prompts in parallel
- if api_key in ["huggingface", "finetuned"]:
- results = batch_local_llama(
- client=client, prompts=prompts, model=model, max_workers=batch_size
- )
- else:
- # For cloud API, we could implement a batch version of cloud_llama if needed
- # For now, just process sequentially
- results = []
- for prompt in prompts:
- plain_result = cloud_llama(
- client=client,
- api_key=api_key,
- model=model,
- prompt=prompt,
- max_tokens=10240,
- temperature=0,
- stop=["--", "\n\n", ";", "#"],
- )
- results.append(plain_result)
- # Format results
- response_list = []
- for i, result in enumerate(results):
- if isinstance(result, str):
- sql = result
- else:
- sql = "SELECT" + result["choices"][0]["text"]
- db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
- sql = (
- sql + "\t----- bird -----\t" + db_id
- ) # to avoid unpredicted \t appearing in codex results
- response_list.append(sql)
- return response_list
- def question_package(data_json, knowledge=False):
- question_list = []
- for data in data_json:
- question_list.append(data["question"])
- return question_list
- def knowledge_package(data_json, knowledge=False):
- knowledge_list = []
- for data in data_json:
- knowledge_list.append(data["evidence"])
- return knowledge_list
- def decouple_question_schema(datasets, db_root_path):
- question_list = []
- db_path_list = []
- knowledge_list = []
- for i, data in enumerate(datasets):
- question_list.append(data["question"])
- cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite"
- db_path_list.append(cur_db_path)
- knowledge_list.append(data["evidence"])
- return question_list, db_path_list, knowledge_list
- def generate_sql_file(sql_lst, output_path=None):
- result = {}
- for i, sql in enumerate(sql_lst):
- result[i] = sql
- if output_path:
- directory_path = os.path.dirname(output_path)
- new_directory(directory_path)
- json.dump(result, open(output_path, "w"), indent=4)
- return result
- if __name__ == "__main__":
- args_parser = argparse.ArgumentParser()
- args_parser.add_argument("--eval_path", type=str, default="")
- args_parser.add_argument("--mode", type=str, default="dev")
- args_parser.add_argument("--test_path", type=str, default="")
- args_parser.add_argument("--use_knowledge", type=str, default="True")
- args_parser.add_argument("--db_root_path", type=str, default="")
- args_parser.add_argument("--api_key", type=str, required=True)
- args_parser.add_argument("--model", type=str, required=True)
- args_parser.add_argument("--data_output_path", type=str)
- args_parser.add_argument(
- "--batch_size",
- type=int,
- default=8,
- help="Number of parallel requests for batch processing",
- )
- args_parser.add_argument(
- "--use_batch", type=str, default="True", help="Whether to use batch processing"
- )
- args = args_parser.parse_args()
- if args.api_key not in ["huggingface", "finetuned"]:
- os.environ["LLAMA_API_KEY"] = args.api_key
- try:
- client = LlamaAPIClient()
- response = client.chat.completions.create(
- model=args.model,
- messages=[{"role": "user", "content": "125*125 is?"}],
- temperature=0,
- )
- answer = response.completion_message.content.text
- print(f"{answer=}")
- except Exception as exception:
- print(f"{exception=}")
- exit(1)
- eval_data = json.load(open(args.eval_path, "r"))
- # '''for debug'''
- # eval_data = eval_data[:3]
- # '''for debug'''
- question_list, db_path_list, knowledge_list = decouple_question_schema(
- datasets=eval_data, db_root_path=args.db_root_path
- )
- assert len(question_list) == len(db_path_list) == len(knowledge_list)
- use_batch = args.use_batch.lower() == "true"
- if use_batch:
- print(f"Using batch processing with batch_size={args.batch_size}")
- if args.use_knowledge == "True":
- responses = batch_collect_response_from_llama(
- db_path_list=db_path_list,
- question_list=question_list,
- api_key=args.api_key,
- model=args.model,
- knowledge_list=knowledge_list,
- batch_size=args.batch_size,
- )
- else:
- responses = batch_collect_response_from_llama(
- db_path_list=db_path_list,
- question_list=question_list,
- api_key=args.api_key,
- model=args.model,
- knowledge_list=None,
- batch_size=args.batch_size,
- )
- else:
- print("Using sequential processing")
- if args.use_knowledge == "True":
- responses = collect_response_from_llama(
- db_path_list=db_path_list,
- question_list=question_list,
- api_key=args.api_key,
- model=args.model,
- knowledge_list=knowledge_list,
- )
- else:
- responses = collect_response_from_llama(
- db_path_list=db_path_list,
- question_list=question_list,
- api_key=args.api_key,
- model=args.model,
- knowledge_list=None,
- )
- output_name = args.data_output_path + "predict_" + args.mode + ".json"
- generate_sql_file(sql_lst=responses, output_path=output_name)
- print("successfully collect results from {}".format(args.model))
|