Browse Source

adding vllm eval files and updating requirements.txt

Amir Youssefi 3 days ago
parent
commit
2bd662c67d

+ 24 - 0
end-to-end-use-cases/coding/text2sql/eval/llama_eval_vllm.sh

@@ -0,0 +1,24 @@
+eval_path='../data/dev_20240627/dev.json'
+db_root_path='../data/dev_20240627/dev_databases/'
+ground_truth_path='../data/'
+
+model='meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
+data_output_path="./output/$model/"
+
+echo "Text2SQL using $model"
+python3 -u llama_text2sql_vllm.py --db_root_path ${db_root_path} \
+--model ${model} --eval_path ${eval_path} --data_output_path ${data_output_path}
+
+# Check if llama_text2sql.py exited successfully
+if [ $? -eq 0 ]; then
+    echo "llama_text2sql.py completed successfully. Proceeding with evaluation..."
+    python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \
+    --ground_truth_path ${ground_truth_path} \
+    --diff_json_path ${eval_path}
+
+    echo "Done evaluating $model."
+
+else
+    echo "Error: llama_text2sql.py failed with exit code $?. Skipping evaluation."
+    exit 1
+fi

+ 289 - 0
end-to-end-use-cases/coding/text2sql/eval/llama_text2sql_vllm.py

@@ -0,0 +1,289 @@
+import argparse
+import json
+import os
+import re
+import sqlite3
+from typing import Dict, List, Tuple
+
+from tqdm import tqdm
+
+from vllm import LLM, EngineArgs, SamplingParams
+
+DEFAULT_MAX_TOKENS=10240
+SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
+# UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
+# SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
+
+
+def inference(llm, sampling_params, user_prompt):
+    messages = [
+        {"content": SYSTEM_PROMPT, "role": "system"},
+        {"role": "user", "content": user_prompt},
+    ]
+
+    print(f"{messages=}")
+
+    response = llm.chat(messages, sampling_params, use_tqdm=False)
+    print(f"{response=}")
+    response_text = response[0].outputs[0].text
+    pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
+    matches = pattern.findall(response_text)
+    if matches != []:
+        result = matches[0]
+    else:
+        result = response_text
+    print(f"{result=}")
+    return result
+
+
+def new_directory(path):
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+
+def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
+    """
+    Read an sqlite file, and return the CREATE commands for each of the tables in the database.
+    """
+    asdf = "database" if bench_root == "spider" else "databases"
+    with sqlite3.connect(
+        f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True
+    ) as conn:
+        # conn.text_factory = bytes
+        cursor = conn.cursor()
+        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
+        tables = cursor.fetchall()
+        schemas = {}
+        for table in tables:
+            cursor.execute(
+                "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
+                    table[0]
+                )
+            )
+            schemas[table[0]] = cursor.fetchone()[0]
+
+        return schemas
+
+
+def nice_look_table(column_names: list, values: list):
+    rows = []
+    # Determine the maximum width of each column
+    widths = [
+        max(len(str(value[i])) for value in values + [column_names])
+        for i in range(len(column_names))
+    ]
+
+    # Print the column names
+    header = "".join(
+        f"{column.rjust(width)} " for column, width in zip(column_names, widths)
+    )
+    # print(header)
+    # Print the values
+    for value in values:
+        row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
+        rows.append(row)
+    rows = "\n".join(rows)
+    final_output = header + "\n" + rows
+    return final_output
+
+
+def generate_schema_prompt(db_path, num_rows=None):
+    # extract create ddls
+    """
+    :param root_place:
+    :param db_name:
+    :return:
+    """
+    full_schema_prompt_list = []
+    conn = sqlite3.connect(db_path)
+    # Create a cursor object
+    cursor = conn.cursor()
+    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
+    tables = cursor.fetchall()
+    schemas = {}
+    for table in tables:
+        if table == "sqlite_sequence":
+            continue
+        cursor.execute(
+            "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
+                table[0]
+            )
+        )
+        create_prompt = cursor.fetchone()[0]
+        schemas[table[0]] = create_prompt
+        if num_rows:
+            cur_table = table[0]
+            if cur_table in ["order", "by", "group"]:
+                cur_table = "`{}`".format(cur_table)
+
+            cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
+            column_names = [description[0] for description in cursor.description]
+            values = cursor.fetchall()
+            rows_prompt = nice_look_table(column_names=column_names, values=values)
+            verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
+                num_rows, cur_table, num_rows, rows_prompt
+            )
+            schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
+
+    for k, v in schemas.items():
+        full_schema_prompt_list.append(v)
+
+    schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
+
+    return schema_prompt
+
+
+def generate_comment_prompt(question, knowledge=None):
+    knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
+    question_prompt = "-- Question: {}".format(question)
+
+    result_prompt = knowledge_prompt + "\n\n" + question_prompt
+
+    return result_prompt
+
+
+def generate_combined_prompts_one(db_path, question, knowledge=None):
+    schema_prompt = generate_schema_prompt(db_path, num_rows=None)
+    comment_prompt = generate_comment_prompt(question, knowledge)
+
+    combined_prompts = schema_prompt + "\n\n" + comment_prompt
+
+    return combined_prompts
+
+
+
+def collect_response_from_llama(
+    llm, sampling_params, db_path_list, question_list, knowledge_list=None
+):
+    response_list = []
+
+    for i, question in tqdm(enumerate(question_list)):
+        print(
+            "--------------------- processing question #{}---------------------".format(
+                i + 1
+            )
+        )
+        print("the question is: {}".format(question))
+
+        if knowledge_list:
+            cur_prompt = generate_combined_prompts_one(
+                db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
+            )
+        else:
+            cur_prompt = generate_combined_prompts_one(
+                db_path=db_path_list[i], question=question
+            )
+
+        plain_result = inference(llm, sampling_params, cur_prompt)
+        if type(plain_result) == str:
+            sql = plain_result
+        else:
+            sql = "SELECT" + plain_result["choices"][0]["text"]
+
+        # responses_dict[i] = sql
+        db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
+        sql = (
+            sql + "\t----- bird -----\t" + db_id
+        )  # to avoid unpredicted \t appearing in codex results
+        response_list.append(sql)
+
+    return response_list
+
+
+def question_package(data_json, knowledge=False):
+    question_list = []
+    for data in data_json:
+        question_list.append(data["question"])
+
+    return question_list
+
+
+def knowledge_package(data_json, knowledge=False):
+    knowledge_list = []
+    for data in data_json:
+        knowledge_list.append(data["evidence"])
+
+    return knowledge_list
+
+
+def decouple_question_schema(datasets, db_root_path):
+    question_list = []
+    db_path_list = []
+    knowledge_list = []
+    for i, data in enumerate(datasets):
+        question_list.append(data["question"])
+        cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite"
+        db_path_list.append(cur_db_path)
+        knowledge_list.append(data["evidence"])
+
+    return question_list, db_path_list, knowledge_list
+
+
+def generate_sql_file(sql_lst, output_path=None):
+    result = {}
+    for i, sql in enumerate(sql_lst):
+        result[i] = sql
+
+    if output_path:
+        directory_path = os.path.dirname(output_path)
+        new_directory(directory_path)
+        json.dump(result, open(output_path, "w"), indent=4)
+
+    return result
+
+
+if __name__ == "__main__":
+    args_parser = argparse.ArgumentParser()
+    args_parser.add_argument("--eval_path", type=str, default="")
+    args_parser.add_argument("--mode", type=str, default="dev")
+    args_parser.add_argument("--test_path", type=str, default="")
+    args_parser.add_argument("--use_knowledge", type=str, default="True")
+    args_parser.add_argument("--db_root_path", type=str, default="")
+    args_parser.add_argument("--model", type=str, default="meta-llama/Llama-3.1-8B-Instruct")
+    args_parser.add_argument("--data_output_path", type=str)
+    args_parser.add_argument("--max_tokens", type=int, default=DEFAULT_MAX_TOKENS)
+    args_parser.add_argument("--temperature", type=float, default=0.0)
+    args_parser.add_argument("--top_k", type=int, default=50)
+    args_parser.add_argument("--top_p", type=float, default=0.1)
+    args = args_parser.parse_args()
+
+    eval_data = json.load(open(args.eval_path, "r"))
+    # '''for debug'''
+    # eval_data = eval_data[:3]
+    # '''for debug'''
+
+    question_list, db_path_list, knowledge_list = decouple_question_schema(
+        datasets=eval_data, db_root_path=args.db_root_path
+    )
+    assert len(question_list) == len(db_path_list) == len(knowledge_list)
+
+    llm = LLM(model=args.model, download_dir="/opt/hpcaas/.mounts/fs-06ad2f76a5ad0b18f/shared/amiryo/.cache/vllm")
+    sampling_params = llm.get_default_sampling_params()
+    sampling_params.max_tokens = args.max_tokens
+    sampling_params.temperature = args.temperature
+    sampling_params.top_p = args.top_p
+    sampling_params.top_k = args.top_k
+
+
+    if args.use_knowledge == "True":
+        responses = collect_response_from_llama(
+            llm=llm,
+            sampling_params=sampling_params,
+            db_path_list=db_path_list,
+            question_list=question_list,
+            knowledge_list=knowledge_list,
+        )
+    else:
+        responses = collect_response_from_llama(
+            llm=llm,
+            sampling_params=sampling_params,
+            db_path_list=db_path_list,
+            question_list=question_list,
+            knowledge_list=None,
+        )
+
+    output_name = args.data_output_path + "predict_" + args.mode + ".json"
+
+    generate_sql_file(sql_lst=responses, output_path=output_name)
+
+    print("successfully collect results from {}".format(args.model))

+ 2 - 0
end-to-end-use-cases/coding/text2sql/eval/requirements.txt

@@ -15,3 +15,5 @@ peft==0.13.2
 lighteval==0.6.2
 hf-transfer==0.1.8
 func_timeout==4.3.5
+vllm==0.9.2
+flashinfer-python==0.2.7.post1