llama_text2sql.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. import argparse
  2. import fnmatch
  3. import json
  4. import os
  5. import pdb
  6. import pickle
  7. import re
  8. import sqlite3
  9. from typing import Dict, List, Tuple
  10. import pandas as pd
  11. import sqlparse
  12. import torch
  13. from datasets import Dataset, load_dataset
  14. from langchain_together import ChatTogether
  15. from llama_api_client import LlamaAPIClient
  16. from peft import AutoPeftModelForCausalLM
  17. from tqdm import tqdm
  18. from transformers import (
  19. AutoModelForCausalLM,
  20. AutoTokenizer,
  21. BitsAndBytesConfig,
  22. pipeline,
  23. )
  24. MAX_NEW_TOKENS=10240 # If API has max tokens (vs max new tokens), we calculate it
  25. def local_llama(prompt, pipe):
  26. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  27. # UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
  28. # SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
  29. messages = [
  30. {"content": SYSTEM_PROMPT, "role": "system"},
  31. {"role": "user", "content": prompt},
  32. ]
  33. raw_prompt = pipe.tokenizer.apply_chat_template(
  34. messages,
  35. tokenize=False,
  36. add_generation_prompt=True,
  37. )
  38. print(f"local_llama: {raw_prompt=}")
  39. outputs = pipe(
  40. raw_prompt,
  41. max_new_tokens=MAX_NEW_TOKENS,
  42. do_sample=False,
  43. temperature=0.0,
  44. top_k=50,
  45. top_p=0.1,
  46. eos_token_id=pipe.tokenizer.eos_token_id,
  47. pad_token_id=pipe.tokenizer.pad_token_id,
  48. )
  49. answer = outputs[0]["generated_text"][len(raw_prompt) :].strip()
  50. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  51. matches = pattern.findall(answer)
  52. if matches != []:
  53. result = matches[0]
  54. else:
  55. result = answer
  56. print(f"{result=}")
  57. return result
  58. def new_directory(path):
  59. if not os.path.exists(path):
  60. os.makedirs(path)
  61. def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
  62. """
  63. Read an sqlite file, and return the CREATE commands for each of the tables in the database.
  64. """
  65. asdf = "database" if bench_root == "spider" else "databases"
  66. with sqlite3.connect(
  67. f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True
  68. ) as conn:
  69. # conn.text_factory = bytes
  70. cursor = conn.cursor()
  71. cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
  72. tables = cursor.fetchall()
  73. schemas = {}
  74. for table in tables:
  75. cursor.execute(
  76. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  77. table[0]
  78. )
  79. )
  80. schemas[table[0]] = cursor.fetchone()[0]
  81. return schemas
  82. def nice_look_table(column_names: list, values: list):
  83. rows = []
  84. # Determine the maximum width of each column
  85. widths = [
  86. max(len(str(value[i])) for value in values + [column_names])
  87. for i in range(len(column_names))
  88. ]
  89. # Print the column names
  90. header = "".join(
  91. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  92. )
  93. # print(header)
  94. # Print the values
  95. for value in values:
  96. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  97. rows.append(row)
  98. rows = "\n".join(rows)
  99. final_output = header + "\n" + rows
  100. return final_output
  101. def generate_schema_prompt(db_path, num_rows=None):
  102. # extract create ddls
  103. """
  104. :param root_place:
  105. :param db_name:
  106. :return:
  107. """
  108. full_schema_prompt_list = []
  109. conn = sqlite3.connect(db_path)
  110. # Create a cursor object
  111. cursor = conn.cursor()
  112. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  113. tables = cursor.fetchall()
  114. schemas = {}
  115. for table in tables:
  116. if table == "sqlite_sequence":
  117. continue
  118. cursor.execute(
  119. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  120. table[0]
  121. )
  122. )
  123. create_prompt = cursor.fetchone()[0]
  124. schemas[table[0]] = create_prompt
  125. if num_rows:
  126. cur_table = table[0]
  127. if cur_table in ["order", "by", "group"]:
  128. cur_table = "`{}`".format(cur_table)
  129. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  130. column_names = [description[0] for description in cursor.description]
  131. values = cursor.fetchall()
  132. rows_prompt = nice_look_table(column_names=column_names, values=values)
  133. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  134. num_rows, cur_table, num_rows, rows_prompt
  135. )
  136. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  137. for k, v in schemas.items():
  138. full_schema_prompt_list.append(v)
  139. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  140. return schema_prompt
  141. def generate_comment_prompt(question, knowledge=None):
  142. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  143. question_prompt = "-- Question: {}".format(question)
  144. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  145. return result_prompt
  146. def generate_combined_prompts_one(db_path, question, knowledge=None):
  147. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  148. comment_prompt = generate_comment_prompt(question, knowledge)
  149. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  150. return combined_prompts
  151. def cloud_llama(api_key, model, prompt, max_tokens, temperature, stop):
  152. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  153. try:
  154. if model.startswith("meta-llama/"):
  155. final_prompt = SYSTEM_PROMPT + "\n\n" + prompt
  156. final_max_tokens = len(final_prompt) + MAX_NEW_TOKENS
  157. llm = ChatTogether(
  158. model=model,
  159. temperature=0,
  160. max_tokens=final_max_tokens,
  161. )
  162. answer = llm.invoke(final_prompt).content
  163. else:
  164. client = LlamaAPIClient()
  165. messages = [
  166. {"content": SYSTEM_PROMPT, "role": "system"},
  167. {"role": "user", "content": prompt},
  168. ]
  169. final_max_tokens = len(messages) + MAX_NEW_TOKENS
  170. response = client.chat.completions.create(
  171. model=model,
  172. messages=messages,
  173. temperature=0,
  174. max_completion_tokens=final_max_tokens,
  175. )
  176. answer = response.completion_message.content.text
  177. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  178. matches = pattern.findall(answer)
  179. if matches != []:
  180. result = matches[0]
  181. else:
  182. result = answer
  183. print(result)
  184. except Exception as e:
  185. result = "error:{}".format(e)
  186. print(f"{result=}")
  187. return result
  188. def huggingface_finetuned(api_key, model):
  189. if api_key == "finetuned":
  190. model_id = model
  191. # Check if this is a PEFT model by looking for adapter_config.json
  192. import os
  193. is_peft_model = os.path.exists(os.path.join(model_id, "adapter_config.json"))
  194. if is_peft_model:
  195. # Use AutoPeftModelForCausalLM for PEFT fine-tuned models
  196. print(f"Loading PEFT model from {model_id}")
  197. model = AutoPeftModelForCausalLM.from_pretrained(
  198. model_id, device_map="auto", torch_dtype=torch.float16
  199. )
  200. tokenizer = AutoTokenizer.from_pretrained(model_id)
  201. else:
  202. # Use AutoModelForCausalLM for FFT (Full Fine-Tuning) models
  203. print(f"Loading FFT model from {model_id}")
  204. model = AutoModelForCausalLM.from_pretrained(
  205. model_id, device_map="auto", torch_dtype=torch.float16
  206. )
  207. tokenizer = AutoTokenizer.from_pretrained(model_id)
  208. # For FFT models, handle pad token if it was added during training
  209. if tokenizer.pad_token is None:
  210. tokenizer.add_special_tokens({"pad_token": "[PAD]"})
  211. model.resize_token_embeddings(len(tokenizer))
  212. tokenizer.padding_side = "right" # to prevent warnings
  213. elif api_key == "huggingface":
  214. model_id = model
  215. bnb_config = BitsAndBytesConfig(
  216. load_in_4bit=True,
  217. bnb_4bit_use_double_quant=True,
  218. bnb_4bit_quant_type="nf4",
  219. bnb_4bit_compute_dtype=torch.bfloat16,
  220. )
  221. model = AutoModelForCausalLM.from_pretrained(
  222. model_id,
  223. device_map="auto",
  224. torch_dtype=torch.bfloat16,
  225. quantization_config=bnb_config, # None
  226. )
  227. tokenizer = AutoTokenizer.from_pretrained(model_id)
  228. pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
  229. return pipe
  230. def collect_response_from_llama(
  231. db_path_list, question_list, api_key, model, knowledge_list=None
  232. ):
  233. """
  234. :param db_path: str
  235. :param question_list: []
  236. :return: dict of responses
  237. """
  238. responses_dict = {}
  239. response_list = []
  240. if api_key in ["huggingface", "finetuned"]:
  241. pipe = huggingface_finetuned(api_key=api_key, model=model)
  242. for i, question in tqdm(enumerate(question_list)):
  243. print(
  244. "--------------------- processing question #{}---------------------".format(
  245. i + 1
  246. )
  247. )
  248. print("the question is: {}".format(question))
  249. if knowledge_list:
  250. cur_prompt = generate_combined_prompts_one(
  251. db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
  252. )
  253. else:
  254. cur_prompt = generate_combined_prompts_one(
  255. db_path=db_path_list[i], question=question
  256. )
  257. if api_key in ["huggingface", "finetuned"]:
  258. plain_result = local_llama(prompt=cur_prompt, pipe=pipe)
  259. else:
  260. plain_result = cloud_llama(
  261. api_key=api_key,
  262. model=model,
  263. prompt=cur_prompt,
  264. max_tokens=10240,
  265. temperature=0,
  266. stop=["--", "\n\n", ";", "#"],
  267. )
  268. if type(plain_result) == str:
  269. sql = plain_result
  270. else:
  271. sql = "SELECT" + plain_result["choices"][0]["text"]
  272. # responses_dict[i] = sql
  273. db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
  274. sql = (
  275. sql + "\t----- bird -----\t" + db_id
  276. ) # to avoid unpredicted \t appearing in codex results
  277. response_list.append(sql)
  278. return response_list
  279. def question_package(data_json, knowledge=False):
  280. question_list = []
  281. for data in data_json:
  282. question_list.append(data["question"])
  283. return question_list
  284. def knowledge_package(data_json, knowledge=False):
  285. knowledge_list = []
  286. for data in data_json:
  287. knowledge_list.append(data["evidence"])
  288. return knowledge_list
  289. def decouple_question_schema(datasets, db_root_path):
  290. question_list = []
  291. db_path_list = []
  292. knowledge_list = []
  293. for i, data in enumerate(datasets):
  294. question_list.append(data["question"])
  295. cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite"
  296. db_path_list.append(cur_db_path)
  297. knowledge_list.append(data["evidence"])
  298. return question_list, db_path_list, knowledge_list
  299. def generate_sql_file(sql_lst, output_path=None):
  300. result = {}
  301. for i, sql in enumerate(sql_lst):
  302. result[i] = sql
  303. if output_path:
  304. directory_path = os.path.dirname(output_path)
  305. new_directory(directory_path)
  306. json.dump(result, open(output_path, "w"), indent=4)
  307. return result
  308. if __name__ == "__main__":
  309. args_parser = argparse.ArgumentParser()
  310. args_parser.add_argument("--eval_path", type=str, default="")
  311. args_parser.add_argument("--mode", type=str, default="dev")
  312. args_parser.add_argument("--test_path", type=str, default="")
  313. args_parser.add_argument("--use_knowledge", type=str, default="True")
  314. args_parser.add_argument("--db_root_path", type=str, default="")
  315. args_parser.add_argument("--api_key", type=str, required=True)
  316. args_parser.add_argument("--model", type=str, required=True)
  317. args_parser.add_argument("--data_output_path", type=str)
  318. args = args_parser.parse_args()
  319. if not args.api_key in ["huggingface", "finetuned"]:
  320. if args.model.startswith("meta-llama/"): # Llama model on together
  321. os.environ["TOGETHER_API_KEY"] = args.api_key
  322. llm = ChatTogether(
  323. model=args.model,
  324. temperature=0,
  325. )
  326. try:
  327. response = llm.invoke("125*125 is?").content
  328. print(f"{response=}")
  329. except Exception as exception:
  330. print(f"{exception=}")
  331. exit(1)
  332. else: # Llama model on Llama API
  333. os.environ["LLAMA_API_KEY"] = args.api_key
  334. try:
  335. client = LlamaAPIClient()
  336. response = client.chat.completions.create(
  337. model=args.model,
  338. messages=[{"role": "user", "content": "125*125 is?"}],
  339. temperature=0,
  340. )
  341. answer = response.completion_message.content.text
  342. print(f"{answer=}")
  343. except Exception as exception:
  344. print(f"{exception=}")
  345. exit(1)
  346. eval_data = json.load(open(args.eval_path, "r"))
  347. # '''for debug'''
  348. # eval_data = eval_data[:3]
  349. # '''for debug'''
  350. question_list, db_path_list, knowledge_list = decouple_question_schema(
  351. datasets=eval_data, db_root_path=args.db_root_path
  352. )
  353. assert len(question_list) == len(db_path_list) == len(knowledge_list)
  354. if args.use_knowledge == "True":
  355. responses = collect_response_from_llama(
  356. db_path_list=db_path_list,
  357. question_list=question_list,
  358. api_key=args.api_key,
  359. model=args.model,
  360. knowledge_list=knowledge_list,
  361. )
  362. else:
  363. responses = collect_response_from_llama(
  364. db_path_list=db_path_list,
  365. question_list=question_list,
  366. api_key=args.api_key,
  367. model=args.model,
  368. knowledge_list=None,
  369. )
  370. output_name = args.data_output_path + "predict_" + args.mode + ".json"
  371. generate_sql_file(sql_lst=responses, output_path=output_name)
  372. print("successfully collect results from {}".format(args.model))