123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- import argparse
- import json
- import multiprocessing as mp
- import re
- import sqlite3
- import sys
- from func_timeout import func_timeout, FunctionTimedOut
- def load_json(dir):
- with open(dir, "r") as j:
- contents = json.loads(j.read())
- return contents
- def result_callback(result):
- exec_result.append(result)
- def execute_sql(predicted_sql, ground_truth, db_path):
- conn = sqlite3.connect(db_path)
- # Connect to the database
- cursor = conn.cursor()
- cursor.execute(predicted_sql)
- predicted_res = cursor.fetchall()
- cursor.execute(ground_truth)
- ground_truth_res = cursor.fetchall()
- res = 0
- if set(predicted_res) == set(ground_truth_res):
- res = 1
- else:
- print(
- f"\n\n==== INCORRECT SQL GENERATED ====\n{predicted_sql=}\n{predicted_res=}\n{ground_truth=}\n{ground_truth_res=}\n======\n\n"
- )
- return res
- def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out):
- try:
- res = func_timeout(
- meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place)
- )
- except KeyboardInterrupt:
- sys.exit(0)
- except FunctionTimedOut:
- result = [(f"timeout",)]
- res = 0
- except Exception as e:
- result = [(f"error",)] # possibly len(query) > 512 or not executable
- res = 0
- result = {"sql_idx": idx, "res": res}
- return result
- def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
- clean_sqls = []
- db_path_list = []
- if mode == "gpt":
- sql_data = json.load(open(sql_path + "predict_" + data_mode + ".json", "r"))
- for idx, sql_str in sql_data.items():
- if type(sql_str) == str:
- sql, db_name = sql_str.split("\t----- bird -----\t")
- else:
- sql, db_name = " ", "financial"
- clean_sqls.append(sql)
- db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite")
- elif mode == "gt": # ground truth
- items = json.load(open(db_root_path + "/../dev.json"))
- for item in items:
- sql = item["SQL"]
- db_name = item["db_id"]
- clean_sqls.append(sql)
- db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite")
- return clean_sqls, db_path_list
- def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0):
- pool = mp.Pool(processes=num_cpus)
- for i, sql_pair in enumerate(sqls):
- predicted_sql, ground_truth = sql_pair
- pool.apply_async(
- execute_model,
- args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out),
- callback=result_callback,
- )
- pool.close()
- pool.join()
- def sort_results(list_of_dicts):
- return sorted(list_of_dicts, key=lambda x: x["sql_idx"])
- def compute_acc_by_diff(exec_results, diff_json_path):
- num_queries = len(exec_results)
- results = [res["res"] for res in exec_results]
- contents = load_json(diff_json_path)
- simple_results, moderate_results, challenging_results = [], [], []
- for i, content in enumerate(contents):
- if content["difficulty"] == "simple":
- simple_results.append(exec_results[i])
- if content["difficulty"] == "moderate":
- moderate_results.append(exec_results[i])
- if content["difficulty"] == "challenging":
- challenging_results.append(exec_results[i])
- simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results)
- moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results)
- challenging_acc = (
- 0
- if len(challenging_results) == 0
- else sum([res["res"] for res in challenging_results]) / len(challenging_results)
- )
- all_acc = sum(results) / num_queries
- count_lists = [
- len(simple_results),
- len(moderate_results),
- len(challenging_results),
- num_queries,
- ]
- return (
- simple_acc * 100,
- moderate_acc * 100,
- challenging_acc * 100,
- all_acc * 100,
- count_lists,
- )
- def print_data(score_lists, count_lists):
- levels = ["simple", "moderate", "challenging", "total"]
- print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
- print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists))
- print(
- "====================================== ACCURACY ====================================="
- )
- print(
- "{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists)
- )
- if __name__ == "__main__":
- args_parser = argparse.ArgumentParser()
- args_parser.add_argument(
- "--predicted_sql_path", type=str, required=True, default=""
- )
- args_parser.add_argument("--ground_truth_path", type=str, required=True, default="")
- args_parser.add_argument("--data_mode", type=str, default="dev")
- args_parser.add_argument("--db_root_path", type=str, required=True, default="")
- args_parser.add_argument("--num_cpus", type=int, default=1)
- args_parser.add_argument("--meta_time_out", type=float, default=30.0)
- args_parser.add_argument("--mode_gt", type=str, default="gt")
- args_parser.add_argument("--mode_predict", type=str, default="gpt")
- args_parser.add_argument("--difficulty", type=str, default="simple")
- args_parser.add_argument("--diff_json_path", type=str, default="")
- args = args_parser.parse_args()
- exec_result = []
- pred_queries, db_paths = package_sqls(
- args.predicted_sql_path,
- args.db_root_path,
- mode=args.mode_predict,
- data_mode=args.data_mode,
- )
- # generate gt sqls:
- gt_queries, db_paths_gt = package_sqls(
- args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.data_mode
- )
- query_pairs = list(zip(pred_queries, gt_queries))
- run_sqls_parallel(
- query_pairs,
- db_places=db_paths,
- num_cpus=args.num_cpus,
- meta_time_out=args.meta_time_out,
- )
- exec_result = sort_results(exec_result)
- print("Evaluating statistics...")
- simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff(
- exec_result, args.diff_json_path
- )
- score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
- print_data(score_lists, count_lists)
- print(
- "==========================================================================================="
- )
|