create_sft_dataset.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import argparse
  2. import json
  3. import os
  4. import sqlite3
  5. from datasets import Dataset
  6. from tqdm import tqdm
  7. def new_directory(path):
  8. if not os.path.exists(path):
  9. os.makedirs(path)
  10. def nice_look_table(column_names: list, values: list):
  11. rows = []
  12. # Determine the maximum width of each column
  13. widths = [
  14. max(len(str(value[i])) for value in values + [column_names])
  15. for i in range(len(column_names))
  16. ]
  17. # Print the column names
  18. header = "".join(
  19. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  20. )
  21. # print(header)
  22. # Print the values
  23. for value in values:
  24. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  25. rows.append(row)
  26. rows = "\n".join(rows)
  27. final_output = header + "\n" + rows
  28. return final_output
  29. def generate_schema_prompt(db_path, num_rows=None):
  30. # extract create ddls
  31. """
  32. :param root_place:
  33. :param db_name:
  34. :return:
  35. """
  36. full_schema_prompt_list = []
  37. conn = sqlite3.connect(db_path)
  38. # Create a cursor object
  39. cursor = conn.cursor()
  40. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  41. tables = cursor.fetchall()
  42. schemas = {}
  43. for table in tables:
  44. if table == "sqlite_sequence":
  45. continue
  46. cursor.execute(
  47. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  48. table[0]
  49. )
  50. )
  51. create_prompt = cursor.fetchone()[0]
  52. schemas[table[0]] = create_prompt
  53. if num_rows:
  54. cur_table = table[0]
  55. if cur_table in ["order", "by", "group"]:
  56. cur_table = "`{}`".format(cur_table)
  57. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  58. column_names = [description[0] for description in cursor.description]
  59. values = cursor.fetchall()
  60. rows_prompt = nice_look_table(column_names=column_names, values=values)
  61. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  62. num_rows, cur_table, num_rows, rows_prompt
  63. )
  64. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  65. for k, v in schemas.items():
  66. full_schema_prompt_list.append(v)
  67. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  68. return schema_prompt
  69. def generate_comment_prompt(question, knowledge=None):
  70. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  71. question_prompt = "-- Question: {}".format(question)
  72. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  73. return result_prompt
  74. def generate_combined_prompts_one(db_path, question, knowledge=None):
  75. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  76. comment_prompt = generate_comment_prompt(question, knowledge)
  77. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  78. return combined_prompts
  79. def create_conversation(sample):
  80. return {
  81. "messages": [
  82. {"role": "system", "content": sample["messages"][0]["content"]},
  83. {"role": "user", "content": sample["messages"][1]["content"]},
  84. {"role": "assistant", "content": sample["messages"][2]["content"]},
  85. ]
  86. }
  87. def create_sft_dataset(input_json, db_root_path):
  88. ds = []
  89. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  90. for i, item in tqdm(enumerate(input_json)):
  91. print(f"processing #{i+1}")
  92. db_id = item["db_id"]
  93. question = item["question"]
  94. external_knowledge = item["evidence"]
  95. SQL = item["SQL"]
  96. db_path = db_root_path + "/" + db_id + "/" + db_id + ".sqlite"
  97. print(f"{db_path=}")
  98. prompt = generate_combined_prompts_one(
  99. db_path,
  100. question,
  101. knowledge=external_knowledge,
  102. )
  103. example = {
  104. "messages": [
  105. {"role": "system", "content": SYSTEM_PROMPT},
  106. {"role": "user", "content": prompt},
  107. {"role": "assistant", "content": SQL},
  108. ]
  109. }
  110. ds.append(example)
  111. dataset_dict = {key: [d[key] for d in ds] for key in ds[0]}
  112. dataset = Dataset.from_dict(dataset_dict)
  113. # dataset.save_to_disk(f"text2sql_sft_dataset")
  114. dataset = dataset.map(
  115. create_conversation, remove_columns=dataset.features, batched=False
  116. )
  117. dataset = dataset.train_test_split(test_size=0.3)
  118. dataset["train"].to_json("train_text2sql_sft_dataset.json", orient="records")
  119. dataset["test"].to_json("test_text2sql_sft_dataset.json", orient="records")
  120. if __name__ == "__main__":
  121. args_parser = argparse.ArgumentParser()
  122. args_parser.add_argument("--input_json", type=str, required=True)
  123. args_parser.add_argument("--db_root_path", type=str, required=True)
  124. args = args_parser.parse_args()
  125. input_json = json.load(open(args.input_json, "r"))
  126. db_root_path = args.db_root_path
  127. create_sft_dataset(input_json, db_root_path)
  128. # python create_sft_dataset.py --input_json ../data/train/train.json --db_root_path ../data/train/train_databases