create_reasoning_dataset.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import argparse
  2. import json
  3. import os
  4. import pdb
  5. import pickle
  6. import re
  7. import sqlite3
  8. from typing import Dict, List, Tuple
  9. import sqlparse
  10. from datasets import Dataset, load_from_disk
  11. from langchain_together import ChatTogether
  12. from llama_api_client import LlamaAPIClient
  13. from tqdm import tqdm
  14. if (
  15. os.environ.get("LLAMA_API_KEY", "") == ""
  16. and os.environ.get("TOGETHER_API_KEY", "") == ""
  17. ):
  18. print(
  19. "Please set the environment variable LLAMA_API_KEY or TOGETHER_API_KEY to your API key."
  20. )
  21. exit(1)
  22. if os.environ.get("LLAMA_API_KEY", "") != "": # Llama model on Llama API
  23. try:
  24. client = LlamaAPIClient(api_key=os.environ["LLAMA_API_KEY"])
  25. response = client.chat.completions.create(
  26. model="Llama-3.3-70B-Instruct",
  27. messages=[{"role": "user", "content": "125*125 is?"}],
  28. temperature=0,
  29. )
  30. answer = response.completion_message.content.text
  31. except Exception as exception:
  32. print(f"Invalid LLAMA_API_KEY {exception=}")
  33. if os.environ.get("TOGETHER_API_KEY", "") != "": # Llama model on together
  34. llm = ChatTogether(
  35. model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
  36. temperature=0,
  37. )
  38. try:
  39. answer = llm.invoke("125*125 is?").content
  40. except Exception as exception:
  41. print(f"Invalid TOGETHER_API_KEY - {exception=}")
  42. exit(1)
  43. def llama(prompt, model="Llama-3.3-70B-Instruct"):
  44. if os.environ["LLAMA_API_KEY"] != "":
  45. client = LlamaAPIClient(api_key=os.environ["LLAMA_API_KEY"])
  46. response = client.chat.completions.create(
  47. model=model, messages=[{"role": "user", "content": prompt}], temperature=0
  48. )
  49. return response.completion_message.content.text
  50. else:
  51. llm = ChatTogether(
  52. model="meta-llama/Llama-3.3-70B-Instruct-Turbo",
  53. temperature=0,
  54. )
  55. answer = llm.invoke(prompt).content
  56. return answer
  57. def new_directory(path):
  58. if not os.path.exists(path):
  59. os.makedirs(path)
  60. def nice_look_table(column_names: list, values: list):
  61. rows = []
  62. # Determine the maximum width of each column
  63. widths = [
  64. max(len(str(value[i])) for value in values + [column_names])
  65. for i in range(len(column_names))
  66. ]
  67. # Print the column names
  68. header = "".join(
  69. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  70. )
  71. # Print the values
  72. for value in values:
  73. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  74. rows.append(row)
  75. rows = "\n".join(rows)
  76. final_output = header + "\n" + rows
  77. return final_output
  78. def generate_schema_prompt(db_path, num_rows=None):
  79. # extract create ddls
  80. """
  81. :param root_place:
  82. :param db_name:
  83. :return:
  84. """
  85. full_schema_prompt_list = []
  86. conn = sqlite3.connect(db_path)
  87. # Create a cursor object
  88. cursor = conn.cursor()
  89. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  90. tables = cursor.fetchall()
  91. schemas = {}
  92. for table in tables:
  93. if table == "sqlite_sequence":
  94. continue
  95. cursor.execute(
  96. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  97. table[0]
  98. )
  99. )
  100. create_prompt = cursor.fetchone()[0]
  101. schemas[table[0]] = create_prompt
  102. if num_rows:
  103. cur_table = table[0]
  104. if cur_table in ["order", "by", "group"]:
  105. cur_table = "`{}`".format(cur_table)
  106. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  107. column_names = [description[0] for description in cursor.description]
  108. values = cursor.fetchall()
  109. rows_prompt = nice_look_table(column_names=column_names, values=values)
  110. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  111. num_rows, cur_table, num_rows, rows_prompt
  112. )
  113. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  114. for k, v in schemas.items():
  115. full_schema_prompt_list.append(v)
  116. schema_prompt = "\n\n".join(full_schema_prompt_list)
  117. return schema_prompt
  118. def create_conversation(sample):
  119. return {
  120. "messages": [
  121. {"role": "system", "content": sample["messages"][0]["content"]},
  122. {"role": "user", "content": sample["messages"][1]["content"]},
  123. {"role": "assistant", "content": sample["messages"][2]["content"]},
  124. ]
  125. }
  126. def create_cot_dataset(input_json, db_root_path):
  127. cot_list = []
  128. diff = 0
  129. for i, item in enumerate(input_json):
  130. print(f"processing #{i+1}")
  131. db_id = item["db_id"]
  132. question = item["question"]
  133. external_knowledge = item["evidence"]
  134. gold_SQL = item["SQL"].strip()
  135. db_path = db_root_path + "/" + item["db_id"] + "/" + item["db_id"] + ".sqlite"
  136. # print(f"{db_path=}")
  137. db_schema = generate_schema_prompt(db_path)
  138. prompt_to_generate_reasoning = """
  139. You are a text to SQL query translator. Based on the DB Schema and External Knowledge, given the Text Question Input and its Gold SQL Output below, generate the step-by-step reasoning to infer the Gold SQL Output from the Text Question Input.
  140. -- DB Schema: {db_schema}
  141. -- External Knowledge: {external_knowledge}
  142. -- Text Question Input: {question}
  143. -- Gold SQL Output: {gold_SQL}
  144. Your response should be as follows:\n\n
  145. Let me think through this step by step:\n\n1. First, I need to consider...\n2. Then...\n3. Next...\n...\n\nFinally, the SQL statement for the text question is:
  146. ```sql ...```\n
  147. """
  148. prompt_to_generate_reasoning = (
  149. prompt_to_generate_reasoning.replace("{db_schema}", db_schema)
  150. .replace("{external_knowledge}", external_knowledge)
  151. .replace("{question}", question)
  152. .replace("{gold_SQL}", gold_SQL)
  153. )
  154. reasoning = llama(prompt_to_generate_reasoning)
  155. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  156. matches = pattern.findall(reasoning)
  157. if matches != []:
  158. gene_SQL = matches[0].replace("\n", "").strip()
  159. gene_SQL = re.sub(r"\s{2,}", " ", gene_SQL)
  160. else:
  161. gene_SQL = reasoning
  162. print(f"{diff=}\n{gold_SQL=}\n{gene_SQL=}")
  163. if gold_SQL != gene_SQL:
  164. diff += 1
  165. continue
  166. # use the reasoning generated above to generate an example for the reasoning dataset used for fine-tuning
  167. prompt = f"""
  168. -- DB Schema: {db_schema}
  169. -- External Knowledge: {external_knowledge}
  170. -- Text Question: {question}
  171. """
  172. cot = {
  173. "messages": [
  174. {
  175. "role": "system",
  176. "content": "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question.",
  177. },
  178. {"role": "user", "content": prompt},
  179. {"role": "assistant", "content": reasoning},
  180. ]
  181. }
  182. cot_list.append(cot)
  183. print(f"{diff=}, total: {len(input_json)}")
  184. dataset_dict = {key: [d[key] for d in cot_list] for key in cot_list[0]}
  185. hf_dataset = Dataset.from_dict(dataset_dict)
  186. hf_dataset.save_to_disk(f"text2sql_cot_dataset")
  187. dataset = load_from_disk("text2sql_cot_dataset")
  188. dataset = dataset.map(
  189. create_conversation, remove_columns=dataset.features, batched=False
  190. )
  191. dataset = dataset.train_test_split(test_size=0.3)
  192. dataset["train"].to_json("train_text2sql_cot_dataset.json", orient="records")
  193. dataset["test"].to_json("test_text2sql_cot_dataset.json", orient="records")
  194. if __name__ == "__main__":
  195. args_parser = argparse.ArgumentParser()
  196. args_parser.add_argument("--input_json", type=str, required=True)
  197. args_parser.add_argument("--db_root_path", type=str, required=True)
  198. args = args_parser.parse_args()
  199. input_json = json.load(open(args.input_json, "r"))
  200. db_root_path = args.db_root_path
  201. create_cot_dataset(input_json, db_root_path)
  202. # python create_reasoning_dataset.py --input_json ../data/train/train.json --db_root_path ../data/train/train_databases