llama_text2sql.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. import argparse
  2. import json
  3. import os
  4. import re
  5. import sqlite3
  6. from typing import Dict
  7. import torch
  8. from langchain_together import ChatTogether
  9. from llama_api_client import LlamaAPIClient
  10. from peft import AutoPeftModelForCausalLM
  11. from tqdm import tqdm
  12. from transformers import (
  13. AutoModelForCausalLM,
  14. AutoTokenizer,
  15. BitsAndBytesConfig,
  16. pipeline,
  17. )
  18. MAX_NEW_TOKENS=10240 # If API has max tokens (vs max new tokens), we calculate it
  19. def local_llama(prompt, pipe):
  20. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  21. # UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
  22. # SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
  23. messages = [
  24. {"content": SYSTEM_PROMPT, "role": "system"},
  25. {"role": "user", "content": prompt},
  26. ]
  27. raw_prompt = pipe.tokenizer.apply_chat_template(
  28. messages,
  29. tokenize=False,
  30. add_generation_prompt=True,
  31. )
  32. print(f"local_llama: {raw_prompt=}")
  33. outputs = pipe(
  34. raw_prompt,
  35. max_new_tokens=MAX_NEW_TOKENS,
  36. do_sample=False,
  37. temperature=0.0,
  38. top_k=50,
  39. top_p=0.1,
  40. eos_token_id=pipe.tokenizer.eos_token_id,
  41. pad_token_id=pipe.tokenizer.pad_token_id,
  42. )
  43. answer = outputs[0]["generated_text"][len(raw_prompt) :].strip()
  44. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  45. matches = pattern.findall(answer)
  46. if matches != []:
  47. result = matches[0]
  48. else:
  49. result = answer
  50. print(f"{result=}")
  51. return result
  52. def new_directory(path):
  53. if not os.path.exists(path):
  54. os.makedirs(path)
  55. def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
  56. """
  57. Read an sqlite file, and return the CREATE commands for each of the tables in the database.
  58. """
  59. asdf = "database" if bench_root == "spider" else "databases"
  60. with sqlite3.connect(
  61. f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True
  62. ) as conn:
  63. # conn.text_factory = bytes
  64. cursor = conn.cursor()
  65. cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
  66. tables = cursor.fetchall()
  67. schemas = {}
  68. for table in tables:
  69. cursor.execute(
  70. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  71. table[0]
  72. )
  73. )
  74. schemas[table[0]] = cursor.fetchone()[0]
  75. return schemas
  76. def nice_look_table(column_names: list, values: list):
  77. rows = []
  78. # Determine the maximum width of each column
  79. widths = [
  80. max(len(str(value[i])) for value in values + [column_names])
  81. for i in range(len(column_names))
  82. ]
  83. # Print the column names
  84. header = "".join(
  85. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  86. )
  87. # print(header)
  88. # Print the values
  89. for value in values:
  90. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  91. rows.append(row)
  92. rows = "\n".join(rows)
  93. final_output = header + "\n" + rows
  94. return final_output
  95. def generate_schema_prompt(db_path, num_rows=None):
  96. # extract create ddls
  97. """
  98. :param root_place:
  99. :param db_name:
  100. :return:
  101. """
  102. full_schema_prompt_list = []
  103. conn = sqlite3.connect(db_path)
  104. # Create a cursor object
  105. cursor = conn.cursor()
  106. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  107. tables = cursor.fetchall()
  108. schemas = {}
  109. for table in tables:
  110. if table == "sqlite_sequence":
  111. continue
  112. cursor.execute(
  113. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  114. table[0]
  115. )
  116. )
  117. create_prompt = cursor.fetchone()[0]
  118. schemas[table[0]] = create_prompt
  119. if num_rows:
  120. cur_table = table[0]
  121. if cur_table in ["order", "by", "group"]:
  122. cur_table = "`{}`".format(cur_table)
  123. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  124. column_names = [description[0] for description in cursor.description]
  125. values = cursor.fetchall()
  126. rows_prompt = nice_look_table(column_names=column_names, values=values)
  127. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  128. num_rows, cur_table, num_rows, rows_prompt
  129. )
  130. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  131. for k, v in schemas.items():
  132. full_schema_prompt_list.append(v)
  133. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  134. return schema_prompt
  135. def generate_comment_prompt(question, knowledge=None):
  136. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  137. question_prompt = "-- Question: {}".format(question)
  138. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  139. return result_prompt
  140. def generate_combined_prompts_one(db_path, question, knowledge=None):
  141. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  142. comment_prompt = generate_comment_prompt(question, knowledge)
  143. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  144. return combined_prompts
  145. def cloud_llama(api_key, model, prompt, max_tokens, temperature, stop):
  146. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  147. try:
  148. if model.startswith("meta-llama/"):
  149. final_prompt = SYSTEM_PROMPT + "\n\n" + prompt
  150. final_max_tokens = len(final_prompt) + MAX_NEW_TOKENS
  151. llm = ChatTogether(
  152. model=model,
  153. temperature=0,
  154. max_tokens=final_max_tokens,
  155. )
  156. answer = llm.invoke(final_prompt).content
  157. else:
  158. client = LlamaAPIClient()
  159. messages = [
  160. {"content": SYSTEM_PROMPT, "role": "system"},
  161. {"role": "user", "content": prompt},
  162. ]
  163. final_max_tokens = len(messages) + MAX_NEW_TOKENS
  164. response = client.chat.completions.create(
  165. model=model,
  166. messages=messages,
  167. temperature=0,
  168. max_completion_tokens=final_max_tokens,
  169. )
  170. answer = response.completion_message.content.text
  171. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  172. matches = pattern.findall(answer)
  173. if matches != []:
  174. result = matches[0]
  175. else:
  176. result = answer
  177. print(result)
  178. except Exception as e:
  179. result = "error:{}".format(e)
  180. print(f"{result=}")
  181. return result
  182. def huggingface_finetuned(api_key, model):
  183. if api_key == "finetuned":
  184. model_id = model
  185. # Check if this is a PEFT model by looking for adapter_config.json
  186. import os
  187. is_peft_model = os.path.exists(os.path.join(model_id, "adapter_config.json"))
  188. if is_peft_model:
  189. # Use AutoPeftModelForCausalLM for PEFT fine-tuned models
  190. print(f"Loading PEFT model from {model_id}")
  191. model = AutoPeftModelForCausalLM.from_pretrained(
  192. model_id, device_map="auto", torch_dtype=torch.float16
  193. )
  194. tokenizer = AutoTokenizer.from_pretrained(model_id)
  195. else:
  196. # Use AutoModelForCausalLM for FFT (Full Fine-Tuning) models
  197. print(f"Loading FFT model from {model_id}")
  198. model = AutoModelForCausalLM.from_pretrained(
  199. model_id, device_map="auto", torch_dtype=torch.float16
  200. )
  201. tokenizer = AutoTokenizer.from_pretrained(model_id)
  202. # For FFT models, handle pad token if it was added during training
  203. if tokenizer.pad_token is None:
  204. tokenizer.add_special_tokens({"pad_token": "[PAD]"})
  205. model.resize_token_embeddings(len(tokenizer))
  206. tokenizer.padding_side = "right" # to prevent warnings
  207. elif api_key == "huggingface":
  208. model_id = model
  209. bnb_config = BitsAndBytesConfig(
  210. load_in_4bit=True,
  211. bnb_4bit_use_double_quant=True,
  212. bnb_4bit_quant_type="nf4",
  213. bnb_4bit_compute_dtype=torch.bfloat16,
  214. )
  215. model = AutoModelForCausalLM.from_pretrained(
  216. model_id,
  217. device_map="auto",
  218. torch_dtype=torch.bfloat16,
  219. quantization_config=bnb_config, # None
  220. )
  221. tokenizer = AutoTokenizer.from_pretrained(model_id)
  222. pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
  223. return pipe
  224. def collect_response_from_llama(
  225. db_path_list, question_list, api_key, model, knowledge_list=None
  226. ):
  227. """
  228. :param db_path: str
  229. :param question_list: []
  230. :return: dict of responses
  231. """
  232. response_list = []
  233. if api_key in ["huggingface", "finetuned"]:
  234. pipe = huggingface_finetuned(api_key=api_key, model=model)
  235. for i, question in tqdm(enumerate(question_list)):
  236. print(
  237. "--------------------- processing question #{}---------------------".format(
  238. i + 1
  239. )
  240. )
  241. print("the question is: {}".format(question))
  242. if knowledge_list:
  243. cur_prompt = generate_combined_prompts_one(
  244. db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
  245. )
  246. else:
  247. cur_prompt = generate_combined_prompts_one(
  248. db_path=db_path_list[i], question=question
  249. )
  250. if api_key in ["huggingface", "finetuned"]:
  251. plain_result = local_llama(prompt=cur_prompt, pipe=pipe)
  252. else:
  253. plain_result = cloud_llama(
  254. api_key=api_key,
  255. model=model,
  256. prompt=cur_prompt,
  257. max_tokens=10240,
  258. temperature=0,
  259. stop=["--", "\n\n", ";", "#"],
  260. )
  261. if isinstance(plain_result, str):
  262. sql = plain_result
  263. else:
  264. sql = "SELECT" + plain_result["choices"][0]["text"]
  265. # responses_dict[i] = sql
  266. db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
  267. sql = (
  268. sql + "\t----- bird -----\t" + db_id
  269. ) # to avoid unpredicted \t appearing in codex results
  270. response_list.append(sql)
  271. return response_list
  272. def question_package(data_json, knowledge=False):
  273. question_list = []
  274. for data in data_json:
  275. question_list.append(data["question"])
  276. return question_list
  277. def knowledge_package(data_json, knowledge=False):
  278. knowledge_list = []
  279. for data in data_json:
  280. knowledge_list.append(data["evidence"])
  281. return knowledge_list
  282. def decouple_question_schema(datasets, db_root_path):
  283. question_list = []
  284. db_path_list = []
  285. knowledge_list = []
  286. for i, data in enumerate(datasets):
  287. question_list.append(data["question"])
  288. cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite"
  289. db_path_list.append(cur_db_path)
  290. knowledge_list.append(data["evidence"])
  291. return question_list, db_path_list, knowledge_list
  292. def generate_sql_file(sql_lst, output_path=None):
  293. result = {}
  294. for i, sql in enumerate(sql_lst):
  295. result[i] = sql
  296. if output_path:
  297. directory_path = os.path.dirname(output_path)
  298. new_directory(directory_path)
  299. json.dump(result, open(output_path, "w"), indent=4)
  300. return result
  301. if __name__ == "__main__":
  302. args_parser = argparse.ArgumentParser()
  303. args_parser.add_argument("--eval_path", type=str, default="")
  304. args_parser.add_argument("--mode", type=str, default="dev")
  305. args_parser.add_argument("--test_path", type=str, default="")
  306. args_parser.add_argument("--use_knowledge", type=str, default="True")
  307. args_parser.add_argument("--db_root_path", type=str, default="")
  308. args_parser.add_argument("--api_key", type=str, required=True)
  309. args_parser.add_argument("--model", type=str, required=True)
  310. args_parser.add_argument("--data_output_path", type=str)
  311. args = args_parser.parse_args()
  312. if args.api_key not in ["huggingface", "finetuned"]:
  313. if args.model.startswith("meta-llama/"): # Llama model on together
  314. os.environ["TOGETHER_API_KEY"] = args.api_key
  315. llm = ChatTogether(
  316. model=args.model,
  317. temperature=0,
  318. )
  319. try:
  320. response = llm.invoke("125*125 is?").content
  321. print(f"{response=}")
  322. except Exception as exception:
  323. print(f"{exception=}")
  324. exit(1)
  325. else: # Llama model on Llama API
  326. os.environ["LLAMA_API_KEY"] = args.api_key
  327. try:
  328. client = LlamaAPIClient()
  329. response = client.chat.completions.create(
  330. model=args.model,
  331. messages=[{"role": "user", "content": "125*125 is?"}],
  332. temperature=0,
  333. )
  334. answer = response.completion_message.content.text
  335. print(f"{answer=}")
  336. except Exception as exception:
  337. print(f"{exception=}")
  338. exit(1)
  339. eval_data = json.load(open(args.eval_path, "r"))
  340. # '''for debug'''
  341. # eval_data = eval_data[:3]
  342. # '''for debug'''
  343. question_list, db_path_list, knowledge_list = decouple_question_schema(
  344. datasets=eval_data, db_root_path=args.db_root_path
  345. )
  346. assert len(question_list) == len(db_path_list) == len(knowledge_list)
  347. if args.use_knowledge == "True":
  348. responses = collect_response_from_llama(
  349. db_path_list=db_path_list,
  350. question_list=question_list,
  351. api_key=args.api_key,
  352. model=args.model,
  353. knowledge_list=knowledge_list,
  354. )
  355. else:
  356. responses = collect_response_from_llama(
  357. db_path_list=db_path_list,
  358. question_list=question_list,
  359. api_key=args.api_key,
  360. model=args.model,
  361. knowledge_list=None,
  362. )
  363. output_name = args.data_output_path + "predict_" + args.mode + ".json"
  364. generate_sql_file(sql_lst=responses, output_path=output_name)
  365. print("successfully collect results from {}".format(args.model))