llama_text2sql.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. import argparse
  2. import fnmatch
  3. import json
  4. import os
  5. import pdb
  6. import pickle
  7. import re
  8. import sqlite3
  9. from typing import Dict, List, Tuple
  10. import pandas as pd
  11. import sqlparse
  12. import torch
  13. from datasets import Dataset, load_dataset
  14. from peft import AutoPeftModelForCausalLM
  15. from tqdm import tqdm
  16. from transformers import (
  17. AutoModelForCausalLM,
  18. AutoTokenizer,
  19. BitsAndBytesConfig,
  20. pipeline,
  21. )
  22. def local_llama(prompt, pipe):
  23. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  24. messages = [
  25. {"content": SYSTEM_PROMPT, "role": "system"},
  26. {"role": "user", "content": prompt},
  27. ]
  28. raw_prompt = pipe.tokenizer.apply_chat_template(
  29. messages,
  30. tokenize=False,
  31. add_generation_prompt=True,
  32. )
  33. print(f"local_llama: {raw_prompt=}")
  34. outputs = pipe(
  35. raw_prompt,
  36. max_new_tokens=10240,
  37. do_sample=False,
  38. temperature=0.0,
  39. top_k=50,
  40. top_p=0.1,
  41. eos_token_id=pipe.tokenizer.eos_token_id,
  42. pad_token_id=pipe.tokenizer.pad_token_id,
  43. )
  44. generated_answer = outputs[0]["generated_text"][len(raw_prompt) :].strip()
  45. print(f"{generated_answer=}")
  46. return generated_answer
  47. def new_directory(path):
  48. if not os.path.exists(path):
  49. os.makedirs(path)
  50. def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
  51. """
  52. Read an sqlite file, and return the CREATE commands for each of the tables in the database.
  53. """
  54. asdf = "database" if bench_root == "spider" else "databases"
  55. with sqlite3.connect(
  56. f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True
  57. ) as conn:
  58. # conn.text_factory = bytes
  59. cursor = conn.cursor()
  60. cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
  61. tables = cursor.fetchall()
  62. schemas = {}
  63. for table in tables:
  64. cursor.execute(
  65. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  66. table[0]
  67. )
  68. )
  69. schemas[table[0]] = cursor.fetchone()[0]
  70. return schemas
  71. def nice_look_table(column_names: list, values: list):
  72. rows = []
  73. # Determine the maximum width of each column
  74. widths = [
  75. max(len(str(value[i])) for value in values + [column_names])
  76. for i in range(len(column_names))
  77. ]
  78. # Print the column names
  79. header = "".join(
  80. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  81. )
  82. # print(header)
  83. # Print the values
  84. for value in values:
  85. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  86. rows.append(row)
  87. rows = "\n".join(rows)
  88. final_output = header + "\n" + rows
  89. return final_output
  90. def generate_schema_prompt(db_path, num_rows=None):
  91. # extract create ddls
  92. """
  93. :param root_place:
  94. :param db_name:
  95. :return:
  96. """
  97. full_schema_prompt_list = []
  98. conn = sqlite3.connect(db_path)
  99. # Create a cursor object
  100. cursor = conn.cursor()
  101. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  102. tables = cursor.fetchall()
  103. schemas = {}
  104. for table in tables:
  105. if table == "sqlite_sequence":
  106. continue
  107. cursor.execute(
  108. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  109. table[0]
  110. )
  111. )
  112. create_prompt = cursor.fetchone()[0]
  113. schemas[table[0]] = create_prompt
  114. if num_rows:
  115. cur_table = table[0]
  116. if cur_table in ["order", "by", "group"]:
  117. cur_table = "`{}`".format(cur_table)
  118. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  119. column_names = [description[0] for description in cursor.description]
  120. values = cursor.fetchall()
  121. rows_prompt = nice_look_table(column_names=column_names, values=values)
  122. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  123. num_rows, cur_table, num_rows, rows_prompt
  124. )
  125. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  126. for k, v in schemas.items():
  127. full_schema_prompt_list.append(v)
  128. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  129. return schema_prompt
  130. def generate_comment_prompt(question, knowledge=None):
  131. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  132. question_prompt = "-- Question: {}".format(question)
  133. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  134. return result_prompt
  135. def generate_combined_prompts_one(db_path, question, knowledge=None):
  136. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  137. comment_prompt = generate_comment_prompt(question, knowledge)
  138. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  139. return combined_prompts
  140. def cloud_llama(api_key, model, prompt, max_tokens, temperature, stop):
  141. try:
  142. if model.startswith("meta-llama/"):
  143. llm = ChatTogether(
  144. model=model,
  145. temperature=0,
  146. )
  147. answer = llm.invoke(prompt).content
  148. else:
  149. client = LlamaAPIClient()
  150. response = client.chat.completions.create(
  151. model=model,
  152. messages=[{"role": "user", "content": prompt}],
  153. temperature=0,
  154. )
  155. answer = response.completion_message.content.text
  156. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  157. matches = pattern.findall(answer)
  158. if matches != []:
  159. result = matches[0]
  160. else:
  161. result = answer
  162. print(result)
  163. except Exception as e:
  164. result = "error:{}".format(e)
  165. print(f"{result=}")
  166. return result
  167. def huggingface_finetuned(api_key, model):
  168. if api_key == "finetuned":
  169. model_id = model
  170. model = AutoPeftModelForCausalLM.from_pretrained(
  171. model_id, device_map="auto", torch_dtype=torch.float16
  172. )
  173. tokenizer = AutoTokenizer.from_pretrained(model_id)
  174. # TODO: uncomment to see if it makes a difference
  175. tokenizer.padding_side = "right" # to prevent warnings
  176. if tokenizer.pad_token is None:
  177. tokenizer.add_special_tokens({"pad_token": "[PAD]"})
  178. model.resize_token_embeddings(len(tokenizer))
  179. elif api_key == "huggingface":
  180. model_id = model
  181. bnb_config = BitsAndBytesConfig(
  182. load_in_4bit=True,
  183. bnb_4bit_use_double_quant=True,
  184. bnb_4bit_quant_type="nf4",
  185. bnb_4bit_compute_dtype=torch.bfloat16,
  186. )
  187. model = AutoModelForCausalLM.from_pretrained(
  188. model_id,
  189. device_map="auto",
  190. # attn_implementation="flash_attention_2",
  191. torch_dtype=torch.bfloat16,
  192. quantization_config=bnb_config,
  193. )
  194. tokenizer = AutoTokenizer.from_pretrained(model_id)
  195. pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
  196. return pipe
  197. def collect_response_from_llama(
  198. db_path_list, question_list, api_key, model, knowledge_list=None
  199. ):
  200. """
  201. :param db_path: str
  202. :param question_list: []
  203. :return: dict of responses
  204. """
  205. responses_dict = {}
  206. response_list = []
  207. if api_key in ["huggingface", "finetuned"]:
  208. pipe = huggingface_finetuned(api_key=api_key, model=model)
  209. for i, question in tqdm(enumerate(question_list)):
  210. print(
  211. "--------------------- processing question #{}---------------------".format(
  212. i + 1
  213. )
  214. )
  215. print("the question is: {}".format(question))
  216. if knowledge_list:
  217. cur_prompt = generate_combined_prompts_one(
  218. db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
  219. )
  220. else:
  221. cur_prompt = generate_combined_prompts_one(
  222. db_path=db_path_list[i], question=question
  223. )
  224. if api_key in ["huggingface", "finetuned"]:
  225. plain_result = local_llama(prompt=cur_prompt, pipe=pipe)
  226. else:
  227. plain_result = cloud_llama(
  228. api_key=api_key,
  229. model=model,
  230. prompt=cur_prompt,
  231. max_tokens=4096,
  232. temperature=0,
  233. stop=["--", "\n\n", ";", "#"],
  234. )
  235. if type(plain_result) == str:
  236. sql = plain_result
  237. else:
  238. sql = "SELECT" + plain_result["choices"][0]["text"]
  239. # responses_dict[i] = sql
  240. db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
  241. sql = (
  242. sql + "\t----- bird -----\t" + db_id
  243. ) # to avoid unpredicted \t appearing in codex results
  244. response_list.append(sql)
  245. return response_list
  246. def question_package(data_json, knowledge=False):
  247. question_list = []
  248. for data in data_json:
  249. question_list.append(data["question"])
  250. return question_list
  251. def knowledge_package(data_json, knowledge=False):
  252. knowledge_list = []
  253. for data in data_json:
  254. knowledge_list.append(data["evidence"])
  255. return knowledge_list
  256. def decouple_question_schema(datasets, db_root_path):
  257. question_list = []
  258. db_path_list = []
  259. knowledge_list = []
  260. for i, data in enumerate(datasets):
  261. question_list.append(data["question"])
  262. cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite"
  263. db_path_list.append(cur_db_path)
  264. knowledge_list.append(data["evidence"])
  265. return question_list, db_path_list, knowledge_list
  266. def generate_sql_file(sql_lst, output_path=None):
  267. result = {}
  268. for i, sql in enumerate(sql_lst):
  269. result[i] = sql
  270. if output_path:
  271. directory_path = os.path.dirname(output_path)
  272. new_directory(directory_path)
  273. json.dump(result, open(output_path, "w"), indent=4)
  274. return result
  275. if __name__ == "__main__":
  276. args_parser = argparse.ArgumentParser()
  277. args_parser.add_argument("--eval_path", type=str, default="")
  278. args_parser.add_argument("--mode", type=str, default="dev")
  279. args_parser.add_argument("--test_path", type=str, default="")
  280. args_parser.add_argument("--use_knowledge", type=str, default="True")
  281. args_parser.add_argument("--db_root_path", type=str, default="")
  282. args_parser.add_argument("--api_key", type=str, required=True)
  283. args_parser.add_argument("--model", type=str, required=True)
  284. args_parser.add_argument("--data_output_path", type=str)
  285. args = args_parser.parse_args()
  286. if not args.api_key in ["huggingface", "finetuned"]:
  287. if args.model.startswith("meta-llama/"): # Llama model on together
  288. os.environ["TOGETHER_API_KEY"] = args.api_key
  289. llm = ChatTogether(
  290. model=args.model,
  291. temperature=0,
  292. )
  293. try:
  294. response = llm.invoke("125*125 is?").content
  295. print(f"{response=}")
  296. except Exception as exception:
  297. print(f"{exception=}")
  298. exit(1)
  299. else: # Llama model on Llama API
  300. os.environ["LLAMA_API_KEY"] = args.api_key
  301. try:
  302. client = LlamaAPIClient()
  303. response = client.chat.completions.create(
  304. model=args.model,
  305. messages=[{"role": "user", "content": "125*125 is?"}],
  306. temperature=0,
  307. )
  308. answer = response.completion_message.content.text
  309. print(f"{answer=}")
  310. except Exception as exception:
  311. print(f"{exception=}")
  312. exit(1)
  313. eval_data = json.load(open(args.eval_path, "r"))
  314. # '''for debug'''
  315. # eval_data = eval_data[:3]
  316. # '''for debug'''
  317. question_list, db_path_list, knowledge_list = decouple_question_schema(
  318. datasets=eval_data, db_root_path=args.db_root_path
  319. )
  320. assert len(question_list) == len(db_path_list) == len(knowledge_list)
  321. if args.use_knowledge == "True":
  322. responses = collect_response_from_llama(
  323. db_path_list=db_path_list,
  324. question_list=question_list,
  325. api_key=args.api_key,
  326. model=args.model,
  327. knowledge_list=knowledge_list,
  328. )
  329. else:
  330. responses = collect_response_from_llama(
  331. db_path_list=db_path_list,
  332. question_list=question_list,
  333. api_key=args.api_key,
  334. model=args.model,
  335. knowledge_list=None,
  336. )
  337. output_name = args.data_output_path + "predict_" + args.mode + ".json"
  338. generate_sql_file(sql_lst=responses, output_path=output_name)
  339. print("successfully collect results from {}".format(args.model))