create_bird_eval_dataset.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import argparse
  2. import json
  3. import os
  4. import sqlite3
  5. import pandas as pd
  6. # from datasets import Dataset
  7. from tqdm import tqdm
  8. def new_directory(path):
  9. if not os.path.exists(path):
  10. os.makedirs(path)
  11. def nice_look_table(column_names: list, values: list):
  12. rows = []
  13. # Determine the maximum width of each column
  14. widths = [
  15. max(len(str(value[i])) for value in values + [column_names])
  16. for i in range(len(column_names))
  17. ]
  18. # Print the column names
  19. header = "".join(
  20. f"{column.rjust(width)} " for column, width in zip(column_names, widths)
  21. )
  22. # print(header)
  23. # Print the values
  24. for value in values:
  25. row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
  26. rows.append(row)
  27. rows = "\n".join(rows)
  28. final_output = header + "\n" + rows
  29. return final_output
  30. def generate_schema_prompt(db_path, num_rows=None):
  31. # extract create ddls
  32. """
  33. :param root_place:
  34. :param db_name:
  35. :return:
  36. """
  37. full_schema_prompt_list = []
  38. conn = sqlite3.connect(db_path)
  39. # Create a cursor object
  40. cursor = conn.cursor()
  41. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  42. tables = cursor.fetchall()
  43. schemas = {}
  44. for table in tables:
  45. if table == "sqlite_sequence":
  46. continue
  47. cursor.execute(
  48. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  49. table[0]
  50. )
  51. )
  52. create_prompt = cursor.fetchone()[0]
  53. schemas[table[0]] = create_prompt
  54. if num_rows:
  55. cur_table = table[0]
  56. if cur_table in ["order", "by", "group"]:
  57. cur_table = "`{}`".format(cur_table)
  58. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  59. column_names = [description[0] for description in cursor.description]
  60. values = cursor.fetchall()
  61. rows_prompt = nice_look_table(column_names=column_names, values=values)
  62. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  63. num_rows, cur_table, num_rows, rows_prompt
  64. )
  65. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  66. for k, v in schemas.items():
  67. full_schema_prompt_list.append(v)
  68. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  69. return schema_prompt
  70. def generate_comment_prompt(question, knowledge=None):
  71. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  72. question_prompt = "-- Question: {}".format(question)
  73. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  74. return result_prompt
  75. def generate_combined_prompts_one(db_path, question, knowledge=None):
  76. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  77. comment_prompt = generate_comment_prompt(question, knowledge)
  78. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  79. return combined_prompts
  80. def create_conversation(sample):
  81. return {
  82. "messages": [
  83. {"role": "system", "content": sample["messages"][0]["content"]},
  84. {"role": "user", "content": sample["messages"][1]["content"]},
  85. {"role": "assistant", "content": sample["messages"][2]["content"]},
  86. ]
  87. }
  88. def create_bird_eval_dataset(input_json, db_root_path):
  89. SYSTEM_PROMPT = (
  90. "You are a text to SQL query translator. Using the SQLite DB Schema and the "
  91. "External Knowledge, translate the following text question into a SQLite SQL "
  92. "select statement."
  93. )
  94. data = []
  95. for i, item in tqdm(enumerate(input_json)):
  96. print(f"processing #{i+1}")
  97. db_id = item["db_id"]
  98. question = item["question"]
  99. external_knowledge = item["evidence"]
  100. SQL = item["SQL"]
  101. db_path = db_root_path + "/" + db_id + "/" + db_id + ".sqlite"
  102. print(f"{db_path=}")
  103. prompt = generate_combined_prompts_one(
  104. db_path,
  105. question,
  106. knowledge=external_knowledge,
  107. )
  108. data.append(
  109. {
  110. "prompt": SYSTEM_PROMPT + "\n\n" + prompt,
  111. "gold_sql": SQL,
  112. "db_id": db_id,
  113. }
  114. )
  115. df = pd.DataFrame(data)
  116. df.to_csv("bird_dev_set_eval.csv", index=False)
  117. print(f"Dataset saved as bird_dev_set_eval.csv with {len(df)} rows")
  118. if __name__ == "__main__":
  119. args_parser = argparse.ArgumentParser()
  120. args_parser.add_argument("--input_json", type=str, required=True)
  121. args_parser.add_argument("--db_root_path", type=str, required=True)
  122. args = args_parser.parse_args()
  123. input_json = json.load(open(args.input_json, "r"))
  124. db_root_path = args.db_root_path
  125. create_bird_eval_dataset(input_json, db_root_path)
  126. # python3 create_bird_eval_dataset.py --input_json ../data/dev_20240627/dev.json --db_root_path ../data/dev_20240627/dev_databases