llama_text2sql_vllm.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import argparse
  2. import json
  3. import os
  4. import re
  5. import sqlite3
  6. from typing import Dict, List, Tuple
  7. from tqdm import tqdm
  8. from vllm import LLM, EngineArgs, SamplingParams
  9. DEFAULT_MAX_TOKENS=10240
  10. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  11. # UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
  12. # SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
  13. def inference(llm, sampling_params, user_prompt):
  14. messages = [
  15. {"content": SYSTEM_PROMPT, "role": "system"},
  16. {"role": "user", "content": user_prompt},
  17. ]
  18. print(f"{messages=}")
  19. response = llm.chat(messages, sampling_params, use_tqdm=False)
  20. print(f"{response=}")
  21. response_text = response[0].outputs[0].text
  22. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  23. matches = pattern.findall(response_text)
  24. if matches != []:
  25. result = matches[0]
  26. else:
  27. result = response_text
  28. print(f"{result=}")
  29. return result
  30. def new_directory(path):
  31. if not os.path.exists(path):
  32. os.makedirs(path)
  33. def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
  34. """
  35. Read an sqlite file, and return the CREATE commands for each of the tables in the database.
  36. """
  37. asdf = "database" if bench_root == "spider" else "databases"
  38. with sqlite3.connect(
  39. f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True
  40. ) as conn:
  41. # conn.text_factory = bytes
  42. cursor = conn.cursor()
  43. cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
  44. tables = cursor.fetchall()
  45. schemas = {}
  46. for table in tables:
  47. cursor.execute(
  48. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  49. table[0]
  50. )
  51. )
  52. schemas[table[0]] = cursor.fetchone()[0]
  53. return schemas
  54. def nice_look_table(column_names: list, values: list):
  55. rows = []
  56. # Determine the maximum width of each column
  57. widths = [
  58. max(len(str(value[i])) for value in values + [column_names])
  59. for i in range(len(column_names))
  60. ]
  61. # Print the column names
  62. header = "".join(
  63. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  64. )
  65. # print(header)
  66. # Print the values
  67. for value in values:
  68. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  69. rows.append(row)
  70. rows = "\n".join(rows)
  71. final_output = header + "\n" + rows
  72. return final_output
  73. def generate_schema_prompt(db_path, num_rows=None):
  74. # extract create ddls
  75. """
  76. :param root_place:
  77. :param db_name:
  78. :return:
  79. """
  80. full_schema_prompt_list = []
  81. conn = sqlite3.connect(db_path)
  82. # Create a cursor object
  83. cursor = conn.cursor()
  84. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  85. tables = cursor.fetchall()
  86. schemas = {}
  87. for table in tables:
  88. if table == "sqlite_sequence":
  89. continue
  90. cursor.execute(
  91. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  92. table[0]
  93. )
  94. )
  95. create_prompt = cursor.fetchone()[0]
  96. schemas[table[0]] = create_prompt
  97. if num_rows:
  98. cur_table = table[0]
  99. if cur_table in ["order", "by", "group"]:
  100. cur_table = "`{}`".format(cur_table)
  101. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  102. column_names = [description[0] for description in cursor.description]
  103. values = cursor.fetchall()
  104. rows_prompt = nice_look_table(column_names=column_names, values=values)
  105. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  106. num_rows, cur_table, num_rows, rows_prompt
  107. )
  108. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  109. for k, v in schemas.items():
  110. full_schema_prompt_list.append(v)
  111. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  112. return schema_prompt
  113. def generate_comment_prompt(question, knowledge=None):
  114. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  115. question_prompt = "-- Question: {}".format(question)
  116. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  117. return result_prompt
  118. def generate_combined_prompts_one(db_path, question, knowledge=None):
  119. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  120. comment_prompt = generate_comment_prompt(question, knowledge)
  121. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  122. return combined_prompts
  123. def collect_response_from_llama(
  124. llm, sampling_params, db_path_list, question_list, knowledge_list=None
  125. ):
  126. response_list = []
  127. for i, question in tqdm(enumerate(question_list)):
  128. print(
  129. "--------------------- processing question #{}---------------------".format(
  130. i + 1
  131. )
  132. )
  133. print("the question is: {}".format(question))
  134. if knowledge_list:
  135. cur_prompt = generate_combined_prompts_one(
  136. db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
  137. )
  138. else:
  139. cur_prompt = generate_combined_prompts_one(
  140. db_path=db_path_list[i], question=question
  141. )
  142. plain_result = inference(llm, sampling_params, cur_prompt)
  143. if type(plain_result) == str:
  144. sql = plain_result
  145. else:
  146. sql = "SELECT" + plain_result["choices"][0]["text"]
  147. # responses_dict[i] = sql
  148. db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
  149. sql = (
  150. sql + "\t----- bird -----\t" + db_id
  151. ) # to avoid unpredicted \t appearing in codex results
  152. response_list.append(sql)
  153. return response_list
  154. def question_package(data_json, knowledge=False):
  155. question_list = []
  156. for data in data_json:
  157. question_list.append(data["question"])
  158. return question_list
  159. def knowledge_package(data_json, knowledge=False):
  160. knowledge_list = []
  161. for data in data_json:
  162. knowledge_list.append(data["evidence"])
  163. return knowledge_list
  164. def decouple_question_schema(datasets, db_root_path):
  165. question_list = []
  166. db_path_list = []
  167. knowledge_list = []
  168. for i, data in enumerate(datasets):
  169. question_list.append(data["question"])
  170. cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite"
  171. db_path_list.append(cur_db_path)
  172. knowledge_list.append(data["evidence"])
  173. return question_list, db_path_list, knowledge_list
  174. def generate_sql_file(sql_lst, output_path=None):
  175. result = {}
  176. for i, sql in enumerate(sql_lst):
  177. result[i] = sql
  178. if output_path:
  179. directory_path = os.path.dirname(output_path)
  180. new_directory(directory_path)
  181. json.dump(result, open(output_path, "w"), indent=4)
  182. return result
  183. if __name__ == "__main__":
  184. args_parser = argparse.ArgumentParser()
  185. args_parser.add_argument("--eval_path", type=str, default="")
  186. args_parser.add_argument("--mode", type=str, default="dev")
  187. args_parser.add_argument("--test_path", type=str, default="")
  188. args_parser.add_argument("--use_knowledge", type=str, default="True")
  189. args_parser.add_argument("--db_root_path", type=str, default="")
  190. args_parser.add_argument("--model", type=str, default="meta-llama/Llama-3.1-8B-Instruct")
  191. args_parser.add_argument("--data_output_path", type=str)
  192. args_parser.add_argument("--max_tokens", type=int, default=DEFAULT_MAX_TOKENS)
  193. args_parser.add_argument("--temperature", type=float, default=0.0)
  194. args_parser.add_argument("--top_k", type=int, default=50)
  195. args_parser.add_argument("--top_p", type=float, default=0.1)
  196. args = args_parser.parse_args()
  197. eval_data = json.load(open(args.eval_path, "r"))
  198. # '''for debug'''
  199. # eval_data = eval_data[:3]
  200. # '''for debug'''
  201. question_list, db_path_list, knowledge_list = decouple_question_schema(
  202. datasets=eval_data, db_root_path=args.db_root_path
  203. )
  204. assert len(question_list) == len(db_path_list) == len(knowledge_list)
  205. llm = LLM(model=args.model, download_dir="/opt/hpcaas/.mounts/fs-06ad2f76a5ad0b18f/shared/amiryo/.cache/vllm")
  206. sampling_params = llm.get_default_sampling_params()
  207. sampling_params.max_tokens = args.max_tokens
  208. sampling_params.temperature = args.temperature
  209. sampling_params.top_p = args.top_p
  210. sampling_params.top_k = args.top_k
  211. if args.use_knowledge == "True":
  212. responses = collect_response_from_llama(
  213. llm=llm,
  214. sampling_params=sampling_params,
  215. db_path_list=db_path_list,
  216. question_list=question_list,
  217. knowledge_list=knowledge_list,
  218. )
  219. else:
  220. responses = collect_response_from_llama(
  221. llm=llm,
  222. sampling_params=sampling_params,
  223. db_path_list=db_path_list,
  224. question_list=question_list,
  225. knowledge_list=None,
  226. )
  227. output_name = args.data_output_path + "predict_" + args.mode + ".json"
  228. generate_sql_file(sql_lst=responses, output_path=output_name)
  229. print("successfully collect results from {}".format(args.model))