llama_text2sql.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. import argparse
  2. import fnmatch
  3. import json
  4. import os
  5. import pdb
  6. import pickle
  7. import re
  8. import sqlite3
  9. from typing import Dict, List, Tuple
  10. import pandas as pd
  11. import sqlparse
  12. import torch
  13. from datasets import Dataset, load_dataset
  14. from langchain_together import ChatTogether
  15. from peft import AutoPeftModelForCausalLM
  16. from tqdm import tqdm
  17. from transformers import (
  18. AutoModelForCausalLM,
  19. AutoTokenizer,
  20. BitsAndBytesConfig,
  21. pipeline,
  22. )
  23. def local_llama(prompt, pipe):
  24. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  25. messages = [
  26. {"content": SYSTEM_PROMPT, "role": "system"},
  27. {"role": "user", "content": prompt},
  28. ]
  29. raw_prompt = pipe.tokenizer.apply_chat_template(
  30. messages,
  31. tokenize=False,
  32. add_generation_prompt=True,
  33. )
  34. print(f"local_llama: {raw_prompt=}")
  35. outputs = pipe(
  36. raw_prompt,
  37. max_new_tokens=10240,
  38. do_sample=False,
  39. temperature=0.0,
  40. top_k=50,
  41. top_p=0.1,
  42. eos_token_id=pipe.tokenizer.eos_token_id,
  43. pad_token_id=pipe.tokenizer.pad_token_id,
  44. )
  45. generated_answer = outputs[0]["generated_text"][len(raw_prompt) :].strip()
  46. print(f"{generated_answer=}")
  47. return generated_answer
  48. def new_directory(path):
  49. if not os.path.exists(path):
  50. os.makedirs(path)
  51. def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
  52. """
  53. Read an sqlite file, and return the CREATE commands for each of the tables in the database.
  54. """
  55. asdf = "database" if bench_root == "spider" else "databases"
  56. with sqlite3.connect(
  57. f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True
  58. ) as conn:
  59. # conn.text_factory = bytes
  60. cursor = conn.cursor()
  61. cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
  62. tables = cursor.fetchall()
  63. schemas = {}
  64. for table in tables:
  65. cursor.execute(
  66. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  67. table[0]
  68. )
  69. )
  70. schemas[table[0]] = cursor.fetchone()[0]
  71. return schemas
  72. def nice_look_table(column_names: list, values: list):
  73. rows = []
  74. # Determine the maximum width of each column
  75. widths = [
  76. max(len(str(value[i])) for value in values + [column_names])
  77. for i in range(len(column_names))
  78. ]
  79. # Print the column names
  80. header = "".join(
  81. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  82. )
  83. # print(header)
  84. # Print the values
  85. for value in values:
  86. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  87. rows.append(row)
  88. rows = "\n".join(rows)
  89. final_output = header + "\n" + rows
  90. return final_output
  91. def generate_schema_prompt(db_path, num_rows=None):
  92. # extract create ddls
  93. """
  94. :param root_place:
  95. :param db_name:
  96. :return:
  97. """
  98. full_schema_prompt_list = []
  99. conn = sqlite3.connect(db_path)
  100. # Create a cursor object
  101. cursor = conn.cursor()
  102. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  103. tables = cursor.fetchall()
  104. schemas = {}
  105. for table in tables:
  106. if table == "sqlite_sequence":
  107. continue
  108. cursor.execute(
  109. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  110. table[0]
  111. )
  112. )
  113. create_prompt = cursor.fetchone()[0]
  114. schemas[table[0]] = create_prompt
  115. if num_rows:
  116. cur_table = table[0]
  117. if cur_table in ["order", "by", "group"]:
  118. cur_table = "`{}`".format(cur_table)
  119. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  120. column_names = [description[0] for description in cursor.description]
  121. values = cursor.fetchall()
  122. rows_prompt = nice_look_table(column_names=column_names, values=values)
  123. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  124. num_rows, cur_table, num_rows, rows_prompt
  125. )
  126. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  127. for k, v in schemas.items():
  128. full_schema_prompt_list.append(v)
  129. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  130. return schema_prompt
  131. def generate_comment_prompt(question, knowledge=None):
  132. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  133. question_prompt = "-- Question: {}".format(question)
  134. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  135. return result_prompt
  136. def generate_combined_prompts_one(db_path, question, knowledge=None):
  137. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  138. comment_prompt = generate_comment_prompt(question, knowledge)
  139. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  140. return combined_prompts
  141. def cloud_llama(api_key, model, prompt, max_tokens, temperature, stop):
  142. try:
  143. if model.startswith("meta-llama/"):
  144. llm = ChatTogether(
  145. model=model,
  146. temperature=0,
  147. )
  148. answer = llm.invoke(prompt).content
  149. else:
  150. client = LlamaAPIClient()
  151. response = client.chat.completions.create(
  152. model=model,
  153. messages=[{"role": "user", "content": prompt}],
  154. temperature=0,
  155. )
  156. answer = response.completion_message.content.text
  157. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  158. matches = pattern.findall(answer)
  159. if matches != []:
  160. result = matches[0]
  161. else:
  162. result = answer
  163. print(result)
  164. except Exception as e:
  165. result = "error:{}".format(e)
  166. print(f"{result=}")
  167. return result
  168. def huggingface_finetuned(api_key, model):
  169. if api_key == "finetuned":
  170. model_id = model
  171. model = AutoPeftModelForCausalLM.from_pretrained(
  172. model_id, device_map="auto", torch_dtype=torch.float16
  173. )
  174. tokenizer = AutoTokenizer.from_pretrained(model_id)
  175. tokenizer.padding_side = "right" # to prevent warnings
  176. if tokenizer.pad_token is None:
  177. tokenizer.add_special_tokens({"pad_token": "[PAD]"})
  178. model.resize_token_embeddings(len(tokenizer))
  179. elif api_key == "huggingface":
  180. model_id = model
  181. bnb_config = BitsAndBytesConfig(
  182. load_in_4bit=True,
  183. bnb_4bit_use_double_quant=True,
  184. bnb_4bit_quant_type="nf4",
  185. bnb_4bit_compute_dtype=torch.bfloat16,
  186. )
  187. model = AutoModelForCausalLM.from_pretrained(
  188. model_id,
  189. device_map="auto",
  190. torch_dtype=torch.bfloat16,
  191. quantization_config=bnb_config,
  192. )
  193. tokenizer = AutoTokenizer.from_pretrained(model_id)
  194. pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
  195. return pipe
  196. def collect_response_from_llama(
  197. db_path_list, question_list, api_key, model, knowledge_list=None
  198. ):
  199. """
  200. :param db_path: str
  201. :param question_list: []
  202. :return: dict of responses
  203. """
  204. responses_dict = {}
  205. response_list = []
  206. if api_key in ["huggingface", "finetuned"]:
  207. pipe = huggingface_finetuned(api_key=api_key, model=model)
  208. for i, question in tqdm(enumerate(question_list)):
  209. print(
  210. "--------------------- processing question #{}---------------------".format(
  211. i + 1
  212. )
  213. )
  214. print("the question is: {}".format(question))
  215. if knowledge_list:
  216. cur_prompt = generate_combined_prompts_one(
  217. db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
  218. )
  219. else:
  220. cur_prompt = generate_combined_prompts_one(
  221. db_path=db_path_list[i], question=question
  222. )
  223. if api_key in ["huggingface", "finetuned"]:
  224. plain_result = local_llama(prompt=cur_prompt, pipe=pipe)
  225. else:
  226. plain_result = cloud_llama(
  227. api_key=api_key,
  228. model=model,
  229. prompt=cur_prompt,
  230. max_tokens=4096,
  231. temperature=0,
  232. stop=["--", "\n\n", ";", "#"],
  233. )
  234. if type(plain_result) == str:
  235. sql = plain_result
  236. else:
  237. sql = "SELECT" + plain_result["choices"][0]["text"]
  238. # responses_dict[i] = sql
  239. db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
  240. sql = (
  241. sql + "\t----- bird -----\t" + db_id
  242. ) # to avoid unpredicted \t appearing in codex results
  243. response_list.append(sql)
  244. return response_list
  245. def question_package(data_json, knowledge=False):
  246. question_list = []
  247. for data in data_json:
  248. question_list.append(data["question"])
  249. return question_list
  250. def knowledge_package(data_json, knowledge=False):
  251. knowledge_list = []
  252. for data in data_json:
  253. knowledge_list.append(data["evidence"])
  254. return knowledge_list
  255. def decouple_question_schema(datasets, db_root_path):
  256. question_list = []
  257. db_path_list = []
  258. knowledge_list = []
  259. for i, data in enumerate(datasets):
  260. question_list.append(data["question"])
  261. cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite"
  262. db_path_list.append(cur_db_path)
  263. knowledge_list.append(data["evidence"])
  264. return question_list, db_path_list, knowledge_list
  265. def generate_sql_file(sql_lst, output_path=None):
  266. result = {}
  267. for i, sql in enumerate(sql_lst):
  268. result[i] = sql
  269. if output_path:
  270. directory_path = os.path.dirname(output_path)
  271. new_directory(directory_path)
  272. json.dump(result, open(output_path, "w"), indent=4)
  273. return result
  274. if __name__ == "__main__":
  275. args_parser = argparse.ArgumentParser()
  276. args_parser.add_argument("--eval_path", type=str, default="")
  277. args_parser.add_argument("--mode", type=str, default="dev")
  278. args_parser.add_argument("--test_path", type=str, default="")
  279. args_parser.add_argument("--use_knowledge", type=str, default="True")
  280. args_parser.add_argument("--db_root_path", type=str, default="")
  281. args_parser.add_argument("--api_key", type=str, required=True)
  282. args_parser.add_argument("--model", type=str, required=True)
  283. args_parser.add_argument("--data_output_path", type=str)
  284. args = args_parser.parse_args()
  285. if not args.api_key in ["huggingface", "finetuned"]:
  286. if args.model.startswith("meta-llama/"): # Llama model on together
  287. os.environ["TOGETHER_API_KEY"] = args.api_key
  288. llm = ChatTogether(
  289. model=args.model,
  290. temperature=0,
  291. )
  292. try:
  293. response = llm.invoke("125*125 is?").content
  294. print(f"{response=}")
  295. except Exception as exception:
  296. print(f"{exception=}")
  297. exit(1)
  298. else: # Llama model on Llama API
  299. os.environ["LLAMA_API_KEY"] = args.api_key
  300. try:
  301. client = LlamaAPIClient()
  302. response = client.chat.completions.create(
  303. model=args.model,
  304. messages=[{"role": "user", "content": "125*125 is?"}],
  305. temperature=0,
  306. )
  307. answer = response.completion_message.content.text
  308. print(f"{answer=}")
  309. except Exception as exception:
  310. print(f"{exception=}")
  311. exit(1)
  312. eval_data = json.load(open(args.eval_path, "r"))
  313. # '''for debug'''
  314. # eval_data = eval_data[:3]
  315. # '''for debug'''
  316. question_list, db_path_list, knowledge_list = decouple_question_schema(
  317. datasets=eval_data, db_root_path=args.db_root_path
  318. )
  319. assert len(question_list) == len(db_path_list) == len(knowledge_list)
  320. if args.use_knowledge == "True":
  321. responses = collect_response_from_llama(
  322. db_path_list=db_path_list,
  323. question_list=question_list,
  324. api_key=args.api_key,
  325. model=args.model,
  326. knowledge_list=knowledge_list,
  327. )
  328. else:
  329. responses = collect_response_from_llama(
  330. db_path_list=db_path_list,
  331. question_list=question_list,
  332. api_key=args.api_key,
  333. model=args.model,
  334. knowledge_list=None,
  335. )
  336. output_name = args.data_output_path + "predict_" + args.mode + ".json"
  337. generate_sql_file(sql_lst=responses, output_path=output_name)
  338. print("successfully collect results from {}".format(args.model))