llama_text2sql.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. import argparse
  2. import concurrent.futures
  3. import json
  4. import os
  5. import re
  6. import sqlite3
  7. from typing import Dict
  8. from llama_api_client import LlamaAPIClient
  9. from tqdm import tqdm
  10. MAX_NEW_TOKENS = 10240 # If API has max tokens (vs max new tokens), we calculate it
  11. TIMEOUT = 60 # Timeout in seconds for each API call
  12. def local_llama(client, prompt, model):
  13. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  14. # UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
  15. # SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
  16. messages = [
  17. {"content": SYSTEM_PROMPT, "role": "system"},
  18. {"role": "user", "content": prompt},
  19. ]
  20. print(f"local_llama: {model=}")
  21. chat_response = client.chat.completions.create(
  22. model=model,
  23. messages=messages,
  24. timeout=TIMEOUT,
  25. temperature=0,
  26. )
  27. answer = chat_response.choices[0].message.content.strip()
  28. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  29. matches = pattern.findall(answer)
  30. if not matches:
  31. result = answer
  32. else:
  33. result = matches[0]
  34. print(f"{result=}")
  35. return result
  36. def batch_local_llama(client, prompts, model, max_workers=8):
  37. """
  38. Process multiple prompts in parallel using the vllm server.
  39. Args:
  40. client: OpenAI client
  41. prompts: List of prompts to process
  42. model: Model name
  43. max_workers: Maximum number of parallel workers
  44. Returns:
  45. List of results in the same order as prompts
  46. """
  47. SYSTEM_PROMPT = (
  48. "You are a text to SQL query translator. Using the SQLite DB Schema "
  49. "and the External Knowledge, translate the following text question "
  50. "into a SQLite SQL select statement."
  51. )
  52. def process_single_prompt(prompt):
  53. messages = [
  54. {"content": SYSTEM_PROMPT, "role": "system"},
  55. {"role": "user", "content": prompt},
  56. ]
  57. try:
  58. chat_response = client.chat.completions.create(
  59. model=model,
  60. messages=messages,
  61. timeout=TIMEOUT,
  62. temperature=0,
  63. )
  64. answer = chat_response.choices[0].message.content.strip()
  65. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  66. matches = pattern.findall(answer)
  67. if not matches:
  68. result = answer
  69. else:
  70. result = matches[0]
  71. return result
  72. except Exception as e:
  73. print(f"Error processing prompt: {e}")
  74. return f"error:{e}"
  75. print(
  76. f"batch_local_llama: Processing {len(prompts)} prompts with {model=} "
  77. f"using {max_workers} workers"
  78. )
  79. results = []
  80. with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
  81. # Submit all tasks and create a map of futures to their indices
  82. future_to_index = {
  83. executor.submit(process_single_prompt, prompt): i
  84. for i, prompt in enumerate(prompts)
  85. }
  86. # Initialize results list with None values
  87. results = [None] * len(prompts)
  88. # Process completed futures as they complete
  89. for future in tqdm(
  90. concurrent.futures.as_completed(future_to_index),
  91. total=len(prompts),
  92. desc="Processing prompts",
  93. ):
  94. index = future_to_index[future]
  95. try:
  96. results[index] = future.result()
  97. except Exception as e:
  98. print(f"Error processing prompt at index {index}: {e}")
  99. results[index] = f"error:{e}"
  100. return results
  101. def new_directory(path):
  102. if not os.path.exists(path):
  103. os.makedirs(path)
  104. def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
  105. """
  106. Read an sqlite file, and return the CREATE commands for each of the tables in the database.
  107. """
  108. asdf = "database" if bench_root == "spider" else "databases"
  109. with sqlite3.connect(
  110. f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True
  111. ) as conn:
  112. # conn.text_factory = bytes
  113. cursor = conn.cursor()
  114. cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
  115. tables = cursor.fetchall()
  116. schemas = {}
  117. for table in tables:
  118. cursor.execute(
  119. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  120. table[0]
  121. )
  122. )
  123. schemas[table[0]] = cursor.fetchone()[0]
  124. return schemas
  125. def nice_look_table(column_names: list, values: list):
  126. rows = []
  127. # Determine the maximum width of each column
  128. widths = [
  129. max(len(str(value[i])) for value in values + [column_names])
  130. for i in range(len(column_names))
  131. ]
  132. header = "".join(
  133. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  134. )
  135. for value in values:
  136. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  137. rows.append(row)
  138. rows = "\n".join(rows)
  139. final_output = header + "\n" + rows
  140. return final_output
  141. def generate_schema_prompt(db_path, num_rows=None):
  142. # extract create ddls
  143. """
  144. :param root_place:
  145. :param db_name:
  146. :return:
  147. """
  148. full_schema_prompt_list = []
  149. conn = sqlite3.connect(db_path)
  150. # Create a cursor object
  151. cursor = conn.cursor()
  152. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  153. tables = cursor.fetchall()
  154. schemas = {}
  155. for table in tables:
  156. if table == "sqlite_sequence":
  157. continue
  158. cursor.execute(
  159. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  160. table[0]
  161. )
  162. )
  163. create_prompt = cursor.fetchone()[0]
  164. schemas[table[0]] = create_prompt
  165. if num_rows:
  166. cur_table = table[0]
  167. if cur_table in ["order", "by", "group"]:
  168. cur_table = "`{}`".format(cur_table)
  169. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  170. column_names = [description[0] for description in cursor.description]
  171. values = cursor.fetchall()
  172. rows_prompt = nice_look_table(column_names=column_names, values=values)
  173. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  174. num_rows, cur_table, num_rows, rows_prompt
  175. )
  176. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  177. for k, v in schemas.items():
  178. full_schema_prompt_list.append(v)
  179. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  180. return schema_prompt
  181. def generate_comment_prompt(question, knowledge=None):
  182. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  183. question_prompt = "-- Question: {}".format(question)
  184. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  185. return result_prompt
  186. def generate_combined_prompts_one(db_path, question, knowledge=None):
  187. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  188. comment_prompt = generate_comment_prompt(question, knowledge)
  189. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  190. return combined_prompts
  191. def cloud_llama(client, api_key, model, prompt, max_tokens, temperature, stop):
  192. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  193. try:
  194. messages = [
  195. {"content": SYSTEM_PROMPT, "role": "system"},
  196. {"role": "user", "content": prompt},
  197. ]
  198. final_max_tokens = len(messages) + MAX_NEW_TOKENS
  199. response = client.chat.completions.create(
  200. model=model,
  201. messages=messages,
  202. temperature=0,
  203. max_completion_tokens=final_max_tokens,
  204. )
  205. answer = response.completion_message.content.text
  206. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  207. matches = pattern.findall(answer)
  208. if matches != []:
  209. result = matches[0]
  210. else:
  211. result = answer
  212. print(result)
  213. except Exception as e:
  214. result = "error:{}".format(e)
  215. print(f"{result=}")
  216. return result
  217. def collect_response_from_llama(
  218. db_path_list, question_list, api_key, model, knowledge_list=None
  219. ):
  220. """
  221. :param db_path: str
  222. :param question_list: []
  223. :return: dict of responses
  224. """
  225. response_list = []
  226. if api_key in ["huggingface", "finetuned"]:
  227. from openai import OpenAI
  228. openai_api_key = "EMPTY"
  229. openai_api_base = "http://localhost:8000/v1"
  230. client = OpenAI(
  231. api_key=openai_api_key,
  232. base_url=openai_api_base,
  233. )
  234. else:
  235. client = LlamaAPIClient()
  236. for i, question in enumerate(question_list):
  237. print(
  238. "--------------------- processing question #{}---------------------".format(
  239. i + 1
  240. )
  241. )
  242. print("the question is: {}".format(question))
  243. if knowledge_list:
  244. cur_prompt = generate_combined_prompts_one(
  245. db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
  246. )
  247. else:
  248. cur_prompt = generate_combined_prompts_one(
  249. db_path=db_path_list[i], question=question
  250. )
  251. if api_key in ["huggingface", "finetuned"]:
  252. plain_result = local_llama(client=client, prompt=cur_prompt, model=model)
  253. else:
  254. plain_result = cloud_llama(
  255. client=client,
  256. api_key=api_key,
  257. model=model,
  258. prompt=cur_prompt,
  259. max_tokens=10240,
  260. temperature=0,
  261. stop=["--", "\n\n", ";", "#"],
  262. )
  263. if isinstance(plain_result, str):
  264. sql = plain_result
  265. else:
  266. sql = "SELECT" + plain_result["choices"][0]["text"]
  267. # responses_dict[i] = sql
  268. db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
  269. sql = (
  270. sql + "\t----- bird -----\t" + db_id
  271. ) # to avoid unpredicted \t appearing in codex results
  272. response_list.append(sql)
  273. return response_list
  274. def batch_collect_response_from_llama(
  275. db_path_list, question_list, api_key, model, knowledge_list=None, batch_size=8
  276. ):
  277. """
  278. Process multiple questions in parallel using the vllm server.
  279. Args:
  280. db_path_list: List of database paths
  281. question_list: List of questions
  282. api_key: API key
  283. model: Model name
  284. knowledge_list: List of knowledge strings (optional)
  285. batch_size: Number of parallel requests
  286. Returns:
  287. List of SQL responses
  288. """
  289. if api_key in ["huggingface", "finetuned"]:
  290. from openai import OpenAI
  291. openai_api_key = "EMPTY"
  292. openai_api_base = "http://localhost:8000/v1"
  293. client = OpenAI(
  294. api_key=openai_api_key,
  295. base_url=openai_api_base,
  296. )
  297. else:
  298. client = LlamaAPIClient()
  299. # Generate all prompts first
  300. prompts = []
  301. for i, question in enumerate(question_list):
  302. if knowledge_list:
  303. cur_prompt = generate_combined_prompts_one(
  304. db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
  305. )
  306. else:
  307. cur_prompt = generate_combined_prompts_one(
  308. db_path=db_path_list[i], question=question
  309. )
  310. prompts.append(cur_prompt)
  311. print(f"Generated {len(prompts)} prompts for batch processing")
  312. # Process prompts in parallel
  313. if api_key in ["huggingface", "finetuned"]:
  314. results = batch_local_llama(
  315. client=client, prompts=prompts, model=model, max_workers=batch_size
  316. )
  317. else:
  318. # For cloud API, we could implement a batch version of cloud_llama if needed
  319. # For now, just process sequentially
  320. results = []
  321. for prompt in prompts:
  322. plain_result = cloud_llama(
  323. client=client,
  324. api_key=api_key,
  325. model=model,
  326. prompt=prompt,
  327. max_tokens=10240,
  328. temperature=0,
  329. stop=["--", "\n\n", ";", "#"],
  330. )
  331. results.append(plain_result)
  332. # Format results
  333. response_list = []
  334. for i, result in enumerate(results):
  335. if isinstance(result, str):
  336. sql = result
  337. else:
  338. sql = "SELECT" + result["choices"][0]["text"]
  339. db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
  340. sql = (
  341. sql + "\t----- bird -----\t" + db_id
  342. ) # to avoid unpredicted \t appearing in codex results
  343. response_list.append(sql)
  344. return response_list
  345. def question_package(data_json, knowledge=False):
  346. question_list = []
  347. for data in data_json:
  348. question_list.append(data["question"])
  349. return question_list
  350. def knowledge_package(data_json, knowledge=False):
  351. knowledge_list = []
  352. for data in data_json:
  353. knowledge_list.append(data["evidence"])
  354. return knowledge_list
  355. def decouple_question_schema(datasets, db_root_path):
  356. question_list = []
  357. db_path_list = []
  358. knowledge_list = []
  359. for i, data in enumerate(datasets):
  360. question_list.append(data["question"])
  361. cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite"
  362. db_path_list.append(cur_db_path)
  363. knowledge_list.append(data["evidence"])
  364. return question_list, db_path_list, knowledge_list
  365. def generate_sql_file(sql_lst, output_path=None):
  366. result = {}
  367. for i, sql in enumerate(sql_lst):
  368. result[i] = sql
  369. if output_path:
  370. directory_path = os.path.dirname(output_path)
  371. new_directory(directory_path)
  372. json.dump(result, open(output_path, "w"), indent=4)
  373. return result
  374. if __name__ == "__main__":
  375. args_parser = argparse.ArgumentParser()
  376. args_parser.add_argument("--eval_path", type=str, default="")
  377. args_parser.add_argument("--mode", type=str, default="dev")
  378. args_parser.add_argument("--test_path", type=str, default="")
  379. args_parser.add_argument("--use_knowledge", type=str, default="True")
  380. args_parser.add_argument("--db_root_path", type=str, default="")
  381. args_parser.add_argument("--api_key", type=str, required=True)
  382. args_parser.add_argument("--model", type=str, required=True)
  383. args_parser.add_argument("--data_output_path", type=str)
  384. args_parser.add_argument(
  385. "--batch_size",
  386. type=int,
  387. default=8,
  388. help="Number of parallel requests for batch processing",
  389. )
  390. args_parser.add_argument(
  391. "--use_batch", type=str, default="True", help="Whether to use batch processing"
  392. )
  393. args = args_parser.parse_args()
  394. if args.api_key not in ["huggingface", "finetuned"]:
  395. os.environ["LLAMA_API_KEY"] = args.api_key
  396. try:
  397. client = LlamaAPIClient()
  398. response = client.chat.completions.create(
  399. model=args.model,
  400. messages=[{"role": "user", "content": "125*125 is?"}],
  401. temperature=0,
  402. )
  403. answer = response.completion_message.content.text
  404. print(f"{answer=}")
  405. except Exception as exception:
  406. print(f"{exception=}")
  407. exit(1)
  408. eval_data = json.load(open(args.eval_path, "r"))
  409. # '''for debug'''
  410. # eval_data = eval_data[:3]
  411. # '''for debug'''
  412. question_list, db_path_list, knowledge_list = decouple_question_schema(
  413. datasets=eval_data, db_root_path=args.db_root_path
  414. )
  415. assert len(question_list) == len(db_path_list) == len(knowledge_list)
  416. use_batch = args.use_batch.lower() == "true"
  417. if use_batch:
  418. print(f"Using batch processing with batch_size={args.batch_size}")
  419. if args.use_knowledge == "True":
  420. responses = batch_collect_response_from_llama(
  421. db_path_list=db_path_list,
  422. question_list=question_list,
  423. api_key=args.api_key,
  424. model=args.model,
  425. knowledge_list=knowledge_list,
  426. batch_size=args.batch_size,
  427. )
  428. else:
  429. responses = batch_collect_response_from_llama(
  430. db_path_list=db_path_list,
  431. question_list=question_list,
  432. api_key=args.api_key,
  433. model=args.model,
  434. knowledge_list=None,
  435. batch_size=args.batch_size,
  436. )
  437. else:
  438. print("Using sequential processing")
  439. if args.use_knowledge == "True":
  440. responses = collect_response_from_llama(
  441. db_path_list=db_path_list,
  442. question_list=question_list,
  443. api_key=args.api_key,
  444. model=args.model,
  445. knowledge_list=knowledge_list,
  446. )
  447. else:
  448. responses = collect_response_from_llama(
  449. db_path_list=db_path_list,
  450. question_list=question_list,
  451. api_key=args.api_key,
  452. model=args.model,
  453. knowledge_list=None,
  454. )
  455. output_name = args.data_output_path + "predict_" + args.mode + ".json"
  456. generate_sql_file(sql_lst=responses, output_path=output_name)
  457. print("successfully collect results from {}".format(args.model))