فهرست منبع

some refactoring and cleaning

Amir Youssefi 3 ماه پیش
والد
کامیت
e10ddda64a

+ 3 - 11
end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py

@@ -1,18 +1,11 @@
 import argparse
-import fnmatch
 import json
 import os
-import pdb
-import pickle
 import re
 import sqlite3
-from typing import Dict, List, Tuple
-
-import pandas as pd
-import sqlparse
+from typing import Dict
 
 import torch
-from datasets import Dataset, load_dataset
 from langchain_together import ChatTogether
 from llama_api_client import LlamaAPIClient
 from peft import AutoPeftModelForCausalLM
@@ -284,7 +277,6 @@ def collect_response_from_llama(
     :param question_list: []
     :return: dict of responses
     """
-    responses_dict = {}
     response_list = []
 
     if api_key in ["huggingface", "finetuned"]:
@@ -318,7 +310,7 @@ def collect_response_from_llama(
                 temperature=0,
                 stop=["--", "\n\n", ";", "#"],
             )
-        if type(plain_result) == str:
+        if isinstance(plain_result, str):
             sql = plain_result
         else:
             sql = "SELECT" + plain_result["choices"][0]["text"]
@@ -387,7 +379,7 @@ if __name__ == "__main__":
     args_parser.add_argument("--data_output_path", type=str)
     args = args_parser.parse_args()
 
-    if not args.api_key in ["huggingface", "finetuned"]:
+    if args.api_key not in ["huggingface", "finetuned"]:
         if args.model.startswith("meta-llama/"):  # Llama model on together
 
             os.environ["TOGETHER_API_KEY"] = args.api_key

+ 4 - 6
end-to-end-use-cases/coding/text2sql/eval/text2sql_eval.py

@@ -1,11 +1,10 @@
 import argparse
 import json
 import multiprocessing as mp
-import re
 import sqlite3
 import sys
 
-from func_timeout import func_timeout, FunctionTimedOut
+from func_timeout import FunctionTimedOut, func_timeout
 
 
 def load_json(dir):
@@ -45,10 +44,10 @@ def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out):
     except KeyboardInterrupt:
         sys.exit(0)
     except FunctionTimedOut:
-        result = [(f"timeout",)]
+        result = [("timeout",)]
         res = 0
     except Exception as e:
-        result = [(f"error",)]  # possibly len(query) > 512 or not executable
+        result = [("error",)]  # possibly len(query) > 512 or not executable
         res = 0
     result = {"sql_idx": idx, "res": res}
     return result
@@ -60,7 +59,7 @@ def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
     if mode == "gpt":
         sql_data = json.load(open(sql_path + "predict_" + data_mode + ".json", "r"))
         for idx, sql_str in sql_data.items():
-            if type(sql_str) == str:
+            if isinstance(sql_str, str):
                 sql, db_name = sql_str.split("\t----- bird -----\t")
             else:
                 sql, db_name = " ", "financial"
@@ -83,7 +82,6 @@ def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
 def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0):
     pool = mp.Pool(processes=num_cpus)
     for i, sql_pair in enumerate(sqls):
-
         predicted_sql, ground_truth = sql_pair
         pool.apply_async(
             execute_model,

+ 2 - 9
end-to-end-use-cases/coding/text2sql/fine-tuning/create_reasoning_dataset.py

@@ -1,19 +1,12 @@
 import argparse
 import json
 import os
-import pdb
-import pickle
 import re
 import sqlite3
-from typing import Dict, List, Tuple
 
-import sqlparse
 from datasets import Dataset, load_from_disk
-
 from langchain_together import ChatTogether
 from llama_api_client import LlamaAPIClient
-from tqdm import tqdm
-
 
 if (
     os.environ.get("LLAMA_API_KEY", "") == ""
@@ -159,7 +152,7 @@ def create_cot_dataset(input_json, db_root_path):
         question = item["question"]
         external_knowledge = item["evidence"]
         gold_SQL = item["SQL"].strip()
-        db_path = db_root_path + "/" + item["db_id"] + "/" + item["db_id"] + ".sqlite"
+        db_path = db_root_path + "/" + db_id + "/" + db_id + ".sqlite"
         # print(f"{db_path=}")
         db_schema = generate_schema_prompt(db_path)
 
@@ -219,7 +212,7 @@ def create_cot_dataset(input_json, db_root_path):
     print(f"{diff=}, total: {len(input_json)}")
     dataset_dict = {key: [d[key] for d in cot_list] for key in cot_list[0]}
     hf_dataset = Dataset.from_dict(dataset_dict)
-    hf_dataset.save_to_disk(f"text2sql_cot_dataset")
+    hf_dataset.save_to_disk("text2sql_cot_dataset")
 
     dataset = load_from_disk("text2sql_cot_dataset")
     dataset = dataset.map(

+ 1 - 7
end-to-end-use-cases/coding/text2sql/fine-tuning/create_sft_dataset.py

@@ -1,15 +1,9 @@
 import argparse
 import json
 import os
-import pdb
-import pickle
-import re
 import sqlite3
-from typing import Dict, List, Tuple
 
-import sqlparse
 from datasets import Dataset
-
 from tqdm import tqdm
 
 
@@ -124,7 +118,7 @@ def create_sft_dataset(input_json, db_root_path):
         question = item["question"]
         external_knowledge = item["evidence"]
         SQL = item["SQL"]
-        db_path = db_root_path + "/" + item["db_id"] + "/" + item["db_id"] + ".sqlite"
+        db_path = db_root_path + "/" + db_id + "/" + db_id + ".sqlite"
         print(f"{db_path=}")
         prompt = generate_combined_prompts_one(
             db_path,