create_reasoning_dataset.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. import argparse
  2. import json
  3. import os
  4. import re
  5. import sqlite3
  6. from datasets import Dataset, load_from_disk
  7. from langchain_together import ChatTogether
  8. from llama_api_client import LlamaAPIClient
  9. if (
  10. os.environ.get("LLAMA_API_KEY", "") == ""
  11. and os.environ.get("TOGETHER_API_KEY", "") == ""
  12. ):
  13. print(
  14. "Please set the environment variable LLAMA_API_KEY or TOGETHER_API_KEY to your API key."
  15. )
  16. exit(1)
  17. if os.environ.get("LLAMA_API_KEY", "") != "": # Llama model on Llama API
  18. try:
  19. client = LlamaAPIClient(api_key=os.environ["LLAMA_API_KEY"])
  20. response = client.chat.completions.create(
  21. model="Llama-3.3-70B-Instruct",
  22. messages=[{"role": "user", "content": "125*125 is?"}],
  23. temperature=0,
  24. )
  25. answer = response.completion_message.content.text
  26. except Exception as exception:
  27. print(f"Invalid LLAMA_API_KEY {exception=}")
  28. if os.environ.get("TOGETHER_API_KEY", "") != "": # Llama model on together
  29. llm = ChatTogether(
  30. model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
  31. temperature=0,
  32. )
  33. try:
  34. answer = llm.invoke("125*125 is?").content
  35. except Exception as exception:
  36. print(f"Invalid TOGETHER_API_KEY - {exception=}")
  37. exit(1)
  38. def llama(prompt, model="Llama-3.3-70B-Instruct"):
  39. if os.environ["LLAMA_API_KEY"] != "":
  40. client = LlamaAPIClient(api_key=os.environ["LLAMA_API_KEY"])
  41. response = client.chat.completions.create(
  42. model=model, messages=[{"role": "user", "content": prompt}], temperature=0
  43. )
  44. return response.completion_message.content.text
  45. else:
  46. llm = ChatTogether(
  47. model="meta-llama/Llama-3.3-70B-Instruct-Turbo",
  48. temperature=0,
  49. )
  50. answer = llm.invoke(prompt).content
  51. return answer
  52. def new_directory(path):
  53. if not os.path.exists(path):
  54. os.makedirs(path)
  55. def nice_look_table(column_names: list, values: list):
  56. rows = []
  57. # Determine the maximum width of each column
  58. widths = [
  59. max(len(str(value[i])) for value in values + [column_names])
  60. for i in range(len(column_names))
  61. ]
  62. # Print the column names
  63. header = "".join(
  64. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  65. )
  66. # Print the values
  67. for value in values:
  68. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  69. rows.append(row)
  70. rows = "\n".join(rows)
  71. final_output = header + "\n" + rows
  72. return final_output
  73. def generate_schema_prompt(db_path, num_rows=None):
  74. # extract create ddls
  75. """
  76. :param root_place:
  77. :param db_name:
  78. :return:
  79. """
  80. full_schema_prompt_list = []
  81. conn = sqlite3.connect(db_path)
  82. # Create a cursor object
  83. cursor = conn.cursor()
  84. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  85. tables = cursor.fetchall()
  86. schemas = {}
  87. for table in tables:
  88. if table == "sqlite_sequence":
  89. continue
  90. cursor.execute(
  91. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  92. table[0]
  93. )
  94. )
  95. create_prompt = cursor.fetchone()[0]
  96. schemas[table[0]] = create_prompt
  97. if num_rows:
  98. cur_table = table[0]
  99. if cur_table in ["order", "by", "group"]:
  100. cur_table = "`{}`".format(cur_table)
  101. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  102. column_names = [description[0] for description in cursor.description]
  103. values = cursor.fetchall()
  104. rows_prompt = nice_look_table(column_names=column_names, values=values)
  105. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  106. num_rows, cur_table, num_rows, rows_prompt
  107. )
  108. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  109. for k, v in schemas.items():
  110. full_schema_prompt_list.append(v)
  111. schema_prompt = "\n\n".join(full_schema_prompt_list)
  112. return schema_prompt
  113. def create_conversation(sample):
  114. return {
  115. "messages": [
  116. {"role": "system", "content": sample["messages"][0]["content"]},
  117. {"role": "user", "content": sample["messages"][1]["content"]},
  118. {"role": "assistant", "content": sample["messages"][2]["content"]},
  119. ]
  120. }
  121. def create_cot_dataset(input_json, db_root_path):
  122. cot_list = []
  123. diff = 0
  124. for i, item in enumerate(input_json):
  125. print(f"processing #{i+1}")
  126. db_id = item["db_id"]
  127. question = item["question"]
  128. external_knowledge = item["evidence"]
  129. gold_SQL = item["SQL"].strip()
  130. db_path = db_root_path + "/" + db_id + "/" + db_id + ".sqlite"
  131. # print(f"{db_path=}")
  132. db_schema = generate_schema_prompt(db_path)
  133. prompt_to_generate_reasoning = """
  134. You are a text to SQL query translator. Based on the DB Schema and External Knowledge, given the Text Question Input and its Gold SQL Output below, generate the step-by-step reasoning to infer the Gold SQL Output from the Text Question Input.
  135. -- DB Schema: {db_schema}
  136. -- External Knowledge: {external_knowledge}
  137. -- Text Question Input: {question}
  138. -- Gold SQL Output: {gold_SQL}
  139. Your response should be as follows:\n\n
  140. Let me think through this step by step:\n\n1. First, I need to consider...\n2. Then...\n3. Next...\n...\n\nFinally, the SQL statement for the text question is:
  141. ```sql ...```\n
  142. """
  143. prompt_to_generate_reasoning = (
  144. prompt_to_generate_reasoning.replace("{db_schema}", db_schema)
  145. .replace("{external_knowledge}", external_knowledge)
  146. .replace("{question}", question)
  147. .replace("{gold_SQL}", gold_SQL)
  148. )
  149. reasoning = llama(prompt_to_generate_reasoning)
  150. pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
  151. matches = pattern.findall(reasoning)
  152. if matches != []:
  153. gene_SQL = matches[0].replace("\n", "").strip()
  154. gene_SQL = re.sub(r"\s{2,}", " ", gene_SQL)
  155. else:
  156. gene_SQL = reasoning
  157. print(f"{diff=}\n{gold_SQL=}\n{gene_SQL=}")
  158. if gold_SQL != gene_SQL:
  159. diff += 1
  160. continue
  161. # use the reasoning generated above to generate an example for the reasoning dataset used for fine-tuning
  162. prompt = f"""
  163. -- DB Schema: {db_schema}
  164. -- External Knowledge: {external_knowledge}
  165. -- Text Question: {question}
  166. """
  167. cot = {
  168. "messages": [
  169. {
  170. "role": "system",
  171. "content": "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question.",
  172. },
  173. {"role": "user", "content": prompt},
  174. {"role": "assistant", "content": reasoning},
  175. ]
  176. }
  177. cot_list.append(cot)
  178. print(f"{diff=}, total: {len(input_json)}")
  179. dataset_dict = {key: [d[key] for d in cot_list] for key in cot_list[0]}
  180. hf_dataset = Dataset.from_dict(dataset_dict)
  181. hf_dataset.save_to_disk("text2sql_cot_dataset")
  182. dataset = load_from_disk("text2sql_cot_dataset")
  183. dataset = dataset.map(
  184. create_conversation, remove_columns=dataset.features, batched=False
  185. )
  186. dataset = dataset.train_test_split(test_size=0.3)
  187. dataset["train"].to_json("train_text2sql_cot_dataset.json", orient="records")
  188. dataset["test"].to_json("test_text2sql_cot_dataset.json", orient="records")
  189. if __name__ == "__main__":
  190. args_parser = argparse.ArgumentParser()
  191. args_parser.add_argument("--input_json", type=str, required=True)
  192. args_parser.add_argument("--db_root_path", type=str, required=True)
  193. args = args_parser.parse_args()
  194. input_json = json.load(open(args.input_json, "r"))
  195. db_root_path = args.db_root_path
  196. create_cot_dataset(input_json, db_root_path)
  197. # python create_reasoning_dataset.py --input_json ../data/train/train.json --db_root_path ../data/train/train_databases