text2sql_eval.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import argparse
  2. import json
  3. import multiprocessing as mp
  4. import sqlite3
  5. import sys
  6. from func_timeout import func_timeout, FunctionTimedOut
  7. from tqdm import tqdm
  8. def load_json(dir):
  9. with open(dir, "r") as j:
  10. contents = json.loads(j.read())
  11. return contents
  12. def result_callback(result):
  13. exec_result.append(result)
  14. def execute_sql(predicted_sql, ground_truth, db_path, debug=False):
  15. conn = sqlite3.connect(db_path)
  16. # Connect to the database
  17. cursor = conn.cursor()
  18. cursor.execute(predicted_sql)
  19. predicted_res = cursor.fetchall()
  20. cursor.execute(ground_truth)
  21. ground_truth_res = cursor.fetchall()
  22. res = 0
  23. if set(predicted_res) == set(ground_truth_res):
  24. res = 1
  25. elif debug:
  26. print(
  27. f"\n\n==== INCORRECT SQL GENERATED ====\n{predicted_sql=}\n{predicted_res=}\n{ground_truth=}\n{ground_truth_res=}\n======\n\n"
  28. )
  29. return res
  30. def execute_model(
  31. predicted_sql, ground_truth, db_place, idx, meta_time_out, debug=False
  32. ):
  33. try:
  34. res = func_timeout(
  35. meta_time_out,
  36. execute_sql,
  37. args=(predicted_sql, ground_truth, db_place, debug),
  38. )
  39. except KeyboardInterrupt:
  40. sys.exit(0)
  41. except FunctionTimedOut:
  42. result = [(f"timeout",)]
  43. res = 0
  44. except Exception as e:
  45. result = [(f"{e}",)] # possibly len(query) > 512 or not executable
  46. res = 0
  47. result = {"sql_idx": idx, "res": res}
  48. return result
  49. def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
  50. clean_sqls = []
  51. db_path_list = []
  52. if mode == "gpt":
  53. sql_data = json.load(open(sql_path + "predict_" + data_mode + ".json", "r"))
  54. for idx, sql_str in sql_data.items():
  55. if type(sql_str) == str:
  56. sql, db_name = sql_str.split("\t----- bird -----\t")
  57. else:
  58. sql, db_name = " ", "financial"
  59. clean_sqls.append(sql)
  60. db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite")
  61. elif mode == "gt": # ground truth
  62. items = json.load(open(db_root_path + "/../dev.json"))
  63. for item in items:
  64. sql = item["SQL"]
  65. db_name = item["db_id"]
  66. clean_sqls.append(sql)
  67. db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite")
  68. return clean_sqls, db_path_list
  69. def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0, debug=False):
  70. pool = mp.Pool(processes=num_cpus)
  71. # Create a progress bar if not in debug mode
  72. if not debug:
  73. pbar = tqdm(total=len(sqls), desc="Evaluating SQL queries")
  74. for i, sql_pair in enumerate(sqls):
  75. predicted_sql, ground_truth = sql_pair
  76. pool.apply_async(
  77. execute_model,
  78. args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out, debug),
  79. callback=lambda result: result_callback_with_progress(
  80. result, not debug, pbar
  81. ),
  82. )
  83. pool.close()
  84. pool.join()
  85. # Close the progress bar if not in debug mode
  86. if not debug:
  87. pbar.close()
  88. def result_callback_with_progress(result, use_progress, pbar=None):
  89. exec_result.append(result)
  90. if use_progress and pbar:
  91. pbar.update(1)
  92. def sort_results(list_of_dicts):
  93. return sorted(list_of_dicts, key=lambda x: x["sql_idx"])
  94. def compute_acc_by_diff(exec_results, diff_json_path):
  95. num_queries = len(exec_results)
  96. results = [res["res"] for res in exec_results]
  97. contents = load_json(diff_json_path)
  98. simple_results, moderate_results, challenging_results = [], [], []
  99. for i, content in enumerate(contents):
  100. if content["difficulty"] == "simple":
  101. simple_results.append(exec_results[i])
  102. if content["difficulty"] == "moderate":
  103. moderate_results.append(exec_results[i])
  104. if content["difficulty"] == "challenging":
  105. challenging_results.append(exec_results[i])
  106. simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results)
  107. moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results)
  108. challenging_acc = (
  109. 0
  110. if len(challenging_results) == 0
  111. else sum([res["res"] for res in challenging_results]) / len(challenging_results)
  112. )
  113. all_acc = sum(results) / num_queries
  114. count_lists = [
  115. len(simple_results),
  116. len(moderate_results),
  117. len(challenging_results),
  118. num_queries,
  119. ]
  120. return (
  121. simple_acc * 100,
  122. moderate_acc * 100,
  123. challenging_acc * 100,
  124. all_acc * 100,
  125. count_lists,
  126. )
  127. def print_data(score_lists, count_lists, debug=False):
  128. levels = ["simple", "moderate", "challenging", "total"]
  129. if debug:
  130. print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
  131. print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists))
  132. print(
  133. "====================================== ACCURACY ====================================="
  134. )
  135. else:
  136. print("\nEvaluation Results:")
  137. print("-" * 40)
  138. print(
  139. "{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists)
  140. )
  141. if __name__ == "__main__":
  142. args_parser = argparse.ArgumentParser()
  143. args_parser.add_argument(
  144. "--predicted_sql_path", type=str, required=True, default=""
  145. )
  146. args_parser.add_argument("--ground_truth_path", type=str, required=True, default="")
  147. args_parser.add_argument("--data_mode", type=str, default="dev")
  148. args_parser.add_argument("--db_root_path", type=str, required=True, default="")
  149. args_parser.add_argument("--num_cpus", type=int, default=1)
  150. args_parser.add_argument("--meta_time_out", type=float, default=30.0)
  151. args_parser.add_argument("--mode_gt", type=str, default="gt")
  152. args_parser.add_argument("--mode_predict", type=str, default="gpt")
  153. args_parser.add_argument("--difficulty", type=str, default="simple")
  154. args_parser.add_argument("--diff_json_path", type=str, default="")
  155. args_parser.add_argument(
  156. "--debug", action="store_true", help="Enable debug mode with detailed prints"
  157. )
  158. args = args_parser.parse_args()
  159. exec_result = []
  160. if args.debug:
  161. print("Debug mode enabled - showing detailed output")
  162. # Show loading progress if not in debug mode
  163. if not args.debug:
  164. print("Loading SQL queries and database paths...")
  165. pred_queries, db_paths = package_sqls(
  166. args.predicted_sql_path,
  167. args.db_root_path,
  168. mode=args.mode_predict,
  169. data_mode=args.data_mode,
  170. )
  171. # generate gt sqls:
  172. gt_queries, db_paths_gt = package_sqls(
  173. args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.data_mode
  174. )
  175. query_pairs = list(zip(pred_queries, gt_queries))
  176. if args.debug:
  177. print(f"Executing {len(query_pairs)} SQL query pairs...")
  178. run_sqls_parallel(
  179. query_pairs,
  180. db_places=db_paths,
  181. num_cpus=args.num_cpus,
  182. meta_time_out=args.meta_time_out,
  183. debug=args.debug,
  184. )
  185. exec_result = sort_results(exec_result)
  186. if args.debug:
  187. print("Evaluating statistics...")
  188. simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff(
  189. exec_result, args.diff_json_path
  190. )
  191. score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
  192. print_data(score_lists, count_lists, debug=args.debug)
  193. if args.debug:
  194. print(
  195. "==========================================================================================="
  196. )