|
|
@@ -0,0 +1,228 @@
|
|
|
+import argparse
|
|
|
+import json
|
|
|
+import os
|
|
|
+import pdb
|
|
|
+import pickle
|
|
|
+import re
|
|
|
+import sqlite3
|
|
|
+from typing import Dict, List, Tuple
|
|
|
+
|
|
|
+import sqlparse
|
|
|
+from datasets import Dataset
|
|
|
+
|
|
|
+from langchain_together import ChatTogether
|
|
|
+from llama_api_client import LlamaAPIClient
|
|
|
+from tqdm import tqdm
|
|
|
+
|
|
|
+
|
|
|
+if (
|
|
|
+ os.environ.get("LLAMA_API_KEY", "") == ""
|
|
|
+ and os.environ.get("TOGETHER_API_KEY", "") == ""
|
|
|
+):
|
|
|
+ print(
|
|
|
+ "Please set the environment variable LLAMA_API_KEY or TOGETHER_API_KEY to your API key."
|
|
|
+ )
|
|
|
+ exit(1)
|
|
|
+
|
|
|
+
|
|
|
+if os.environ.get("LLAMA_API_KEY", "") != "": # Llama model on Llama API
|
|
|
+ try:
|
|
|
+ client = LlamaAPIClient(api_key=os.environ["LLAMA_API_KEY"])
|
|
|
+
|
|
|
+ response = client.chat.completions.create(
|
|
|
+ model="Llama-3.3-70B-Instruct",
|
|
|
+ messages=[{"role": "user", "content": "125*125 is?"}],
|
|
|
+ temperature=0,
|
|
|
+ )
|
|
|
+ answer = response.completion_message.content.text
|
|
|
+ except Exception as exception:
|
|
|
+ print(f"Invalid LLAMA_API_KEY {exception=}")
|
|
|
+
|
|
|
+if os.environ.get("TOGETHER_API_KEY", "") != "": # Llama model on together
|
|
|
+ llm = ChatTogether(
|
|
|
+ model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
|
|
+ temperature=0,
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ answer = llm.invoke("125*125 is?").content
|
|
|
+ except Exception as exception:
|
|
|
+ print(f"Invalid TOGETHER_API_KEY - {exception=}")
|
|
|
+ exit(1)
|
|
|
+
|
|
|
+
|
|
|
+def llama(prompt, model="Llama-3.3-70B-Instruct"):
|
|
|
+
|
|
|
+ if os.environ["LLAMA_API_KEY"] != "":
|
|
|
+ client = LlamaAPIClient(api_key=os.environ["LLAMA_API_KEY"])
|
|
|
+ response = client.chat.completions.create(
|
|
|
+ model=model, messages=[{"role": "user", "content": prompt}], temperature=0
|
|
|
+ )
|
|
|
+ return response.completion_message.content.text
|
|
|
+ else:
|
|
|
+ llm = ChatTogether(
|
|
|
+ model="meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
|
|
+ temperature=0,
|
|
|
+ )
|
|
|
+ answer = llm.invoke(prompt).content
|
|
|
+ return answer
|
|
|
+
|
|
|
+
|
|
|
+def new_directory(path):
|
|
|
+ if not os.path.exists(path):
|
|
|
+ os.makedirs(path)
|
|
|
+
|
|
|
+
|
|
|
+def nice_look_table(column_names: list, values: list):
|
|
|
+ rows = []
|
|
|
+ # Determine the maximum width of each column
|
|
|
+ widths = [
|
|
|
+ max(len(str(value[i])) for value in values + [column_names])
|
|
|
+ for i in range(len(column_names))
|
|
|
+ ]
|
|
|
+
|
|
|
+ # Print the column names
|
|
|
+ header = "".join(
|
|
|
+ f"{column.rjust(width)} " for column, width in zip(column_names, widths)
|
|
|
+ )
|
|
|
+ # Print the values
|
|
|
+ for value in values:
|
|
|
+ row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
|
|
|
+ rows.append(row)
|
|
|
+ rows = "\n".join(rows)
|
|
|
+ final_output = header + "\n" + rows
|
|
|
+ return final_output
|
|
|
+
|
|
|
+
|
|
|
+def generate_schema_prompt(db_path, num_rows=None):
|
|
|
+ # extract create ddls
|
|
|
+ """
|
|
|
+ :param root_place:
|
|
|
+ :param db_name:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ full_schema_prompt_list = []
|
|
|
+ conn = sqlite3.connect(db_path)
|
|
|
+ # Create a cursor object
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
|
+ tables = cursor.fetchall()
|
|
|
+ schemas = {}
|
|
|
+ for table in tables:
|
|
|
+ if table == "sqlite_sequence":
|
|
|
+ continue
|
|
|
+ cursor.execute(
|
|
|
+ "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
|
|
|
+ table[0]
|
|
|
+ )
|
|
|
+ )
|
|
|
+ create_prompt = cursor.fetchone()[0]
|
|
|
+ schemas[table[0]] = create_prompt
|
|
|
+ if num_rows:
|
|
|
+ cur_table = table[0]
|
|
|
+ if cur_table in ["order", "by", "group"]:
|
|
|
+ cur_table = "`{}`".format(cur_table)
|
|
|
+
|
|
|
+ cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
|
|
|
+ column_names = [description[0] for description in cursor.description]
|
|
|
+ values = cursor.fetchall()
|
|
|
+ rows_prompt = nice_look_table(column_names=column_names, values=values)
|
|
|
+ verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
|
|
|
+ num_rows, cur_table, num_rows, rows_prompt
|
|
|
+ )
|
|
|
+ schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
|
|
|
+
|
|
|
+ for k, v in schemas.items():
|
|
|
+ full_schema_prompt_list.append(v)
|
|
|
+
|
|
|
+ schema_prompt = "\n\n".join(full_schema_prompt_list)
|
|
|
+
|
|
|
+ return schema_prompt
|
|
|
+
|
|
|
+
|
|
|
+def create_cot_dataset(input_json, db_root_path):
|
|
|
+ cot_list = []
|
|
|
+ diff = 0
|
|
|
+ for i, item in enumerate(input_json):
|
|
|
+ print(f"processing #{i+1}")
|
|
|
+
|
|
|
+ db_id = item["db_id"]
|
|
|
+ question = item["question"]
|
|
|
+ external_knowledge = item["evidence"]
|
|
|
+ gold_SQL = item["SQL"].strip()
|
|
|
+ db_path = db_root_path + "/" + item["db_id"] + "/" + item["db_id"] + ".sqlite"
|
|
|
+ # print(f"{db_path=}")
|
|
|
+ db_schema = generate_schema_prompt(db_path)
|
|
|
+
|
|
|
+ prompt_to_generate_reasoning = """
|
|
|
+ You are a text to SQL query translator. Based on the DB Schema and External Knowledge, given the Text Question Input and its Gold SQL Output below, generate the step-by-step reasoning to infer the Gold SQL Output from the Text Question Input.
|
|
|
+
|
|
|
+ -- DB Schema: {db_schema}
|
|
|
+ -- External Knowledge: {external_knowledge}
|
|
|
+ -- Text Question Input: {question}
|
|
|
+ -- Gold SQL Output: {gold_SQL}
|
|
|
+
|
|
|
+ Your response should be as follows:\n\n
|
|
|
+ Let me think through this step by step:\n\n1. First, I need to consider...\n2. Then...\n3. Next...\n...\n\nFinally, the SQL statement for the text question is:
|
|
|
+ ```sql ...```\n
|
|
|
+
|
|
|
+ """
|
|
|
+
|
|
|
+ prompt_to_generate_reasoning = (
|
|
|
+ prompt_to_generate_reasoning.replace("{db_schema}", db_schema)
|
|
|
+ .replace("{external_knowledge}", external_knowledge)
|
|
|
+ .replace("{question}", question)
|
|
|
+ .replace("{gold_SQL}", gold_SQL)
|
|
|
+ )
|
|
|
+ reasoning = llama(prompt_to_generate_reasoning)
|
|
|
+ # print(f"\n======\n{prompt_to_generate_reasoning=}\n\n")
|
|
|
+ # print(f"\n======\n{reasoning=}\n\n")
|
|
|
+
|
|
|
+ pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
|
|
|
+ matches = pattern.findall(reasoning)
|
|
|
+ if matches != []:
|
|
|
+ gene_SQL = matches[0].replace("\n", "").strip()
|
|
|
+ gene_SQL = re.sub(r"\s{2,}", " ", gene_SQL)
|
|
|
+ else:
|
|
|
+ gene_SQL = reasoning
|
|
|
+
|
|
|
+ print(f"{diff=}\n{gold_SQL=}\n{gene_SQL=}")
|
|
|
+ if gold_SQL != gene_SQL:
|
|
|
+ diff += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ # use the reasoning generated above to generate an example for the reasoning dataset used for fine-tuning
|
|
|
+ prompt = f"""
|
|
|
+ -- DB Schema: {db_schema}
|
|
|
+ -- External Knowledge: {external_knowledge}
|
|
|
+ -- Text Question: {question}
|
|
|
+"""
|
|
|
+ cot = {
|
|
|
+ "messages": [
|
|
|
+ {
|
|
|
+ "role": "system",
|
|
|
+ "content": "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question.",
|
|
|
+ },
|
|
|
+ {"role": "user", "content": prompt},
|
|
|
+ {"role": "assistant", "content": reasoning},
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ cot_list.append(cot)
|
|
|
+
|
|
|
+ print(f"{diff=}, total: {len(input_json)}")
|
|
|
+ dataset_dict = {key: [d[key] for d in cot_list] for key in cot_list[0]}
|
|
|
+ hf_dataset = Dataset.from_dict(dataset_dict)
|
|
|
+ hf_dataset.save_to_disk(f"text2sql_cot_dataset")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ args_parser = argparse.ArgumentParser()
|
|
|
+ args_parser.add_argument("--input_json", type=str, required=True)
|
|
|
+ args_parser.add_argument("--db_root_path", type=str, required=True)
|
|
|
+ args = args_parser.parse_args()
|
|
|
+
|
|
|
+ input_json = json.load(open(args.input_json, "r"))
|
|
|
+ db_root_path = args.db_root_path
|
|
|
+
|
|
|
+ create_cot_dataset(input_json, db_root_path)
|
|
|
+
|
|
|
+# python create_reasoning_dataset.py --input_json ../data/train/train.json --db_root_path ../data/train/train_databases
|