create_sft_dataset.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import argparse
  2. import json
  3. import os
  4. import pdb
  5. import pickle
  6. import re
  7. import sqlite3
  8. from typing import Dict, List, Tuple
  9. import sqlparse
  10. from datasets import Dataset
  11. from tqdm import tqdm
  12. def new_directory(path):
  13. if not os.path.exists(path):
  14. os.makedirs(path)
  15. def nice_look_table(column_names: list, values: list):
  16. rows = []
  17. # Determine the maximum width of each column
  18. widths = [
  19. max(len(str(value[i])) for value in values + [column_names])
  20. for i in range(len(column_names))
  21. ]
  22. # Print the column names
  23. header = "".join(
  24. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  25. )
  26. # print(header)
  27. # Print the values
  28. for value in values:
  29. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  30. rows.append(row)
  31. rows = "\n".join(rows)
  32. final_output = header + "\n" + rows
  33. return final_output
  34. def generate_schema_prompt(db_path, num_rows=None):
  35. # extract create ddls
  36. """
  37. :param root_place:
  38. :param db_name:
  39. :return:
  40. """
  41. full_schema_prompt_list = []
  42. conn = sqlite3.connect(db_path)
  43. # Create a cursor object
  44. cursor = conn.cursor()
  45. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  46. tables = cursor.fetchall()
  47. schemas = {}
  48. for table in tables:
  49. if table == "sqlite_sequence":
  50. continue
  51. cursor.execute(
  52. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  53. table[0]
  54. )
  55. )
  56. create_prompt = cursor.fetchone()[0]
  57. schemas[table[0]] = create_prompt
  58. if num_rows:
  59. cur_table = table[0]
  60. if cur_table in ["order", "by", "group"]:
  61. cur_table = "`{}`".format(cur_table)
  62. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  63. column_names = [description[0] for description in cursor.description]
  64. values = cursor.fetchall()
  65. rows_prompt = nice_look_table(column_names=column_names, values=values)
  66. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  67. num_rows, cur_table, num_rows, rows_prompt
  68. )
  69. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  70. for k, v in schemas.items():
  71. full_schema_prompt_list.append(v)
  72. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  73. return schema_prompt
  74. def generate_comment_prompt(question, knowledge=None):
  75. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  76. question_prompt = "-- Question: {}".format(question)
  77. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  78. return result_prompt
  79. def generate_combined_prompts_one(db_path, question, knowledge=None):
  80. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  81. comment_prompt = generate_comment_prompt(question, knowledge)
  82. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  83. return combined_prompts
  84. def create_conversation(sample):
  85. return {
  86. "messages": [
  87. {"role": "system", "content": sample["messages"][0]["content"]},
  88. {"role": "user", "content": sample["messages"][1]["content"]},
  89. {"role": "assistant", "content": sample["messages"][2]["content"]},
  90. ]
  91. }
  92. def create_sft_dataset(input_json, db_root_path):
  93. ds = []
  94. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  95. for i, item in tqdm(enumerate(input_json)):
  96. print(f"processing #{i+1}")
  97. db_id = item["db_id"]
  98. question = item["question"]
  99. external_knowledge = item["evidence"]
  100. SQL = item["SQL"]
  101. db_path = db_root_path + "/" + item["db_id"] + "/" + item["db_id"] + ".sqlite"
  102. print(f"{db_path=}")
  103. prompt = generate_combined_prompts_one(
  104. db_path,
  105. question,
  106. knowledge=external_knowledge,
  107. )
  108. example = {
  109. "messages": [
  110. {"role": "system", "content": SYSTEM_PROMPT},
  111. {"role": "user", "content": prompt},
  112. {"role": "assistant", "content": SQL},
  113. ]
  114. }
  115. ds.append(example)
  116. dataset_dict = {key: [d[key] for d in ds] for key in ds[0]}
  117. dataset = Dataset.from_dict(dataset_dict)
  118. # dataset.save_to_disk(f"text2sql_sft_dataset")
  119. dataset = dataset.map(
  120. create_conversation, remove_columns=dataset.features, batched=False
  121. )
  122. dataset = dataset.train_test_split(test_size=0.3)
  123. dataset["train"].to_json("train_text2sql_sft_dataset.json", orient="records")
  124. dataset["test"].to_json("test_text2sql_sft_dataset.json", orient="records")
  125. if __name__ == "__main__":
  126. args_parser = argparse.ArgumentParser()
  127. args_parser.add_argument("--input_json", type=str, required=True)
  128. args_parser.add_argument("--db_root_path", type=str, required=True)
  129. args = args_parser.parse_args()
  130. input_json = json.load(open(args.input_json, "r"))
  131. db_root_path = args.db_root_path
  132. create_sft_dataset(input_json, db_root_path)
  133. # python create_sft_dataset.py --input_json ../data/train/train.json --db_root_path ../data/train/train_databases