llama_text2sql.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. import argparse
  2. import fnmatch
  3. import json
  4. import os
  5. import pdb
  6. import pickle
  7. import re
  8. import sqlite3
  9. from typing import Dict, List, Tuple
  10. import pandas as pd
  11. import sqlparse
  12. import torch
  13. from datasets import Dataset, load_dataset
  14. from langchain_together import ChatTogether
  15. from llama_api_client import LlamaAPIClient
  16. from peft import AutoPeftModelForCausalLM
  17. from tqdm import tqdm
  18. from transformers import (
  19. AutoModelForCausalLM,
  20. AutoTokenizer,
  21. BitsAndBytesConfig,
  22. pipeline,
  23. )
  24. def local_llama(prompt, pipe):
  25. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  26. # UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
  27. # SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
  28. messages = [
  29. {"content": SYSTEM_PROMPT, "role": "system"},
  30. {"role": "user", "content": prompt},
  31. ]
  32. raw_prompt = pipe.tokenizer.apply_chat_template(
  33. messages,
  34. tokenize=False,
  35. add_generation_prompt=True,
  36. )
  37. print(f"local_llama: {raw_prompt=}")
  38. outputs = pipe(
  39. raw_prompt,
  40. max_new_tokens=10240,
  41. do_sample=False,
  42. temperature=0.0,
  43. top_k=50,
  44. top_p=0.1,
  45. eos_token_id=pipe.tokenizer.eos_token_id,
  46. pad_token_id=pipe.tokenizer.pad_token_id,
  47. )
  48. answer = outputs[0]["generated_text"][len(raw_prompt) :].strip()
  49. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  50. matches = pattern.findall(answer)
  51. if matches != []:
  52. result = matches[0]
  53. else:
  54. result = answer
  55. print(f"{result=}")
  56. return result
  57. def new_directory(path):
  58. if not os.path.exists(path):
  59. os.makedirs(path)
  60. def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
  61. """
  62. Read an sqlite file, and return the CREATE commands for each of the tables in the database.
  63. """
  64. asdf = "database" if bench_root == "spider" else "databases"
  65. with sqlite3.connect(
  66. f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True
  67. ) as conn:
  68. # conn.text_factory = bytes
  69. cursor = conn.cursor()
  70. cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
  71. tables = cursor.fetchall()
  72. schemas = {}
  73. for table in tables:
  74. cursor.execute(
  75. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  76. table[0]
  77. )
  78. )
  79. schemas[table[0]] = cursor.fetchone()[0]
  80. return schemas
  81. def nice_look_table(column_names: list, values: list):
  82. rows = []
  83. # Determine the maximum width of each column
  84. widths = [
  85. max(len(str(value[i])) for value in values + [column_names])
  86. for i in range(len(column_names))
  87. ]
  88. # Print the column names
  89. header = "".join(
  90. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  91. )
  92. # print(header)
  93. # Print the values
  94. for value in values:
  95. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  96. rows.append(row)
  97. rows = "\n".join(rows)
  98. final_output = header + "\n" + rows
  99. return final_output
  100. def generate_schema_prompt(db_path, num_rows=None):
  101. # extract create ddls
  102. """
  103. :param root_place:
  104. :param db_name:
  105. :return:
  106. """
  107. full_schema_prompt_list = []
  108. conn = sqlite3.connect(db_path)
  109. # Create a cursor object
  110. cursor = conn.cursor()
  111. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  112. tables = cursor.fetchall()
  113. schemas = {}
  114. for table in tables:
  115. if table == "sqlite_sequence":
  116. continue
  117. cursor.execute(
  118. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  119. table[0]
  120. )
  121. )
  122. create_prompt = cursor.fetchone()[0]
  123. schemas[table[0]] = create_prompt
  124. if num_rows:
  125. cur_table = table[0]
  126. if cur_table in ["order", "by", "group"]:
  127. cur_table = "`{}`".format(cur_table)
  128. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  129. column_names = [description[0] for description in cursor.description]
  130. values = cursor.fetchall()
  131. rows_prompt = nice_look_table(column_names=column_names, values=values)
  132. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  133. num_rows, cur_table, num_rows, rows_prompt
  134. )
  135. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  136. for k, v in schemas.items():
  137. full_schema_prompt_list.append(v)
  138. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  139. return schema_prompt
  140. def generate_comment_prompt(question, knowledge=None):
  141. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  142. question_prompt = "-- Question: {}".format(question)
  143. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  144. return result_prompt
  145. def generate_combined_prompts_one(db_path, question, knowledge=None):
  146. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  147. comment_prompt = generate_comment_prompt(question, knowledge)
  148. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  149. return combined_prompts
  150. def cloud_llama(api_key, model, prompt, max_tokens, temperature, stop):
  151. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  152. try:
  153. if model.startswith("meta-llama/"):
  154. llm = ChatTogether(
  155. model=model,
  156. temperature=0,
  157. )
  158. answer = llm.invoke(SYSTEM_PROMPT + "\n\n" + prompt).content
  159. else:
  160. client = LlamaAPIClient()
  161. messages = [
  162. {"content": SYSTEM_PROMPT, "role": "system"},
  163. {"role": "user", "content": prompt},
  164. ]
  165. response = client.chat.completions.create(
  166. model=model,
  167. messages=messages,
  168. temperature=0,
  169. )
  170. answer = response.completion_message.content.text
  171. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  172. matches = pattern.findall(answer)
  173. if matches != []:
  174. result = matches[0]
  175. else:
  176. result = answer
  177. print(result)
  178. except Exception as e:
  179. result = "error:{}".format(e)
  180. print(f"{result=}")
  181. return result
  182. def huggingface_finetuned(api_key, model):
  183. if api_key == "finetuned":
  184. model_id = model
  185. # Check if this is a PEFT model by looking for adapter_config.json
  186. import os
  187. is_peft_model = os.path.exists(os.path.join(model_id, "adapter_config.json"))
  188. if is_peft_model:
  189. # Use AutoPeftModelForCausalLM for PEFT fine-tuned models
  190. print(f"Loading PEFT model from {model_id}")
  191. model = AutoPeftModelForCausalLM.from_pretrained(
  192. model_id, device_map="auto", torch_dtype=torch.float16
  193. )
  194. tokenizer = AutoTokenizer.from_pretrained(model_id)
  195. else:
  196. # Use AutoModelForCausalLM for FFT (Full Fine-Tuning) models
  197. print(f"Loading FFT model from {model_id}")
  198. model = AutoModelForCausalLM.from_pretrained(
  199. model_id, device_map="auto", torch_dtype=torch.float16
  200. )
  201. tokenizer = AutoTokenizer.from_pretrained(model_id)
  202. # For FFT models, handle pad token if it was added during training
  203. if tokenizer.pad_token is None:
  204. tokenizer.add_special_tokens({"pad_token": "[PAD]"})
  205. model.resize_token_embeddings(len(tokenizer))
  206. tokenizer.padding_side = "right" # to prevent warnings
  207. elif api_key == "huggingface":
  208. model_id = model
  209. bnb_config = BitsAndBytesConfig(
  210. load_in_4bit=True,
  211. bnb_4bit_use_double_quant=True,
  212. bnb_4bit_quant_type="nf4",
  213. bnb_4bit_compute_dtype=torch.bfloat16,
  214. )
  215. model = AutoModelForCausalLM.from_pretrained(
  216. model_id,
  217. device_map="auto",
  218. torch_dtype=torch.bfloat16,
  219. quantization_config=bnb_config, # None
  220. )
  221. tokenizer = AutoTokenizer.from_pretrained(model_id)
  222. pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
  223. return pipe
  224. def collect_response_from_llama(
  225. db_path_list, question_list, api_key, model, knowledge_list=None
  226. ):
  227. """
  228. :param db_path: str
  229. :param question_list: []
  230. :return: dict of responses
  231. """
  232. responses_dict = {}
  233. response_list = []
  234. if api_key in ["huggingface", "finetuned"]:
  235. pipe = huggingface_finetuned(api_key=api_key, model=model)
  236. for i, question in tqdm(enumerate(question_list)):
  237. print(
  238. "--------------------- processing question #{}---------------------".format(
  239. i + 1
  240. )
  241. )
  242. print("the question is: {}".format(question))
  243. if knowledge_list:
  244. cur_prompt = generate_combined_prompts_one(
  245. db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
  246. )
  247. else:
  248. cur_prompt = generate_combined_prompts_one(
  249. db_path=db_path_list[i], question=question
  250. )
  251. if api_key in ["huggingface", "finetuned"]:
  252. plain_result = local_llama(prompt=cur_prompt, pipe=pipe)
  253. else:
  254. plain_result = cloud_llama(
  255. api_key=api_key,
  256. model=model,
  257. prompt=cur_prompt,
  258. max_tokens=4096,
  259. temperature=0,
  260. stop=["--", "\n\n", ";", "#"],
  261. )
  262. if type(plain_result) == str:
  263. sql = plain_result
  264. else:
  265. sql = "SELECT" + plain_result["choices"][0]["text"]
  266. # responses_dict[i] = sql
  267. db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
  268. sql = (
  269. sql + "\t----- bird -----\t" + db_id
  270. ) # to avoid unpredicted \t appearing in codex results
  271. response_list.append(sql)
  272. return response_list
  273. def question_package(data_json, knowledge=False):
  274. question_list = []
  275. for data in data_json:
  276. question_list.append(data["question"])
  277. return question_list
  278. def knowledge_package(data_json, knowledge=False):
  279. knowledge_list = []
  280. for data in data_json:
  281. knowledge_list.append(data["evidence"])
  282. return knowledge_list
  283. def decouple_question_schema(datasets, db_root_path):
  284. question_list = []
  285. db_path_list = []
  286. knowledge_list = []
  287. for i, data in enumerate(datasets):
  288. question_list.append(data["question"])
  289. cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite"
  290. db_path_list.append(cur_db_path)
  291. knowledge_list.append(data["evidence"])
  292. return question_list, db_path_list, knowledge_list
  293. def generate_sql_file(sql_lst, output_path=None):
  294. result = {}
  295. for i, sql in enumerate(sql_lst):
  296. result[i] = sql
  297. if output_path:
  298. directory_path = os.path.dirname(output_path)
  299. new_directory(directory_path)
  300. json.dump(result, open(output_path, "w"), indent=4)
  301. return result
  302. if __name__ == "__main__":
  303. args_parser = argparse.ArgumentParser()
  304. args_parser.add_argument("--eval_path", type=str, default="")
  305. args_parser.add_argument("--mode", type=str, default="dev")
  306. args_parser.add_argument("--test_path", type=str, default="")
  307. args_parser.add_argument("--use_knowledge", type=str, default="True")
  308. args_parser.add_argument("--db_root_path", type=str, default="")
  309. args_parser.add_argument("--api_key", type=str, required=True)
  310. args_parser.add_argument("--model", type=str, required=True)
  311. args_parser.add_argument("--data_output_path", type=str)
  312. args = args_parser.parse_args()
  313. if not args.api_key in ["huggingface", "finetuned"]:
  314. if args.model.startswith("meta-llama/"): # Llama model on together
  315. os.environ["TOGETHER_API_KEY"] = args.api_key
  316. llm = ChatTogether(
  317. model=args.model,
  318. temperature=0,
  319. )
  320. try:
  321. response = llm.invoke("125*125 is?").content
  322. print(f"{response=}")
  323. except Exception as exception:
  324. print(f"{exception=}")
  325. exit(1)
  326. else: # Llama model on Llama API
  327. os.environ["LLAMA_API_KEY"] = args.api_key
  328. try:
  329. client = LlamaAPIClient()
  330. response = client.chat.completions.create(
  331. model=args.model,
  332. messages=[{"role": "user", "content": "125*125 is?"}],
  333. temperature=0,
  334. )
  335. answer = response.completion_message.content.text
  336. print(f"{answer=}")
  337. except Exception as exception:
  338. print(f"{exception=}")
  339. exit(1)
  340. eval_data = json.load(open(args.eval_path, "r"))
  341. # '''for debug'''
  342. # eval_data = eval_data[:3]
  343. # '''for debug'''
  344. question_list, db_path_list, knowledge_list = decouple_question_schema(
  345. datasets=eval_data, db_root_path=args.db_root_path
  346. )
  347. assert len(question_list) == len(db_path_list) == len(knowledge_list)
  348. if args.use_knowledge == "True":
  349. responses = collect_response_from_llama(
  350. db_path_list=db_path_list,
  351. question_list=question_list,
  352. api_key=args.api_key,
  353. model=args.model,
  354. knowledge_list=knowledge_list,
  355. )
  356. else:
  357. responses = collect_response_from_llama(
  358. db_path_list=db_path_list,
  359. question_list=question_list,
  360. api_key=args.api_key,
  361. model=args.model,
  362. knowledge_list=None,
  363. )
  364. output_name = args.data_output_path + "predict_" + args.mode + ".json"
  365. generate_sql_file(sql_lst=responses, output_path=output_name)
  366. print("successfully collect results from {}".format(args.model))