llama_text2sql.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. import argparse
  2. import fnmatch
  3. import json
  4. import os
  5. import pdb
  6. import pickle
  7. import re
  8. import sqlite3
  9. from typing import Dict, List, Tuple
  10. import pandas as pd
  11. import sqlparse
  12. import torch
  13. from datasets import Dataset, load_dataset
  14. from langchain_together import ChatTogether
  15. from llama_api_client import LlamaAPIClient
  16. from peft import AutoPeftModelForCausalLM
  17. from tqdm import tqdm
  18. from transformers import (
  19. AutoModelForCausalLM,
  20. AutoTokenizer,
  21. BitsAndBytesConfig,
  22. pipeline,
  23. )
  24. def local_llama(prompt, pipe):
  25. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  26. # UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
  27. # SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
  28. messages = [
  29. {"content": SYSTEM_PROMPT, "role": "system"},
  30. {"role": "user", "content": prompt},
  31. ]
  32. raw_prompt = pipe.tokenizer.apply_chat_template(
  33. messages,
  34. tokenize=False,
  35. add_generation_prompt=True,
  36. )
  37. print(f"local_llama: {raw_prompt=}")
  38. outputs = pipe(
  39. raw_prompt,
  40. max_new_tokens=10240,
  41. do_sample=False,
  42. temperature=0.0,
  43. top_k=50,
  44. top_p=0.1,
  45. eos_token_id=pipe.tokenizer.eos_token_id,
  46. pad_token_id=pipe.tokenizer.pad_token_id,
  47. )
  48. answer = outputs[0]["generated_text"][len(raw_prompt) :].strip()
  49. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  50. matches = pattern.findall(answer)
  51. if matches != []:
  52. result = matches[0]
  53. else:
  54. result = answer
  55. print(f"{result=}")
  56. return result
  57. def new_directory(path):
  58. if not os.path.exists(path):
  59. os.makedirs(path)
  60. def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
  61. """
  62. Read an sqlite file, and return the CREATE commands for each of the tables in the database.
  63. """
  64. asdf = "database" if bench_root == "spider" else "databases"
  65. with sqlite3.connect(
  66. f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True
  67. ) as conn:
  68. # conn.text_factory = bytes
  69. cursor = conn.cursor()
  70. cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
  71. tables = cursor.fetchall()
  72. schemas = {}
  73. for table in tables:
  74. cursor.execute(
  75. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  76. table[0]
  77. )
  78. )
  79. schemas[table[0]] = cursor.fetchone()[0]
  80. return schemas
  81. def nice_look_table(column_names: list, values: list):
  82. rows = []
  83. # Determine the maximum width of each column
  84. widths = [
  85. max(len(str(value[i])) for value in values + [column_names])
  86. for i in range(len(column_names))
  87. ]
  88. # Print the column names
  89. header = "".join(
  90. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  91. )
  92. # print(header)
  93. # Print the values
  94. for value in values:
  95. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  96. rows.append(row)
  97. rows = "\n".join(rows)
  98. final_output = header + "\n" + rows
  99. return final_output
  100. def generate_schema_prompt(db_path, num_rows=None):
  101. # extract create ddls
  102. """
  103. :param root_place:
  104. :param db_name:
  105. :return:
  106. """
  107. full_schema_prompt_list = []
  108. conn = sqlite3.connect(db_path)
  109. # Create a cursor object
  110. cursor = conn.cursor()
  111. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  112. tables = cursor.fetchall()
  113. schemas = {}
  114. for table in tables:
  115. if table == "sqlite_sequence":
  116. continue
  117. cursor.execute(
  118. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  119. table[0]
  120. )
  121. )
  122. create_prompt = cursor.fetchone()[0]
  123. schemas[table[0]] = create_prompt
  124. if num_rows:
  125. cur_table = table[0]
  126. if cur_table in ["order", "by", "group"]:
  127. cur_table = "`{}`".format(cur_table)
  128. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  129. column_names = [description[0] for description in cursor.description]
  130. values = cursor.fetchall()
  131. rows_prompt = nice_look_table(column_names=column_names, values=values)
  132. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  133. num_rows, cur_table, num_rows, rows_prompt
  134. )
  135. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  136. for k, v in schemas.items():
  137. full_schema_prompt_list.append(v)
  138. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  139. return schema_prompt
  140. def generate_comment_prompt(question, knowledge=None):
  141. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  142. question_prompt = "-- Question: {}".format(question)
  143. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  144. return result_prompt
  145. def generate_combined_prompts_one(db_path, question, knowledge=None):
  146. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  147. comment_prompt = generate_comment_prompt(question, knowledge)
  148. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  149. return combined_prompts
  150. def cloud_llama(api_key, model, prompt, max_tokens, temperature, stop):
  151. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  152. try:
  153. if model.startswith("meta-llama/"):
  154. llm = ChatTogether(
  155. model=model,
  156. temperature=0,
  157. )
  158. answer = llm.invoke(SYSTEM_PROMPT + "\n\n" + prompt).content
  159. else:
  160. client = LlamaAPIClient()
  161. messages = [
  162. {"content": SYSTEM_PROMPT, "role": "system"},
  163. {"role": "user", "content": prompt},
  164. ]
  165. response = client.chat.completions.create(
  166. model=model,
  167. messages=messages,
  168. temperature=0,
  169. )
  170. answer = response.completion_message.content.text
  171. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  172. matches = pattern.findall(answer)
  173. if matches != []:
  174. result = matches[0]
  175. else:
  176. result = answer
  177. print(result)
  178. except Exception as e:
  179. result = "error:{}".format(e)
  180. print(f"{result=}")
  181. return result
  182. def huggingface_finetuned(api_key, model):
  183. if api_key == "finetuned":
  184. model_id = model
  185. model = AutoPeftModelForCausalLM.from_pretrained(
  186. model_id, device_map="auto", torch_dtype=torch.float16
  187. )
  188. tokenizer = AutoTokenizer.from_pretrained(model_id)
  189. tokenizer.padding_side = "right" # to prevent warnings
  190. if tokenizer.pad_token is None:
  191. tokenizer.add_special_tokens({"pad_token": "[PAD]"})
  192. model.resize_token_embeddings(len(tokenizer))
  193. elif api_key == "huggingface":
  194. model_id = model
  195. bnb_config = BitsAndBytesConfig(
  196. load_in_4bit=True,
  197. bnb_4bit_use_double_quant=True,
  198. bnb_4bit_quant_type="nf4",
  199. bnb_4bit_compute_dtype=torch.bfloat16,
  200. )
  201. model = AutoModelForCausalLM.from_pretrained(
  202. model_id,
  203. device_map="auto",
  204. torch_dtype=torch.bfloat16,
  205. quantization_config=bnb_config,
  206. )
  207. tokenizer = AutoTokenizer.from_pretrained(model_id)
  208. pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
  209. return pipe
  210. def collect_response_from_llama(
  211. db_path_list, question_list, api_key, model, knowledge_list=None
  212. ):
  213. """
  214. :param db_path: str
  215. :param question_list: []
  216. :return: dict of responses
  217. """
  218. responses_dict = {}
  219. response_list = []
  220. if api_key in ["huggingface", "finetuned"]:
  221. pipe = huggingface_finetuned(api_key=api_key, model=model)
  222. for i, question in tqdm(enumerate(question_list)):
  223. print(
  224. "--------------------- processing question #{}---------------------".format(
  225. i + 1
  226. )
  227. )
  228. print("the question is: {}".format(question))
  229. if knowledge_list:
  230. cur_prompt = generate_combined_prompts_one(
  231. db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
  232. )
  233. else:
  234. cur_prompt = generate_combined_prompts_one(
  235. db_path=db_path_list[i], question=question
  236. )
  237. if api_key in ["huggingface", "finetuned"]:
  238. plain_result = local_llama(prompt=cur_prompt, pipe=pipe)
  239. else:
  240. plain_result = cloud_llama(
  241. api_key=api_key,
  242. model=model,
  243. prompt=cur_prompt,
  244. max_tokens=4096,
  245. temperature=0,
  246. stop=["--", "\n\n", ";", "#"],
  247. )
  248. if type(plain_result) == str:
  249. sql = plain_result
  250. else:
  251. sql = "SELECT" + plain_result["choices"][0]["text"]
  252. # responses_dict[i] = sql
  253. db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
  254. sql = (
  255. sql + "\t----- bird -----\t" + db_id
  256. ) # to avoid unpredicted \t appearing in codex results
  257. response_list.append(sql)
  258. return response_list
  259. def question_package(data_json, knowledge=False):
  260. question_list = []
  261. for data in data_json:
  262. question_list.append(data["question"])
  263. return question_list
  264. def knowledge_package(data_json, knowledge=False):
  265. knowledge_list = []
  266. for data in data_json:
  267. knowledge_list.append(data["evidence"])
  268. return knowledge_list
  269. def decouple_question_schema(datasets, db_root_path):
  270. question_list = []
  271. db_path_list = []
  272. knowledge_list = []
  273. for i, data in enumerate(datasets):
  274. question_list.append(data["question"])
  275. cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite"
  276. db_path_list.append(cur_db_path)
  277. knowledge_list.append(data["evidence"])
  278. return question_list, db_path_list, knowledge_list
  279. def generate_sql_file(sql_lst, output_path=None):
  280. result = {}
  281. for i, sql in enumerate(sql_lst):
  282. result[i] = sql
  283. if output_path:
  284. directory_path = os.path.dirname(output_path)
  285. new_directory(directory_path)
  286. json.dump(result, open(output_path, "w"), indent=4)
  287. return result
  288. if __name__ == "__main__":
  289. args_parser = argparse.ArgumentParser()
  290. args_parser.add_argument("--eval_path", type=str, default="")
  291. args_parser.add_argument("--mode", type=str, default="dev")
  292. args_parser.add_argument("--test_path", type=str, default="")
  293. args_parser.add_argument("--use_knowledge", type=str, default="True")
  294. args_parser.add_argument("--db_root_path", type=str, default="")
  295. args_parser.add_argument("--api_key", type=str, required=True)
  296. args_parser.add_argument("--model", type=str, required=True)
  297. args_parser.add_argument("--data_output_path", type=str)
  298. args = args_parser.parse_args()
  299. if not args.api_key in ["huggingface", "finetuned"]:
  300. if args.model.startswith("meta-llama/"): # Llama model on together
  301. os.environ["TOGETHER_API_KEY"] = args.api_key
  302. llm = ChatTogether(
  303. model=args.model,
  304. temperature=0,
  305. )
  306. try:
  307. response = llm.invoke("125*125 is?").content
  308. print(f"{response=}")
  309. except Exception as exception:
  310. print(f"{exception=}")
  311. exit(1)
  312. else: # Llama model on Llama API
  313. os.environ["LLAMA_API_KEY"] = args.api_key
  314. try:
  315. client = LlamaAPIClient()
  316. response = client.chat.completions.create(
  317. model=args.model,
  318. messages=[{"role": "user", "content": "125*125 is?"}],
  319. temperature=0,
  320. )
  321. answer = response.completion_message.content.text
  322. print(f"{answer=}")
  323. except Exception as exception:
  324. print(f"{exception=}")
  325. exit(1)
  326. eval_data = json.load(open(args.eval_path, "r"))
  327. # '''for debug'''
  328. # eval_data = eval_data[:3]
  329. # '''for debug'''
  330. question_list, db_path_list, knowledge_list = decouple_question_schema(
  331. datasets=eval_data, db_root_path=args.db_root_path
  332. )
  333. assert len(question_list) == len(db_path_list) == len(knowledge_list)
  334. if args.use_knowledge == "True":
  335. responses = collect_response_from_llama(
  336. db_path_list=db_path_list,
  337. question_list=question_list,
  338. api_key=args.api_key,
  339. model=args.model,
  340. knowledge_list=knowledge_list,
  341. )
  342. else:
  343. responses = collect_response_from_llama(
  344. db_path_list=db_path_list,
  345. question_list=question_list,
  346. api_key=args.api_key,
  347. model=args.model,
  348. knowledge_list=None,
  349. )
  350. output_name = args.data_output_path + "predict_" + args.mode + ".json"
  351. generate_sql_file(sql_lst=responses, output_path=output_name)
  352. print("successfully collect results from {}".format(args.model))