浏览代码

FT readme update; removed old vllm py and sh files

Jeff Tang 2 天之前
父节点
当前提交
6501cf4566

+ 0 - 24
end-to-end-use-cases/coding/text2sql/eval/llama_eval_vllm.sh

@@ -1,24 +0,0 @@
-eval_path='../data/dev_20240627/dev.json'
-db_root_path='../data/dev_20240627/dev_databases/'
-ground_truth_path='../data/'
-
-model='meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
-data_output_path="./output/$model/"
-
-echo "Text2SQL using $model"
-python3 -u llama_text2sql_vllm.py --db_root_path ${db_root_path} \
---model ${model} --eval_path ${eval_path} --data_output_path ${data_output_path}
-
-# Check if llama_text2sql.py exited successfully
-if [ $? -eq 0 ]; then
-    echo "llama_text2sql.py completed successfully. Proceeding with evaluation..."
-    python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \
-    --ground_truth_path ${ground_truth_path} \
-    --diff_json_path ${eval_path}
-
-    echo "Done evaluating $model."
-
-else
-    echo "Error: llama_text2sql.py failed with exit code $?. Skipping evaluation."
-    exit 1
-fi

+ 0 - 289
end-to-end-use-cases/coding/text2sql/eval/llama_text2sql_vllm.py

@@ -1,289 +0,0 @@
-import argparse
-import json
-import os
-import re
-import sqlite3
-from typing import Dict, List, Tuple
-
-from tqdm import tqdm
-
-from vllm import LLM, EngineArgs, SamplingParams
-
-DEFAULT_MAX_TOKENS=10240
-SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
-# UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
-# SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
-
-
-def inference(llm, sampling_params, user_prompt):
-    messages = [
-        {"content": SYSTEM_PROMPT, "role": "system"},
-        {"role": "user", "content": user_prompt},
-    ]
-
-    print(f"{messages=}")
-
-    response = llm.chat(messages, sampling_params, use_tqdm=False)
-    print(f"{response=}")
-    response_text = response[0].outputs[0].text
-    pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
-    matches = pattern.findall(response_text)
-    if matches != []:
-        result = matches[0]
-    else:
-        result = response_text
-    print(f"{result=}")
-    return result
-
-
-def new_directory(path):
-    if not os.path.exists(path):
-        os.makedirs(path)
-
-
-def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
-    """
-    Read an sqlite file, and return the CREATE commands for each of the tables in the database.
-    """
-    asdf = "database" if bench_root == "spider" else "databases"
-    with sqlite3.connect(
-        f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True
-    ) as conn:
-        # conn.text_factory = bytes
-        cursor = conn.cursor()
-        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
-        tables = cursor.fetchall()
-        schemas = {}
-        for table in tables:
-            cursor.execute(
-                "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
-                    table[0]
-                )
-            )
-            schemas[table[0]] = cursor.fetchone()[0]
-
-        return schemas
-
-
-def nice_look_table(column_names: list, values: list):
-    rows = []
-    # Determine the maximum width of each column
-    widths = [
-        max(len(str(value[i])) for value in values + [column_names])
-        for i in range(len(column_names))
-    ]
-
-    # Print the column names
-    header = "".join(
-        f"{column.rjust(width)} " for column, width in zip(column_names, widths)
-    )
-    # print(header)
-    # Print the values
-    for value in values:
-        row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
-        rows.append(row)
-    rows = "\n".join(rows)
-    final_output = header + "\n" + rows
-    return final_output
-
-
-def generate_schema_prompt(db_path, num_rows=None):
-    # extract create ddls
-    """
-    :param root_place:
-    :param db_name:
-    :return:
-    """
-    full_schema_prompt_list = []
-    conn = sqlite3.connect(db_path)
-    # Create a cursor object
-    cursor = conn.cursor()
-    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
-    tables = cursor.fetchall()
-    schemas = {}
-    for table in tables:
-        if table == "sqlite_sequence":
-            continue
-        cursor.execute(
-            "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
-                table[0]
-            )
-        )
-        create_prompt = cursor.fetchone()[0]
-        schemas[table[0]] = create_prompt
-        if num_rows:
-            cur_table = table[0]
-            if cur_table in ["order", "by", "group"]:
-                cur_table = "`{}`".format(cur_table)
-
-            cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
-            column_names = [description[0] for description in cursor.description]
-            values = cursor.fetchall()
-            rows_prompt = nice_look_table(column_names=column_names, values=values)
-            verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
-                num_rows, cur_table, num_rows, rows_prompt
-            )
-            schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
-
-    for k, v in schemas.items():
-        full_schema_prompt_list.append(v)
-
-    schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
-
-    return schema_prompt
-
-
-def generate_comment_prompt(question, knowledge=None):
-    knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
-    question_prompt = "-- Question: {}".format(question)
-
-    result_prompt = knowledge_prompt + "\n\n" + question_prompt
-
-    return result_prompt
-
-
-def generate_combined_prompts_one(db_path, question, knowledge=None):
-    schema_prompt = generate_schema_prompt(db_path, num_rows=None)
-    comment_prompt = generate_comment_prompt(question, knowledge)
-
-    combined_prompts = schema_prompt + "\n\n" + comment_prompt
-
-    return combined_prompts
-
-
-
-def collect_response_from_llama(
-    llm, sampling_params, db_path_list, question_list, knowledge_list=None
-):
-    response_list = []
-
-    for i, question in tqdm(enumerate(question_list)):
-        print(
-            "--------------------- processing question #{}---------------------".format(
-                i + 1
-            )
-        )
-        print("the question is: {}".format(question))
-
-        if knowledge_list:
-            cur_prompt = generate_combined_prompts_one(
-                db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
-            )
-        else:
-            cur_prompt = generate_combined_prompts_one(
-                db_path=db_path_list[i], question=question
-            )
-
-        plain_result = inference(llm, sampling_params, cur_prompt)
-        if type(plain_result) == str:
-            sql = plain_result
-        else:
-            sql = "SELECT" + plain_result["choices"][0]["text"]
-
-        # responses_dict[i] = sql
-        db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
-        sql = (
-            sql + "\t----- bird -----\t" + db_id
-        )  # to avoid unpredicted \t appearing in codex results
-        response_list.append(sql)
-
-    return response_list
-
-
-def question_package(data_json, knowledge=False):
-    question_list = []
-    for data in data_json:
-        question_list.append(data["question"])
-
-    return question_list
-
-
-def knowledge_package(data_json, knowledge=False):
-    knowledge_list = []
-    for data in data_json:
-        knowledge_list.append(data["evidence"])
-
-    return knowledge_list
-
-
-def decouple_question_schema(datasets, db_root_path):
-    question_list = []
-    db_path_list = []
-    knowledge_list = []
-    for i, data in enumerate(datasets):
-        question_list.append(data["question"])
-        cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite"
-        db_path_list.append(cur_db_path)
-        knowledge_list.append(data["evidence"])
-
-    return question_list, db_path_list, knowledge_list
-
-
-def generate_sql_file(sql_lst, output_path=None):
-    result = {}
-    for i, sql in enumerate(sql_lst):
-        result[i] = sql
-
-    if output_path:
-        directory_path = os.path.dirname(output_path)
-        new_directory(directory_path)
-        json.dump(result, open(output_path, "w"), indent=4)
-
-    return result
-
-
-if __name__ == "__main__":
-    args_parser = argparse.ArgumentParser()
-    args_parser.add_argument("--eval_path", type=str, default="")
-    args_parser.add_argument("--mode", type=str, default="dev")
-    args_parser.add_argument("--test_path", type=str, default="")
-    args_parser.add_argument("--use_knowledge", type=str, default="True")
-    args_parser.add_argument("--db_root_path", type=str, default="")
-    args_parser.add_argument("--model", type=str, default="meta-llama/Llama-3.1-8B-Instruct")
-    args_parser.add_argument("--data_output_path", type=str)
-    args_parser.add_argument("--max_tokens", type=int, default=DEFAULT_MAX_TOKENS)
-    args_parser.add_argument("--temperature", type=float, default=0.0)
-    args_parser.add_argument("--top_k", type=int, default=50)
-    args_parser.add_argument("--top_p", type=float, default=0.1)
-    args = args_parser.parse_args()
-
-    eval_data = json.load(open(args.eval_path, "r"))
-    # '''for debug'''
-    # eval_data = eval_data[:3]
-    # '''for debug'''
-
-    question_list, db_path_list, knowledge_list = decouple_question_schema(
-        datasets=eval_data, db_root_path=args.db_root_path
-    )
-    assert len(question_list) == len(db_path_list) == len(knowledge_list)
-
-    llm = LLM(model=args.model, download_dir="/opt/hpcaas/.mounts/fs-06ad2f76a5ad0b18f/shared/amiryo/.cache/vllm")
-    sampling_params = llm.get_default_sampling_params()
-    sampling_params.max_tokens = args.max_tokens
-    sampling_params.temperature = args.temperature
-    sampling_params.top_p = args.top_p
-    sampling_params.top_k = args.top_k
-
-
-    if args.use_knowledge == "True":
-        responses = collect_response_from_llama(
-            llm=llm,
-            sampling_params=sampling_params,
-            db_path_list=db_path_list,
-            question_list=question_list,
-            knowledge_list=knowledge_list,
-        )
-    else:
-        responses = collect_response_from_llama(
-            llm=llm,
-            sampling_params=sampling_params,
-            db_path_list=db_path_list,
-            question_list=question_list,
-            knowledge_list=None,
-        )
-
-    output_name = args.data_output_path + "predict_" + args.mode + ".json"
-
-    generate_sql_file(sql_lst=responses, output_path=output_name)
-
-    print("successfully collect results from {}".format(args.model))

+ 3 - 5
end-to-end-use-cases/coding/text2sql/fine-tuning/README.md

@@ -101,18 +101,16 @@ YOUR_API_KEY='finetuned'
 model='fine_tuning/llama31-8b-text2sql-fft-nonquantized-cot'
 ```
 
-2. Uncomment the line `SYSTEM_PROMPT` [here](https://github.com/meta-llama/llama-cookbook/blob/text2sql/end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py#L17) in `llama_text2sql.py` to use it with the reasoning dataset fine-tuned model.
-
-3. Start the vllm server by running
+2. Start the vllm server by running
 ```
 vllm serve fine_tuning/llama31-8b-text2sql-fft-nonquantized-cot --tensor-parallel-size 1 --max-num-batched-tokens 8192 --max-num-seqs 64
 ```
-If you have multiple GPUs you can run something like 
+If you have multiple GPUs you can run something like
 ```
 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 vllm serve fine_tuning/llama31-8b-text2sql-fft-nonquantized-cot --tensor-parallel-size 8 --max-num-batched-tokens 8192 --max-num-seqs 64
 ```
  to speed up the eval.
 
-4. Run `sh llama_eval.sh`.
+3. Run `sh llama_eval.sh`.
 
 **Note:** If your fine-tuned model is PEFT based, you may need to run `python merge_peft.py` after modifying its `peft_model_path` and `output_dir` and set the merged folder path after `vllm serve`.