|
|
@@ -1,11 +1,10 @@
|
|
|
import argparse
|
|
|
import json
|
|
|
import multiprocessing as mp
|
|
|
-import re
|
|
|
import sqlite3
|
|
|
import sys
|
|
|
|
|
|
-from func_timeout import func_timeout, FunctionTimedOut
|
|
|
+from func_timeout import FunctionTimedOut, func_timeout
|
|
|
|
|
|
|
|
|
def load_json(dir):
|
|
|
@@ -45,10 +44,10 @@ def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out):
|
|
|
except KeyboardInterrupt:
|
|
|
sys.exit(0)
|
|
|
except FunctionTimedOut:
|
|
|
- result = [(f"timeout",)]
|
|
|
+ result = [("timeout",)]
|
|
|
res = 0
|
|
|
except Exception as e:
|
|
|
- result = [(f"error",)] # possibly len(query) > 512 or not executable
|
|
|
+ result = [("error",)] # possibly len(query) > 512 or not executable
|
|
|
res = 0
|
|
|
result = {"sql_idx": idx, "res": res}
|
|
|
return result
|
|
|
@@ -60,7 +59,7 @@ def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
|
|
|
if mode == "gpt":
|
|
|
sql_data = json.load(open(sql_path + "predict_" + data_mode + ".json", "r"))
|
|
|
for idx, sql_str in sql_data.items():
|
|
|
- if type(sql_str) == str:
|
|
|
+ if isinstance(sql_str, str):
|
|
|
sql, db_name = sql_str.split("\t----- bird -----\t")
|
|
|
else:
|
|
|
sql, db_name = " ", "financial"
|
|
|
@@ -83,7 +82,6 @@ def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
|
|
|
def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0):
|
|
|
pool = mp.Pool(processes=num_cpus)
|
|
|
for i, sql_pair in enumerate(sqls):
|
|
|
-
|
|
|
predicted_sql, ground_truth = sql_pair
|
|
|
pool.apply_async(
|
|
|
execute_model,
|