llama_text2sql.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. import argparse
  2. import concurrent.futures
  3. import json
  4. import os
  5. import re
  6. import sqlite3
  7. from typing import Dict
  8. from llama_api_client import LlamaAPIClient
  9. from tqdm import tqdm
  10. MAX_NEW_TOKENS = 10240 # If API has max tokens (vs max new tokens), we calculate it
  11. TIMEOUT = 60 # Timeout in seconds for each API call
  12. def local_llama(client, api_key, prompts, model, max_workers=8):
  13. """
  14. Process multiple prompts in parallel using the vllm server.
  15. Args:
  16. client: OpenAI client
  17. prompts: List of prompts to process
  18. model: Model name
  19. max_workers: Maximum number of parallel workers
  20. Returns:
  21. List of results in the same order as prompts
  22. """
  23. SYSTEM_PROMPT = (
  24. (
  25. "You are a text to SQL query translator. Using the SQLite DB Schema "
  26. "and the External Knowledge, translate the following text question "
  27. "into a SQLite SQL select statement."
  28. )
  29. if api_key == "huggingface"
  30. else (
  31. "You are a text to SQL query translator. Using the SQLite DB Schema "
  32. "and the External Knowledge, generate the step-by-step reasoning and "
  33. "then the final SQLite SQL select statement from the text question."
  34. )
  35. )
  36. def process_single_prompt(prompt):
  37. messages = [
  38. {"content": SYSTEM_PROMPT, "role": "system"},
  39. {"role": "user", "content": prompt},
  40. ]
  41. try:
  42. chat_response = client.chat.completions.create(
  43. model=model,
  44. messages=messages,
  45. timeout=TIMEOUT,
  46. temperature=0,
  47. )
  48. answer = chat_response.choices[0].message.content.strip()
  49. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  50. matches = pattern.findall(answer)
  51. if not matches:
  52. result = answer
  53. else:
  54. result = matches[0]
  55. return result
  56. except Exception as e:
  57. print(f"Error processing prompt: {e}")
  58. return f"error:{e}"
  59. print(
  60. f"local_llama: Processing {len(prompts)} prompts with {model=} "
  61. f"using {max_workers} workers"
  62. )
  63. results = []
  64. with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
  65. # Submit all tasks and create a map of futures to their indices
  66. future_to_index = {
  67. executor.submit(process_single_prompt, prompt): i
  68. for i, prompt in enumerate(prompts)
  69. }
  70. # Initialize results list with None values
  71. results = [None] * len(prompts)
  72. # Process completed futures as they complete
  73. for future in tqdm(
  74. concurrent.futures.as_completed(future_to_index),
  75. total=len(prompts),
  76. desc="Processing prompts",
  77. ):
  78. index = future_to_index[future]
  79. try:
  80. results[index] = future.result()
  81. except Exception as e:
  82. print(f"Error processing prompt at index {index}: {e}")
  83. results[index] = f"error:{e}"
  84. return results
  85. def new_directory(path):
  86. if not os.path.exists(path):
  87. os.makedirs(path)
  88. def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
  89. """
  90. Read an sqlite file, and return the CREATE commands for each of the tables in the database.
  91. """
  92. asdf = "database" if bench_root == "spider" else "databases"
  93. with sqlite3.connect(
  94. f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True
  95. ) as conn:
  96. # conn.text_factory = bytes
  97. cursor = conn.cursor()
  98. cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
  99. tables = cursor.fetchall()
  100. schemas = {}
  101. for table in tables:
  102. cursor.execute(
  103. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  104. table[0]
  105. )
  106. )
  107. schemas[table[0]] = cursor.fetchone()[0]
  108. return schemas
  109. def nice_look_table(column_names: list, values: list):
  110. rows = []
  111. # Determine the maximum width of each column
  112. widths = [
  113. max(len(str(value[i])) for value in values + [column_names])
  114. for i in range(len(column_names))
  115. ]
  116. header = "".join(
  117. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  118. )
  119. for value in values:
  120. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  121. rows.append(row)
  122. rows = "\n".join(rows)
  123. final_output = header + "\n" + rows
  124. return final_output
  125. def generate_schema_prompt(db_path, num_rows=None):
  126. # extract create ddls
  127. """
  128. :param root_place:
  129. :param db_name:
  130. :return:
  131. """
  132. full_schema_prompt_list = []
  133. conn = sqlite3.connect(db_path)
  134. # Create a cursor object
  135. cursor = conn.cursor()
  136. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  137. tables = cursor.fetchall()
  138. schemas = {}
  139. for table in tables:
  140. if table == "sqlite_sequence":
  141. continue
  142. cursor.execute(
  143. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  144. table[0]
  145. )
  146. )
  147. create_prompt = cursor.fetchone()[0]
  148. schemas[table[0]] = create_prompt
  149. if num_rows:
  150. cur_table = table[0]
  151. if cur_table in ["order", "by", "group"]:
  152. cur_table = "`{}`".format(cur_table)
  153. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  154. column_names = [description[0] for description in cursor.description]
  155. values = cursor.fetchall()
  156. rows_prompt = nice_look_table(column_names=column_names, values=values)
  157. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  158. num_rows, cur_table, num_rows, rows_prompt
  159. )
  160. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  161. for k, v in schemas.items():
  162. full_schema_prompt_list.append(v)
  163. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  164. return schema_prompt
  165. def generate_comment_prompt(question, knowledge=None):
  166. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  167. question_prompt = "-- Question: {}".format(question)
  168. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  169. return result_prompt
  170. def generate_combined_prompts_one(db_path, question, knowledge=None):
  171. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  172. comment_prompt = generate_comment_prompt(question, knowledge)
  173. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  174. return combined_prompts
  175. def cloud_llama(client, api_key, model, prompts):
  176. """
  177. Process multiple prompts sequentially using the cloud API, showing progress with tqdm.
  178. Args:
  179. client: LlamaAPIClient
  180. api_key: API key
  181. model: Model name
  182. prompts: List of prompts to process (or a single prompt as string)
  183. Returns:
  184. List of results if prompts is a list, or a single result if prompts is a string
  185. """
  186. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  187. # Handle the case where a single prompt is passed
  188. single_prompt = False
  189. if isinstance(prompts, str):
  190. prompts = [prompts]
  191. single_prompt = True
  192. results = []
  193. # Process each prompt sequentially with tqdm progress bar
  194. for prompt in tqdm(prompts, desc="Processing prompts", unit="prompt"):
  195. try:
  196. messages = [
  197. {"content": SYSTEM_PROMPT, "role": "system"},
  198. {"role": "user", "content": prompt},
  199. ]
  200. final_max_tokens = len(messages) + MAX_NEW_TOKENS
  201. response = client.chat.completions.create(
  202. model=model,
  203. messages=messages,
  204. temperature=0,
  205. max_completion_tokens=final_max_tokens,
  206. )
  207. answer = response.completion_message.content.text
  208. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  209. matches = pattern.findall(answer)
  210. if matches != []:
  211. result = matches[0]
  212. else:
  213. result = answer
  214. except Exception as e:
  215. result = "error:{}".format(e)
  216. print(f"{result=}")
  217. results.append(result)
  218. # Return a single result if input was a single prompt
  219. if single_prompt:
  220. return results[0]
  221. return results
  222. def batch_collect_response_from_llama(
  223. db_path_list, question_list, api_key, model, knowledge_list=None, batch_size=8
  224. ):
  225. """
  226. Process multiple questions in parallel using the vllm server.
  227. Args:
  228. db_path_list: List of database paths
  229. question_list: List of questions
  230. api_key: API key
  231. model: Model name
  232. knowledge_list: List of knowledge strings (optional)
  233. batch_size: Number of parallel requests
  234. Returns:
  235. List of SQL responses
  236. """
  237. if api_key in ["huggingface", "finetuned"]:
  238. from openai import OpenAI
  239. openai_api_key = "EMPTY"
  240. openai_api_base = "http://localhost:8000/v1"
  241. client = OpenAI(
  242. api_key=openai_api_key,
  243. base_url=openai_api_base,
  244. )
  245. else:
  246. client = LlamaAPIClient()
  247. # Generate all prompts first
  248. prompts = []
  249. for i, question in enumerate(question_list):
  250. if knowledge_list:
  251. cur_prompt = generate_combined_prompts_one(
  252. db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
  253. )
  254. else:
  255. cur_prompt = generate_combined_prompts_one(
  256. db_path=db_path_list[i], question=question
  257. )
  258. prompts.append(cur_prompt)
  259. print(f"Generated {len(prompts)} prompts for batch processing")
  260. # Process prompts in parallel
  261. if api_key in [
  262. "huggingface",
  263. "finetuned",
  264. ]: # running vllm on multiple GPUs to see best performance
  265. results = local_llama(
  266. client=client,
  267. api_key=api_key,
  268. prompts=prompts,
  269. model=model,
  270. max_workers=batch_size,
  271. )
  272. else:
  273. results = cloud_llama(
  274. client=client,
  275. api_key=api_key,
  276. model=model,
  277. prompts=prompts,
  278. )
  279. # Format results
  280. response_list = []
  281. for i, result in enumerate(results):
  282. if isinstance(result, str):
  283. sql = result
  284. else:
  285. sql = "SELECT" + result["choices"][0]["text"]
  286. db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
  287. sql = (
  288. sql + "\t----- bird -----\t" + db_id
  289. ) # to avoid unpredicted \t appearing in codex results
  290. response_list.append(sql)
  291. return response_list
  292. def question_package(data_json, knowledge=False):
  293. question_list = []
  294. for data in data_json:
  295. question_list.append(data["question"])
  296. return question_list
  297. def knowledge_package(data_json, knowledge=False):
  298. knowledge_list = []
  299. for data in data_json:
  300. knowledge_list.append(data["evidence"])
  301. return knowledge_list
  302. def decouple_question_schema(datasets, db_root_path):
  303. question_list = []
  304. db_path_list = []
  305. knowledge_list = []
  306. for i, data in enumerate(datasets):
  307. question_list.append(data["question"])
  308. cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite"
  309. db_path_list.append(cur_db_path)
  310. knowledge_list.append(data["evidence"])
  311. return question_list, db_path_list, knowledge_list
  312. def generate_sql_file(sql_lst, output_path=None):
  313. result = {}
  314. for i, sql in enumerate(sql_lst):
  315. result[i] = sql
  316. if output_path:
  317. directory_path = os.path.dirname(output_path)
  318. new_directory(directory_path)
  319. json.dump(result, open(output_path, "w"), indent=4)
  320. return result
  321. if __name__ == "__main__":
  322. args_parser = argparse.ArgumentParser()
  323. args_parser.add_argument("--eval_path", type=str, default="")
  324. args_parser.add_argument("--mode", type=str, default="dev")
  325. args_parser.add_argument("--test_path", type=str, default="")
  326. args_parser.add_argument("--use_knowledge", type=str, default="True")
  327. args_parser.add_argument("--db_root_path", type=str, default="")
  328. args_parser.add_argument("--api_key", type=str, required=True)
  329. args_parser.add_argument("--model", type=str, required=True)
  330. args_parser.add_argument("--data_output_path", type=str)
  331. args_parser.add_argument(
  332. "--batch_size",
  333. type=int,
  334. default=8,
  335. help="Number of parallel requests for batch processing",
  336. )
  337. args = args_parser.parse_args()
  338. if args.api_key not in ["huggingface", "finetuned"]:
  339. os.environ["LLAMA_API_KEY"] = args.api_key
  340. try:
  341. client = LlamaAPIClient()
  342. response = client.chat.completions.create(
  343. model=args.model,
  344. messages=[{"role": "user", "content": "125*125 is?"}],
  345. temperature=0,
  346. )
  347. answer = response.completion_message.content.text
  348. except Exception as exception:
  349. print(f"{exception=}")
  350. exit(1)
  351. eval_data = json.load(open(args.eval_path, "r"))
  352. question_list, db_path_list, knowledge_list = decouple_question_schema(
  353. datasets=eval_data, db_root_path=args.db_root_path
  354. )
  355. assert len(question_list) == len(db_path_list) == len(knowledge_list)
  356. print(f"Using batch processing with batch_size={args.batch_size}")
  357. if args.use_knowledge == "True":
  358. responses = batch_collect_response_from_llama(
  359. db_path_list=db_path_list,
  360. question_list=question_list,
  361. api_key=args.api_key,
  362. model=args.model,
  363. knowledge_list=knowledge_list,
  364. batch_size=args.batch_size,
  365. )
  366. else:
  367. responses = batch_collect_response_from_llama(
  368. db_path_list=db_path_list,
  369. question_list=question_list,
  370. api_key=args.api_key,
  371. model=args.model,
  372. knowledge_list=None,
  373. batch_size=args.batch_size,
  374. )
  375. output_name = args.data_output_path + "predict_" + args.mode + ".json"
  376. generate_sql_file(sql_lst=responses, output_path=output_name)
  377. print("successfully collect results from {}".format(args.model))