llama_text2sql.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. import argparse
  2. import json
  3. import os
  4. import re
  5. import sqlite3
  6. from typing import Dict
  7. from llama_api_client import LlamaAPIClient
  8. MAX_NEW_TOKENS = 10240 # If API has max tokens (vs max new tokens), we calculate it
  9. TIMEOUT = 60 # Timeout in seconds for each API call
  10. def local_llama(client, prompt, model):
  11. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  12. # UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
  13. # SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
  14. messages = [
  15. {"content": SYSTEM_PROMPT, "role": "system"},
  16. {"role": "user", "content": prompt},
  17. ]
  18. print(f"local_llama: {model=}")
  19. chat_response = client.chat.completions.create(
  20. model=model,
  21. messages=messages,
  22. timeout=TIMEOUT,
  23. )
  24. answer = chat_response.choices[0].message.content.strip()
  25. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  26. matches = pattern.findall(answer)
  27. if matches != []:
  28. result = matches[0]
  29. else:
  30. result = answer
  31. print(f"{result=}")
  32. return result
  33. def new_directory(path):
  34. if not os.path.exists(path):
  35. os.makedirs(path)
  36. def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
  37. """
  38. Read an sqlite file, and return the CREATE commands for each of the tables in the database.
  39. """
  40. asdf = "database" if bench_root == "spider" else "databases"
  41. with sqlite3.connect(
  42. f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True
  43. ) as conn:
  44. # conn.text_factory = bytes
  45. cursor = conn.cursor()
  46. cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
  47. tables = cursor.fetchall()
  48. schemas = {}
  49. for table in tables:
  50. cursor.execute(
  51. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  52. table[0]
  53. )
  54. )
  55. schemas[table[0]] = cursor.fetchone()[0]
  56. return schemas
  57. def nice_look_table(column_names: list, values: list):
  58. rows = []
  59. # Determine the maximum width of each column
  60. widths = [
  61. max(len(str(value[i])) for value in values + [column_names])
  62. for i in range(len(column_names))
  63. ]
  64. header = "".join(
  65. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  66. )
  67. for value in values:
  68. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  69. rows.append(row)
  70. rows = "\n".join(rows)
  71. final_output = header + "\n" + rows
  72. return final_output
  73. def generate_schema_prompt(db_path, num_rows=None):
  74. # extract create ddls
  75. """
  76. :param root_place:
  77. :param db_name:
  78. :return:
  79. """
  80. full_schema_prompt_list = []
  81. conn = sqlite3.connect(db_path)
  82. # Create a cursor object
  83. cursor = conn.cursor()
  84. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  85. tables = cursor.fetchall()
  86. schemas = {}
  87. for table in tables:
  88. if table == "sqlite_sequence":
  89. continue
  90. cursor.execute(
  91. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  92. table[0]
  93. )
  94. )
  95. create_prompt = cursor.fetchone()[0]
  96. schemas[table[0]] = create_prompt
  97. if num_rows:
  98. cur_table = table[0]
  99. if cur_table in ["order", "by", "group"]:
  100. cur_table = "`{}`".format(cur_table)
  101. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  102. column_names = [description[0] for description in cursor.description]
  103. values = cursor.fetchall()
  104. rows_prompt = nice_look_table(column_names=column_names, values=values)
  105. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  106. num_rows, cur_table, num_rows, rows_prompt
  107. )
  108. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  109. for k, v in schemas.items():
  110. full_schema_prompt_list.append(v)
  111. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  112. return schema_prompt
  113. def generate_comment_prompt(question, knowledge=None):
  114. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  115. question_prompt = "-- Question: {}".format(question)
  116. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  117. return result_prompt
  118. def generate_combined_prompts_one(db_path, question, knowledge=None):
  119. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  120. comment_prompt = generate_comment_prompt(question, knowledge)
  121. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  122. return combined_prompts
  123. def cloud_llama(client, api_key, model, prompt, max_tokens, temperature, stop):
  124. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  125. try:
  126. messages = [
  127. {"content": SYSTEM_PROMPT, "role": "system"},
  128. {"role": "user", "content": prompt},
  129. ]
  130. final_max_tokens = len(messages) + MAX_NEW_TOKENS
  131. response = client.chat.completions.create(
  132. model=model,
  133. messages=messages,
  134. temperature=0,
  135. max_completion_tokens=final_max_tokens,
  136. )
  137. answer = response.completion_message.content.text
  138. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  139. matches = pattern.findall(answer)
  140. if matches != []:
  141. result = matches[0]
  142. else:
  143. result = answer
  144. print(result)
  145. except Exception as e:
  146. result = "error:{}".format(e)
  147. print(f"{result=}")
  148. return result
  149. def collect_response_from_llama(
  150. db_path_list, question_list, api_key, model, knowledge_list=None
  151. ):
  152. """
  153. :param db_path: str
  154. :param question_list: []
  155. :return: dict of responses
  156. """
  157. response_list = []
  158. if api_key in ["huggingface", "finetuned"]:
  159. from openai import OpenAI
  160. openai_api_key = "EMPTY"
  161. openai_api_base = "http://localhost:8000/v1"
  162. client = OpenAI(
  163. api_key=openai_api_key,
  164. base_url=openai_api_base,
  165. )
  166. else:
  167. client = LlamaAPIClient()
  168. for i, question in enumerate(question_list):
  169. print(
  170. "--------------------- processing question #{}---------------------".format(
  171. i + 1
  172. )
  173. )
  174. print("the question is: {}".format(question))
  175. if knowledge_list:
  176. cur_prompt = generate_combined_prompts_one(
  177. db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
  178. )
  179. else:
  180. cur_prompt = generate_combined_prompts_one(
  181. db_path=db_path_list[i], question=question
  182. )
  183. if api_key in ["huggingface", "finetuned"]:
  184. plain_result = local_llama(client=client, prompt=cur_prompt, model=model)
  185. else:
  186. plain_result = cloud_llama(
  187. client=client,
  188. api_key=api_key,
  189. model=model,
  190. prompt=cur_prompt,
  191. max_tokens=10240,
  192. temperature=0,
  193. stop=["--", "\n\n", ";", "#"],
  194. )
  195. if type(plain_result) == str:
  196. sql = plain_result
  197. else:
  198. sql = "SELECT" + plain_result["choices"][0]["text"]
  199. # responses_dict[i] = sql
  200. db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
  201. sql = (
  202. sql + "\t----- bird -----\t" + db_id
  203. ) # to avoid unpredicted \t appearing in codex results
  204. response_list.append(sql)
  205. return response_list
  206. def question_package(data_json, knowledge=False):
  207. question_list = []
  208. for data in data_json:
  209. question_list.append(data["question"])
  210. return question_list
  211. def knowledge_package(data_json, knowledge=False):
  212. knowledge_list = []
  213. for data in data_json:
  214. knowledge_list.append(data["evidence"])
  215. return knowledge_list
  216. def decouple_question_schema(datasets, db_root_path):
  217. question_list = []
  218. db_path_list = []
  219. knowledge_list = []
  220. for i, data in enumerate(datasets):
  221. question_list.append(data["question"])
  222. cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite"
  223. db_path_list.append(cur_db_path)
  224. knowledge_list.append(data["evidence"])
  225. return question_list, db_path_list, knowledge_list
  226. def generate_sql_file(sql_lst, output_path=None):
  227. result = {}
  228. for i, sql in enumerate(sql_lst):
  229. result[i] = sql
  230. if output_path:
  231. directory_path = os.path.dirname(output_path)
  232. new_directory(directory_path)
  233. json.dump(result, open(output_path, "w"), indent=4)
  234. return result
  235. if __name__ == "__main__":
  236. args_parser = argparse.ArgumentParser()
  237. args_parser.add_argument("--eval_path", type=str, default="")
  238. args_parser.add_argument("--mode", type=str, default="dev")
  239. args_parser.add_argument("--test_path", type=str, default="")
  240. args_parser.add_argument("--use_knowledge", type=str, default="True")
  241. args_parser.add_argument("--db_root_path", type=str, default="")
  242. args_parser.add_argument("--api_key", type=str, required=True)
  243. args_parser.add_argument("--model", type=str, required=True)
  244. args_parser.add_argument("--data_output_path", type=str)
  245. args = args_parser.parse_args()
  246. if not args.api_key in ["huggingface", "finetuned"]:
  247. os.environ["LLAMA_API_KEY"] = args.api_key
  248. try:
  249. client = LlamaAPIClient()
  250. response = client.chat.completions.create(
  251. model=args.model,
  252. messages=[{"role": "user", "content": "125*125 is?"}],
  253. temperature=0,
  254. )
  255. answer = response.completion_message.content.text
  256. print(f"{answer=}")
  257. except Exception as exception:
  258. print(f"{exception=}")
  259. exit(1)
  260. eval_data = json.load(open(args.eval_path, "r"))
  261. # '''for debug'''
  262. # eval_data = eval_data[:3]
  263. # '''for debug'''
  264. question_list, db_path_list, knowledge_list = decouple_question_schema(
  265. datasets=eval_data, db_root_path=args.db_root_path
  266. )
  267. assert len(question_list) == len(db_path_list) == len(knowledge_list)
  268. if args.use_knowledge == "True":
  269. responses = collect_response_from_llama( # collect_batch_response_from_llama
  270. db_path_list=db_path_list,
  271. question_list=question_list,
  272. api_key=args.api_key,
  273. model=args.model,
  274. knowledge_list=knowledge_list,
  275. )
  276. else:
  277. responses = collect_response_from_llama(
  278. db_path_list=db_path_list,
  279. question_list=question_list,
  280. api_key=args.api_key,
  281. model=args.model,
  282. knowledge_list=None,
  283. )
  284. output_name = args.data_output_path + "predict_" + args.mode + ".json"
  285. generate_sql_file(sql_lst=responses, output_path=output_name)
  286. print("successfully collect results from {}".format(args.model))