瀏覽代碼

text2sql eval and ft tools

jeffxtang 4 月之前
父節點
當前提交
ebed0efb88

文件差異過大導致無法顯示
+ 127 - 0
end-to-end-use-cases/coding/text2sql/tool/README.md


+ 9 - 0
end-to-end-use-cases/coding/text2sql/tool/data/download_dev_unzip.sh

@@ -0,0 +1,9 @@
+wget https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip
+unzip dev.zip
+rm dev.zip
+rm -rf __MACOSX
+cd dev_20240627
+unzip dev_databases.zip
+rm dev_databases.zip
+rm -rf __MACOSX
+cd ..

+ 9 - 0
end-to-end-use-cases/coding/text2sql/tool/data/download_train_unzip.sh

@@ -0,0 +1,9 @@
+wget https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip
+UNZIP_DISABLE_ZIPBOMB_DETECTION=TRUE unzip train.zip
+rm train.zip
+rm -rf __MACOSX
+cd train
+unzip train_databases.zip
+rm train_databases.zip
+rm -rf __MACOSX
+cd ..

+ 169 - 0
end-to-end-use-cases/coding/text2sql/tool/fine_tuning/create_sft_dataset.py

@@ -0,0 +1,169 @@
+import argparse
+import json
+import os
+import pdb
+import pickle
+import re
+import sqlite3
+from typing import Dict, List, Tuple
+
+import sqlparse
+from datasets import Dataset
+
+from tqdm import tqdm
+
+
+def new_directory(path):
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+
+def nice_look_table(column_names: list, values: list):
+    rows = []
+    # Determine the maximum width of each column
+    widths = [
+        max(len(str(value[i])) for value in values + [column_names])
+        for i in range(len(column_names))
+    ]
+
+    # Print the column names
+    header = "".join(
+        f"{column.rjust(width)} " for column, width in zip(column_names, widths)
+    )
+    # print(header)
+    # Print the values
+    for value in values:
+        row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
+        rows.append(row)
+    rows = "\n".join(rows)
+    final_output = header + "\n" + rows
+    return final_output
+
+
+def generate_schema_prompt(db_path, num_rows=None):
+    # extract create ddls
+    """
+    :param root_place:
+    :param db_name:
+    :return:
+    """
+    full_schema_prompt_list = []
+    conn = sqlite3.connect(db_path)
+    # Create a cursor object
+    cursor = conn.cursor()
+    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
+    tables = cursor.fetchall()
+    schemas = {}
+    for table in tables:
+        if table == "sqlite_sequence":
+            continue
+        cursor.execute(
+            "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
+                table[0]
+            )
+        )
+        create_prompt = cursor.fetchone()[0]
+        schemas[table[0]] = create_prompt
+        if num_rows:
+            cur_table = table[0]
+            if cur_table in ["order", "by", "group"]:
+                cur_table = "`{}`".format(cur_table)
+
+            cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
+            column_names = [description[0] for description in cursor.description]
+            values = cursor.fetchall()
+            rows_prompt = nice_look_table(column_names=column_names, values=values)
+            verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
+                num_rows, cur_table, num_rows, rows_prompt
+            )
+            schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
+
+    for k, v in schemas.items():
+        full_schema_prompt_list.append(v)
+
+    schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
+
+    return schema_prompt
+
+
+def generate_comment_prompt(question, knowledge=None):
+    knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
+    question_prompt = "-- Question: {}".format(question)
+
+    result_prompt = knowledge_prompt + "\n\n" + question_prompt
+
+    return result_prompt
+
+
+def generate_combined_prompts_one(db_path, question, knowledge=None):
+    schema_prompt = generate_schema_prompt(db_path, num_rows=None)
+    comment_prompt = generate_comment_prompt(question, knowledge)
+
+    combined_prompts = schema_prompt + "\n\n" + comment_prompt
+
+    return combined_prompts
+
+
+def create_conversation(sample):
+    return {
+        "messages": [
+            {"role": "system", "content": sample["messages"][0]["content"]},
+            {"role": "user", "content": sample["messages"][1]["content"]},
+            {"role": "assistant", "content": sample["messages"][2]["content"]},
+        ]
+    }
+
+
+def create_sft_dataset(input_json, db_root_path):
+    ds = []
+    SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
+
+    for i, item in tqdm(enumerate(input_json)):
+        print(f"processing #{i+1}")
+        db_id = item["db_id"]
+        question = item["question"]
+        external_knowledge = item["evidence"]
+        SQL = item["SQL"]
+        db_path = db_root_path + "/" + item["db_id"] + "/" + item["db_id"] + ".sqlite"
+        print(f"{db_path=}")
+        prompt = generate_combined_prompts_one(
+            db_path,
+            question,
+            knowledge=external_knowledge,
+        )
+
+        example = {
+            "messages": [
+                {"role": "system", "content": SYSTEM_PROMPT},
+                {"role": "user", "content": prompt},
+                {"role": "assistant", "content": SQL},
+            ]
+        }
+
+        ds.append(example)
+
+    dataset_dict = {key: [d[key] for d in ds] for key in ds[0]}
+    dataset = Dataset.from_dict(dataset_dict)
+    # dataset.save_to_disk(f"text2sql_sft_dataset")
+
+    dataset = dataset.map(
+        create_conversation, remove_columns=dataset.features, batched=False
+    )
+    dataset = dataset.train_test_split(test_size=0.3)
+
+    dataset["train"].to_json("train_text2sql_sft_dataset.json", orient="records")
+    dataset["test"].to_json("test_text2sql_sft_dataset.json", orient="records")
+
+
+if __name__ == "__main__":
+    args_parser = argparse.ArgumentParser()
+    args_parser.add_argument("--input_json", type=str, required=True)
+    args_parser.add_argument("--db_root_path", type=str, required=True)
+    args = args_parser.parse_args()
+
+    input_json = json.load(open(args.input_json, "r"))
+    db_root_path = args.db_root_path
+
+    create_sft_dataset(input_json, db_root_path)
+
+# python create_sft_dataset.py --input_json ../data/train/train.json --db_root_path ../data/train/train_databases

+ 83 - 0
end-to-end-use-cases/coding/text2sql/tool/fine_tuning/trl_sft.py

@@ -0,0 +1,83 @@
+# Source: https://www.philschmid.de/fine-tune-llms-in-2024-with-trl
+
+import torch
+from datasets import load_dataset
+from peft import LoraConfig
+from transformers import (
+    AutoModelForCausalLM,
+    AutoTokenizer,
+    BitsAndBytesConfig,
+    TrainingArguments,
+)
+from trl import setup_chat_format, SFTTrainer
+
+dataset = load_dataset(
+    "json", data_files="train_text2sql_sft_dataset.json", split="train"
+)
+
+model_id = "meta-llama/Llama-3.1-8B-Instruct"
+
+bnb_config = BitsAndBytesConfig(
+    load_in_4bit=True,
+    bnb_4bit_use_double_quant=True,
+    bnb_4bit_quant_type="nf4",
+    bnb_4bit_compute_dtype=torch.bfloat16,
+)
+
+model = AutoModelForCausalLM.from_pretrained(
+    model_id,
+    device_map="auto",
+    torch_dtype=torch.bfloat16,
+    quantization_config=bnb_config,
+)
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+tokenizer.padding_side = "right"
+
+if tokenizer.pad_token is None:
+    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
+    model.resize_token_embeddings(len(tokenizer))
+
+peft_config = LoraConfig(
+    lora_alpha=128,
+    lora_dropout=0.05,
+    r=256,
+    bias="none",
+    target_modules="all-linear",
+    task_type="CAUSAL_LM",
+)
+
+args = TrainingArguments(
+    output_dir="llama31-8b-text2sql-epochs-3",  # directory to save and repository id
+    num_train_epochs=3,  # number of training epochs
+    per_device_train_batch_size=3,  # batch size per device during training
+    gradient_accumulation_steps=2,  # number of steps before performing a backward/update pass
+    gradient_checkpointing=True,  # use gradient checkpointing to save memory
+    optim="adamw_torch_fused",  # use fused adamw optimizer
+    logging_steps=10,  # log every 10 steps
+    save_strategy="epoch",  # save checkpoint every epoch
+    learning_rate=2e-4,  # learning rate, based on QLoRA paper
+    bf16=True,  # use bfloat16 precision
+    tf32=True,  # use tf32 precision
+    max_grad_norm=0.3,  # max gradient norm based on QLoRA paper
+    warmup_ratio=0.03,  # warmup ratio based on QLoRA paper
+    lr_scheduler_type="constant",  # use constant learning rate scheduler
+    push_to_hub=True,  # push model to hub
+    report_to="tensorboard",  # report metrics to tensorboard
+)
+
+max_seq_length = 4096
+
+trainer = SFTTrainer(
+    model=model,
+    args=args,
+    train_dataset=dataset,
+    max_seq_length=max_seq_length,
+    tokenizer=tokenizer,
+    peft_config=peft_config,
+    packing=True,
+)
+
+trainer.train()
+
+trainer.save_model()

+ 54 - 0
end-to-end-use-cases/coding/text2sql/tool/llama_eval.sh

@@ -0,0 +1,54 @@
+eval_path='./data/dev_20240627/dev.json'
+db_root_path='./data/dev_20240627/dev_databases/'
+ground_truth_path='./data/'
+
+#YOUR_API_KEY='xxx'
+#YOUR_API_KEY='yyy'
+
+# Llama model on Hugging Face Hub
+# https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct
+YOUR_API_KEY='huggingface'
+model='meta-llama/Llama-3.1-8B-Instruct'
+
+# Fine-tuned Llama model locally
+# YOUR_API_KEY='finetuned'
+# model='fine_tuning/llama31-8b-text2sql-epochs-3'
+
+# Llama models on Together
+#model='meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
+#model='meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo'
+#model='meta-llama/Llama-3.3-70B-Instruct-Turbo'
+#model='meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8'
+#model='meta-llama/Llama-4-Scout-17B-16E-Instruct'
+
+# Llama models on Llama API
+#model='Llama-3.3-8B-Instruct'
+#model='Llama-3.3-70B-Instruct'
+#model='Llama-4-Maverick-17B-128E-Instruct-FP8'
+#model='Llama-4-Scout-17B-16E-Instruct-FP8'
+
+#model="llama31-8b-text-sql-epochs-25"
+#model="llama31-8b-text-sql-epochs-3"
+#model="llama31-8b-text-sql"
+
+#data_output_path="./output/$model/run_500/no_ft/v3/"
+#data_output_path="./output/$model/run_500/ft_epochs-25/v3/"
+data_output_path="./output/$model/"
+
+echo "Text2SQL using $model"
+python3 -u llama_text2sql.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} \
+--model ${model} --eval_path ${eval_path} --data_output_path ${data_output_path}
+
+# Check if llama_text2sql.py exited successfully
+if [ $? -eq 0 ]; then
+    echo "llama_text2sql.py completed successfully. Proceeding with evaluation..."
+    python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \
+    --ground_truth_path ${ground_truth_path} \
+    --diff_json_path ${eval_path}
+
+    echo "Done evaluating $model."
+
+else
+    echo "Error: llama_text2sql.py failed with exit code $?. Skipping evaluation."
+    exit 1
+fi

+ 416 - 0
end-to-end-use-cases/coding/text2sql/tool/llama_text2sql.py

@@ -0,0 +1,416 @@
+import argparse
+import fnmatch
+import json
+import os
+import pdb
+import pickle
+import re
+import sqlite3
+from typing import Dict, List, Tuple
+
+import pandas as pd
+import sqlparse
+
+import torch
+from datasets import Dataset, load_dataset
+from peft import AutoPeftModelForCausalLM
+from tqdm import tqdm
+from transformers import (
+    AutoModelForCausalLM,
+    AutoTokenizer,
+    BitsAndBytesConfig,
+    pipeline,
+)
+
+
+def local_llama(prompt, pipe):
+    SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
+
+    messages = [
+        {"content": SYSTEM_PROMPT, "role": "system"},
+        {"role": "user", "content": prompt},
+    ]
+
+    raw_prompt = pipe.tokenizer.apply_chat_template(
+        messages,
+        tokenize=False,
+        add_generation_prompt=True,
+    )
+
+    print(f"local_llama: {raw_prompt=}")
+
+    outputs = pipe(
+        raw_prompt,
+        max_new_tokens=10240,
+        do_sample=False,
+        temperature=0.0,
+        top_k=50,
+        top_p=0.1,
+        eos_token_id=pipe.tokenizer.eos_token_id,
+        pad_token_id=pipe.tokenizer.pad_token_id,
+    )
+
+    generated_answer = outputs[0]["generated_text"][len(raw_prompt) :].strip()
+
+    print(f"{generated_answer=}")
+    return generated_answer
+
+
+def new_directory(path):
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+
+def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
+    """
+    Read an sqlite file, and return the CREATE commands for each of the tables in the database.
+    """
+    asdf = "database" if bench_root == "spider" else "databases"
+    with sqlite3.connect(
+        f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True
+    ) as conn:
+        # conn.text_factory = bytes
+        cursor = conn.cursor()
+        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
+        tables = cursor.fetchall()
+        schemas = {}
+        for table in tables:
+            cursor.execute(
+                "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
+                    table[0]
+                )
+            )
+            schemas[table[0]] = cursor.fetchone()[0]
+
+        return schemas
+
+
+def nice_look_table(column_names: list, values: list):
+    rows = []
+    # Determine the maximum width of each column
+    widths = [
+        max(len(str(value[i])) for value in values + [column_names])
+        for i in range(len(column_names))
+    ]
+
+    # Print the column names
+    header = "".join(
+        f"{column.rjust(width)} " for column, width in zip(column_names, widths)
+    )
+    # print(header)
+    # Print the values
+    for value in values:
+        row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
+        rows.append(row)
+    rows = "\n".join(rows)
+    final_output = header + "\n" + rows
+    return final_output
+
+
+def generate_schema_prompt(db_path, num_rows=None):
+    # extract create ddls
+    """
+    :param root_place:
+    :param db_name:
+    :return:
+    """
+    full_schema_prompt_list = []
+    conn = sqlite3.connect(db_path)
+    # Create a cursor object
+    cursor = conn.cursor()
+    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
+    tables = cursor.fetchall()
+    schemas = {}
+    for table in tables:
+        if table == "sqlite_sequence":
+            continue
+        cursor.execute(
+            "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
+                table[0]
+            )
+        )
+        create_prompt = cursor.fetchone()[0]
+        schemas[table[0]] = create_prompt
+        if num_rows:
+            cur_table = table[0]
+            if cur_table in ["order", "by", "group"]:
+                cur_table = "`{}`".format(cur_table)
+
+            cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
+            column_names = [description[0] for description in cursor.description]
+            values = cursor.fetchall()
+            rows_prompt = nice_look_table(column_names=column_names, values=values)
+            verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
+                num_rows, cur_table, num_rows, rows_prompt
+            )
+            schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
+
+    for k, v in schemas.items():
+        full_schema_prompt_list.append(v)
+
+    schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
+
+    return schema_prompt
+
+
+def generate_comment_prompt(question, knowledge=None):
+    knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
+    question_prompt = "-- Question: {}".format(question)
+
+    result_prompt = knowledge_prompt + "\n\n" + question_prompt
+
+    return result_prompt
+
+
+def generate_combined_prompts_one(db_path, question, knowledge=None):
+    schema_prompt = generate_schema_prompt(db_path, num_rows=None)
+    comment_prompt = generate_comment_prompt(question, knowledge)
+
+    combined_prompts = schema_prompt + "\n\n" + comment_prompt
+
+    return combined_prompts
+
+
+def cloud_llama(api_key, model, prompt, max_tokens, temperature, stop):
+    try:
+        if model.startswith("meta-llama/"):
+            llm = ChatTogether(
+                model=model,
+                temperature=0,
+            )
+            answer = llm.invoke(prompt).content
+        else:
+            client = LlamaAPIClient()
+
+            response = client.chat.completions.create(
+                model=model,
+                messages=[{"role": "user", "content": prompt}],
+                temperature=0,
+            )
+            answer = response.completion_message.content.text
+
+        pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
+        matches = pattern.findall(answer)
+        if matches != []:
+            result = matches[0]
+        else:
+            result = answer
+
+        print(result)
+    except Exception as e:
+        result = "error:{}".format(e)
+        print(f"{result=}")
+    return result
+
+
+def huggingface_finetuned(api_key, model):
+    if api_key == "finetuned":
+        model_id = model
+        model = AutoPeftModelForCausalLM.from_pretrained(
+            model_id, device_map="auto", torch_dtype=torch.float16
+        )
+        tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+        # TODO: uncomment to see if it makes a difference
+        tokenizer.padding_side = "right"  # to prevent warnings
+
+        if tokenizer.pad_token is None:
+            tokenizer.add_special_tokens({"pad_token": "[PAD]"})
+            model.resize_token_embeddings(len(tokenizer))
+
+    elif api_key == "huggingface":
+        model_id = model
+        bnb_config = BitsAndBytesConfig(
+            load_in_4bit=True,
+            bnb_4bit_use_double_quant=True,
+            bnb_4bit_quant_type="nf4",
+            bnb_4bit_compute_dtype=torch.bfloat16,
+        )
+        model = AutoModelForCausalLM.from_pretrained(
+            model_id,
+            device_map="auto",
+            # attn_implementation="flash_attention_2",
+            torch_dtype=torch.bfloat16,
+            quantization_config=bnb_config,
+        )
+        tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
+    return pipe
+
+
+def collect_response_from_llama(
+    db_path_list, question_list, api_key, model, knowledge_list=None
+):
+    """
+    :param db_path: str
+    :param question_list: []
+    :return: dict of responses
+    """
+    responses_dict = {}
+    response_list = []
+
+    if api_key in ["huggingface", "finetuned"]:
+        pipe = huggingface_finetuned(api_key=api_key, model=model)
+
+    for i, question in tqdm(enumerate(question_list)):
+        print(
+            "--------------------- processing question #{}---------------------".format(
+                i + 1
+            )
+        )
+        print("the question is: {}".format(question))
+
+        if knowledge_list:
+            cur_prompt = generate_combined_prompts_one(
+                db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
+            )
+        else:
+            cur_prompt = generate_combined_prompts_one(
+                db_path=db_path_list[i], question=question
+            )
+
+        if api_key in ["huggingface", "finetuned"]:
+            plain_result = local_llama(prompt=cur_prompt, pipe=pipe)
+        else:
+            plain_result = cloud_llama(
+                api_key=api_key,
+                model=model,
+                prompt=cur_prompt,
+                max_tokens=4096,
+                temperature=0,
+                stop=["--", "\n\n", ";", "#"],
+            )
+        if type(plain_result) == str:
+            sql = plain_result
+        else:
+            sql = "SELECT" + plain_result["choices"][0]["text"]
+
+        # responses_dict[i] = sql
+        db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
+        sql = (
+            sql + "\t----- bird -----\t" + db_id
+        )  # to avoid unpredicted \t appearing in codex results
+        response_list.append(sql)
+
+    return response_list
+
+
+def question_package(data_json, knowledge=False):
+    question_list = []
+    for data in data_json:
+        question_list.append(data["question"])
+
+    return question_list
+
+
+def knowledge_package(data_json, knowledge=False):
+    knowledge_list = []
+    for data in data_json:
+        knowledge_list.append(data["evidence"])
+
+    return knowledge_list
+
+
+def decouple_question_schema(datasets, db_root_path):
+    question_list = []
+    db_path_list = []
+    knowledge_list = []
+    for i, data in enumerate(datasets):
+        question_list.append(data["question"])
+        cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite"
+        db_path_list.append(cur_db_path)
+        knowledge_list.append(data["evidence"])
+
+    return question_list, db_path_list, knowledge_list
+
+
+def generate_sql_file(sql_lst, output_path=None):
+    result = {}
+    for i, sql in enumerate(sql_lst):
+        result[i] = sql
+
+    if output_path:
+        directory_path = os.path.dirname(output_path)
+        new_directory(directory_path)
+        json.dump(result, open(output_path, "w"), indent=4)
+
+    return result
+
+
+if __name__ == "__main__":
+    args_parser = argparse.ArgumentParser()
+    args_parser.add_argument("--eval_path", type=str, default="")
+    args_parser.add_argument("--mode", type=str, default="dev")
+    args_parser.add_argument("--test_path", type=str, default="")
+    args_parser.add_argument("--use_knowledge", type=str, default="True")
+    args_parser.add_argument("--db_root_path", type=str, default="")
+    args_parser.add_argument("--api_key", type=str, required=True)
+    args_parser.add_argument("--model", type=str, required=True)
+    args_parser.add_argument("--data_output_path", type=str)
+    args = args_parser.parse_args()
+
+    if not args.api_key in ["huggingface", "finetuned"]:
+        if args.model.startswith("meta-llama/"):  # Llama model on together
+
+            os.environ["TOGETHER_API_KEY"] = args.api_key
+            llm = ChatTogether(
+                model=args.model,
+                temperature=0,
+            )
+            try:
+                response = llm.invoke("125*125 is?").content
+                print(f"{response=}")
+            except Exception as exception:
+                print(f"{exception=}")
+                exit(1)
+        else:  # Llama model on Llama API
+            os.environ["LLAMA_API_KEY"] = args.api_key
+
+            try:
+                client = LlamaAPIClient()
+
+                response = client.chat.completions.create(
+                    model=args.model,
+                    messages=[{"role": "user", "content": "125*125 is?"}],
+                    temperature=0,
+                )
+                answer = response.completion_message.content.text
+
+                print(f"{answer=}")
+            except Exception as exception:
+                print(f"{exception=}")
+                exit(1)
+
+    eval_data = json.load(open(args.eval_path, "r"))
+    # '''for debug'''
+    # eval_data = eval_data[:3]
+    # '''for debug'''
+
+    question_list, db_path_list, knowledge_list = decouple_question_schema(
+        datasets=eval_data, db_root_path=args.db_root_path
+    )
+    assert len(question_list) == len(db_path_list) == len(knowledge_list)
+
+    if args.use_knowledge == "True":
+        responses = collect_response_from_llama(
+            db_path_list=db_path_list,
+            question_list=question_list,
+            api_key=args.api_key,
+            model=args.model,
+            knowledge_list=knowledge_list,
+        )
+    else:
+        responses = collect_response_from_llama(
+            db_path_list=db_path_list,
+            question_list=question_list,
+            api_key=args.api_key,
+            model=args.model,
+            knowledge_list=None,
+        )
+
+    output_name = args.data_output_path + "predict_" + args.mode + ".json"
+
+    generate_sql_file(sql_lst=responses, output_path=output_name)
+
+    print("successfully collect results from {}".format(args.model))

+ 19 - 0
end-to-end-use-cases/coding/text2sql/tool/requirements.txt

@@ -0,0 +1,19 @@
+llama_api_client
+langchain-together
+sqlparse
+torch==2.4.1
+tensorboard 
+liger-kernel==0.4.2
+setuptools
+deepspeed==0.15.4
+# openai
+# lm-eval[api]==0.4.5
+transformers==4.46.3
+datasets==3.6.0
+accelerate==1.1.1
+bitsandbytes==0.44.1
+trl==0.12.1
+peft==0.13.2
+lighteval==0.6.2
+hf-transfer==0.1.8
+func_timeout

+ 199 - 0
end-to-end-use-cases/coding/text2sql/tool/text2sql_eval.py

@@ -0,0 +1,199 @@
+import argparse
+import json
+import multiprocessing as mp
+import re
+import sqlite3
+import sys
+
+from func_timeout import func_timeout, FunctionTimedOut
+
+
+def load_json(dir):
+    with open(dir, "r") as j:
+        contents = json.loads(j.read())
+    return contents
+
+
+def result_callback(result):
+    exec_result.append(result)
+
+
+def execute_sql(predicted_sql, ground_truth, db_path):
+    conn = sqlite3.connect(db_path)
+    # Connect to the database
+    cursor = conn.cursor()
+    cursor.execute(predicted_sql)
+    predicted_res = cursor.fetchall()
+    cursor.execute(ground_truth)
+    ground_truth_res = cursor.fetchall()
+    res = 0
+    if set(predicted_res) == set(ground_truth_res):
+        res = 1
+    else:
+        print(
+            f"\n\n==== INCORRECT SQL GENERATED ====\n{predicted_sql=}\n{predicted_res=}\n{ground_truth=}\n{ground_truth_res=}\n======\n\n"
+        )
+
+    return res
+
+
+def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out):
+    try:
+        res = func_timeout(
+            meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place)
+        )
+    except KeyboardInterrupt:
+        sys.exit(0)
+    except FunctionTimedOut:
+        result = [(f"timeout",)]
+        res = 0
+    except Exception as e:
+        result = [(f"error",)]  # possibly len(query) > 512 or not executable
+        res = 0
+    result = {"sql_idx": idx, "res": res}
+    return result
+
+
+def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
+    clean_sqls = []
+    db_path_list = []
+    if mode == "gpt":
+        sql_data = json.load(open(sql_path + "predict_" + data_mode + ".json", "r"))
+        for idx, sql_str in sql_data.items():
+            if type(sql_str) == str:
+                sql, db_name = sql_str.split("\t----- bird -----\t")
+            else:
+                sql, db_name = " ", "financial"
+            clean_sqls.append(sql)
+
+            db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite")
+
+    elif mode == "gt":  # ground truth
+        items = json.load(open(db_root_path + "/../dev.json"))
+
+        for item in items:
+            sql = item["SQL"]
+            db_name = item["db_id"]
+            clean_sqls.append(sql)
+            db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite")
+
+    return clean_sqls, db_path_list
+
+
+def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0):
+    pool = mp.Pool(processes=num_cpus)
+    for i, sql_pair in enumerate(sqls):
+
+        predicted_sql, ground_truth = sql_pair
+        pool.apply_async(
+            execute_model,
+            args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out),
+            callback=result_callback,
+        )
+    pool.close()
+    pool.join()
+
+
+def sort_results(list_of_dicts):
+    return sorted(list_of_dicts, key=lambda x: x["sql_idx"])
+
+
+def compute_acc_by_diff(exec_results, diff_json_path):
+    num_queries = len(exec_results)
+    results = [res["res"] for res in exec_results]
+    contents = load_json(diff_json_path)
+
+    simple_results, moderate_results, challenging_results = [], [], []
+
+    for i, content in enumerate(contents):
+        if content["difficulty"] == "simple":
+            simple_results.append(exec_results[i])
+
+        if content["difficulty"] == "moderate":
+            moderate_results.append(exec_results[i])
+
+        if content["difficulty"] == "challenging":
+            challenging_results.append(exec_results[i])
+
+    simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results)
+    moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results)
+    challenging_acc = (
+        0
+        if len(challenging_results) == 0
+        else sum([res["res"] for res in challenging_results]) / len(challenging_results)
+    )
+    all_acc = sum(results) / num_queries
+    count_lists = [
+        len(simple_results),
+        len(moderate_results),
+        len(challenging_results),
+        num_queries,
+    ]
+    return (
+        simple_acc * 100,
+        moderate_acc * 100,
+        challenging_acc * 100,
+        all_acc * 100,
+        count_lists,
+    )
+
+
+def print_data(score_lists, count_lists):
+    levels = ["simple", "moderate", "challenging", "total"]
+    print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
+    print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists))
+
+    print(
+        "======================================    ACCURACY    ====================================="
+    )
+    print(
+        "{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists)
+    )
+
+
+if __name__ == "__main__":
+    args_parser = argparse.ArgumentParser()
+    args_parser.add_argument(
+        "--predicted_sql_path", type=str, required=True, default=""
+    )
+    args_parser.add_argument("--ground_truth_path", type=str, required=True, default="")
+    args_parser.add_argument("--data_mode", type=str, default="dev")
+    args_parser.add_argument("--db_root_path", type=str, required=True, default="")
+    args_parser.add_argument("--num_cpus", type=int, default=1)
+    args_parser.add_argument("--meta_time_out", type=float, default=30.0)
+    args_parser.add_argument("--mode_gt", type=str, default="gt")
+    args_parser.add_argument("--mode_predict", type=str, default="gpt")
+    args_parser.add_argument("--difficulty", type=str, default="simple")
+    args_parser.add_argument("--diff_json_path", type=str, default="")
+    args = args_parser.parse_args()
+    exec_result = []
+
+    pred_queries, db_paths = package_sqls(
+        args.predicted_sql_path,
+        args.db_root_path,
+        mode=args.mode_predict,
+        data_mode=args.data_mode,
+    )
+    # generate gt sqls:
+    gt_queries, db_paths_gt = package_sqls(
+        args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.data_mode
+    )
+
+    query_pairs = list(zip(pred_queries, gt_queries))
+    run_sqls_parallel(
+        query_pairs,
+        db_places=db_paths,
+        num_cpus=args.num_cpus,
+        meta_time_out=args.meta_time_out,
+    )
+    exec_result = sort_results(exec_result)
+
+    print("Evaluating statistics...")
+    simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff(
+        exec_result, args.diff_json_path
+    )
+    score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
+    print_data(score_lists, count_lists)
+    print(
+        "==========================================================================================="
+    )