Procházet zdrojové kódy

script to save dev set in pandas csv format

Jeff Tang před 2 měsíci
rodič
revize
af3ea4fe64

+ 159 - 0
end-to-end-use-cases/coding/text2sql/eval/create_bird_eval_dataset.py

@@ -0,0 +1,159 @@
+import argparse
+import json
+import os
+import sqlite3
+
+import pandas as pd
+
+# from datasets import Dataset
+from tqdm import tqdm
+
+
+def new_directory(path):
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+
+def nice_look_table(column_names: list, values: list):
+    rows = []
+    # Determine the maximum width of each column
+    widths = [
+        max(len(str(value[i])) for value in values + [column_names])
+        for i in range(len(column_names))
+    ]
+
+    # Print the column names
+    header = "".join(
+        f"{column.rjust(width)} " for column, width in zip(column_names, widths)
+    )
+    # print(header)
+    # Print the values
+    for value in values:
+        row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
+        rows.append(row)
+    rows = "\n".join(rows)
+    final_output = header + "\n" + rows
+    return final_output
+
+
+def generate_schema_prompt(db_path, num_rows=None):
+    # extract create ddls
+    """
+    :param root_place:
+    :param db_name:
+    :return:
+    """
+    full_schema_prompt_list = []
+    conn = sqlite3.connect(db_path)
+    # Create a cursor object
+    cursor = conn.cursor()
+    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
+    tables = cursor.fetchall()
+    schemas = {}
+    for table in tables:
+        if table == "sqlite_sequence":
+            continue
+        cursor.execute(
+            "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
+                table[0]
+            )
+        )
+        create_prompt = cursor.fetchone()[0]
+        schemas[table[0]] = create_prompt
+        if num_rows:
+            cur_table = table[0]
+            if cur_table in ["order", "by", "group"]:
+                cur_table = "`{}`".format(cur_table)
+
+            cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
+            column_names = [description[0] for description in cursor.description]
+            values = cursor.fetchall()
+            rows_prompt = nice_look_table(column_names=column_names, values=values)
+            verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
+                num_rows, cur_table, num_rows, rows_prompt
+            )
+            schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
+
+    for k, v in schemas.items():
+        full_schema_prompt_list.append(v)
+
+    schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
+
+    return schema_prompt
+
+
+def generate_comment_prompt(question, knowledge=None):
+    knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
+    question_prompt = "-- Question: {}".format(question)
+
+    result_prompt = knowledge_prompt + "\n\n" + question_prompt
+
+    return result_prompt
+
+
+def generate_combined_prompts_one(db_path, question, knowledge=None):
+    schema_prompt = generate_schema_prompt(db_path, num_rows=None)
+    comment_prompt = generate_comment_prompt(question, knowledge)
+
+    combined_prompts = schema_prompt + "\n\n" + comment_prompt
+
+    return combined_prompts
+
+
+def create_conversation(sample):
+    return {
+        "messages": [
+            {"role": "system", "content": sample["messages"][0]["content"]},
+            {"role": "user", "content": sample["messages"][1]["content"]},
+            {"role": "assistant", "content": sample["messages"][2]["content"]},
+        ]
+    }
+
+
+def create_bird_eval_dataset(input_json, db_root_path):
+    SYSTEM_PROMPT = (
+        "You are a text to SQL query translator. Using the SQLite DB Schema and the "
+        "External Knowledge, translate the following text question into a SQLite SQL "
+        "select statement."
+    )
+    data = []
+
+    for i, item in tqdm(enumerate(input_json)):
+        print(f"processing #{i+1}")
+        db_id = item["db_id"]
+        question = item["question"]
+        external_knowledge = item["evidence"]
+        SQL = item["SQL"]
+        db_path = db_root_path + "/" + db_id + "/" + db_id + ".sqlite"
+        print(f"{db_path=}")
+        prompt = generate_combined_prompts_one(
+            db_path,
+            question,
+            knowledge=external_knowledge,
+        )
+
+        data.append(
+            {
+                "prompt": SYSTEM_PROMPT + "\n\n" + prompt,
+                "gold_sql": SQL,
+                "db_id": db_id,
+            }
+        )
+
+    df = pd.DataFrame(data)
+    df.to_csv("bird_dev_set_eval.csv", index=False)
+    print(f"Dataset saved as bird_dev_set_eval.csv with {len(df)} rows")
+
+
+if __name__ == "__main__":
+    args_parser = argparse.ArgumentParser()
+    args_parser.add_argument("--input_json", type=str, required=True)
+    args_parser.add_argument("--db_root_path", type=str, required=True)
+    args = args_parser.parse_args()
+
+    input_json = json.load(open(args.input_json, "r"))
+    db_root_path = args.db_root_path
+
+    create_bird_eval_dataset(input_json, db_root_path)
+
+# python3 create_bird_eval_dataset.py --input_json ../data/dev_20240627/dev.json --db_root_path ../data/dev_20240627/dev_databases

+ 0 - 50
end-to-end-use-cases/coding/text2sql/eval/llama_eval.sh

@@ -1,50 +0,0 @@
-# Set to "true" to enable debug mode with detailed prints
-DEBUG_MODE="false"
-
-eval_path='../data/dev_20240627/dev.json'
-db_root_path='../data/dev_20240627/dev_databases/'
-ground_truth_path='../data/'
-
-# Llama models on Llama API
-YOUR_API_KEY='YOUR_LLAMA_API_KEY'
-model='Llama-3.3-8B-Instruct'
-#model='Llama-3.3-70B-Instruct'
-#model='Llama-4-Maverick-17B-128E-Instruct-FP8'
-#model='Llama-4-Scout-17B-16E-Instruct-FP8'
-
-# Llama model on Hugging Face Hub https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct
-# YOUR_API_KEY='huggingface'
-# model='meta-llama/Llama-3.1-8B-Instruct'
-
-# Fine-tuned Llama models locally
-# YOUR_API_KEY='finetuned'
-# model='../fine-tuning/llama31-8b-text2sql-fft-nonquantized-cot-epochs-3'
-
-data_output_path="./output/$model/"
-
-echo "Text2SQL using $model"
-
-python3 -u llama_text2sql.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} \
---model ${model} --eval_path ${eval_path} --data_output_path ${data_output_path}
-
-# Check if llama_text2sql.py exited successfully
-if [ $? -eq 0 ]; then
-    echo "llama_text2sql.py completed successfully. Proceeding with evaluation..."
-
-    # Add --debug flag if DEBUG_MODE is true
-    if [ "$DEBUG_MODE" = "true" ]; then
-        python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \
-        --ground_truth_path ${ground_truth_path} \
-        --diff_json_path ${eval_path} --debug
-    else
-        python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \
-        --ground_truth_path ${ground_truth_path} \
-        --diff_json_path ${eval_path}
-    fi
-
-    echo "Done evaluating $model."
-
-else
-    echo "Error: llama_text2sql.py failed with exit code $?. Skipping evaluation."
-    exit 1
-fi