Browse Source

batch processing and vllm llama call in parallel; clean progress showing in 2-step eval

Jeff Tang 4 days ago
parent
commit
77d3544c81

+ 18 - 7
end-to-end-use-cases/coding/text2sql/eval/llama_eval.sh

@@ -1,10 +1,13 @@
+# Set to "true" to enable debug mode with detailed prints
+DEBUG_MODE="false"
+
 eval_path='../data/dev_20240627/dev.json'
 db_root_path='../data/dev_20240627/dev_databases/'
 ground_truth_path='../data/'
 
 # Llama models on Llama API
-YOUR_API_KEY='YOUR_LLAMA_API_KEY'
-model='Llama-3.3-8B-Instruct'
+# YOUR_API_KEY='YOUR_LLAMA_API_KEY'
+# model='Llama-3.3-8B-Instruct'
 #model='Llama-3.3-70B-Instruct'
 #model='Llama-4-Maverick-17B-128E-Instruct-FP8'
 #model='Llama-4-Scout-17B-16E-Instruct-FP8'
@@ -14,8 +17,8 @@ model='Llama-3.3-8B-Instruct'
 # model='meta-llama/Llama-3.1-8B-Instruct'
 
 # Fine-tuned Llama models locally
-# YOUR_API_KEY='finetuned'
-# model='../fine-tuning/final_test/llama31-8b-text2sql-peft-quantized-cot_merged'
+YOUR_API_KEY='finetuned'
+model='../fine-tuning/llama31-8b-text2sql-fft-nonquantized-cot-epochs-3'
 
 data_output_path="./output/$model/"
 
@@ -26,9 +29,17 @@ python3 -u llama_text2sql.py --db_root_path ${db_root_path} --api_key ${YOUR_API
 # Check if llama_text2sql.py exited successfully
 if [ $? -eq 0 ]; then
     echo "llama_text2sql.py completed successfully. Proceeding with evaluation..."
-    python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \
-    --ground_truth_path ${ground_truth_path} \
-    --diff_json_path ${eval_path}
+
+    # Add --debug flag if DEBUG_MODE is true
+    if [ "$DEBUG_MODE" = "true" ]; then
+        python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \
+        --ground_truth_path ${ground_truth_path} \
+        --diff_json_path ${eval_path} --debug
+    else
+        python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \
+        --ground_truth_path ${ground_truth_path} \
+        --diff_json_path ${eval_path}
+    fi
 
     echo "Done evaluating $model."
 

+ 216 - 20
end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py

@@ -1,4 +1,5 @@
 import argparse
+import concurrent.futures
 import json
 import os
 import re
@@ -6,6 +7,7 @@ import sqlite3
 from typing import Dict
 
 from llama_api_client import LlamaAPIClient
+from tqdm import tqdm
 
 MAX_NEW_TOKENS = 10240  # If API has max tokens (vs max new tokens), we calculate it
 TIMEOUT = 60  # Timeout in seconds for each API call
@@ -25,20 +27,98 @@ def local_llama(client, prompt, model):
         model=model,
         messages=messages,
         timeout=TIMEOUT,
+        temperature=0,
     )
     answer = chat_response.choices[0].message.content.strip()
 
     pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
     matches = pattern.findall(answer)
-    if matches != []:
-        result = matches[0]
-    else:
+    if not matches:
         result = answer
+    else:
+        result = matches[0]
 
     print(f"{result=}")
     return result
 
 
+def batch_local_llama(client, prompts, model, max_workers=8):
+    """
+    Process multiple prompts in parallel using the vllm server.
+
+    Args:
+        client: OpenAI client
+        prompts: List of prompts to process
+        model: Model name
+        max_workers: Maximum number of parallel workers
+
+    Returns:
+        List of results in the same order as prompts
+    """
+    SYSTEM_PROMPT = (
+        "You are a text to SQL query translator. Using the SQLite DB Schema "
+        "and the External Knowledge, translate the following text question "
+        "into a SQLite SQL select statement."
+    )
+
+    def process_single_prompt(prompt):
+        messages = [
+            {"content": SYSTEM_PROMPT, "role": "system"},
+            {"role": "user", "content": prompt},
+        ]
+        try:
+            chat_response = client.chat.completions.create(
+                model=model,
+                messages=messages,
+                timeout=TIMEOUT,
+                temperature=0,
+            )
+            answer = chat_response.choices[0].message.content.strip()
+
+            pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
+            matches = pattern.findall(answer)
+            if not matches:
+                result = answer
+            else:
+                result = matches[0]
+
+            return result
+        except Exception as e:
+            print(f"Error processing prompt: {e}")
+            return f"error:{e}"
+
+    print(
+        f"batch_local_llama: Processing {len(prompts)} prompts with {model=} "
+        f"using {max_workers} workers"
+    )
+    results = []
+
+    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
+        # Submit all tasks and create a map of futures to their indices
+        future_to_index = {
+            executor.submit(process_single_prompt, prompt): i
+            for i, prompt in enumerate(prompts)
+        }
+
+        # Initialize results list with None values
+        results = [None] * len(prompts)
+
+        # Process completed futures as they complete
+        for future in tqdm(
+            concurrent.futures.as_completed(future_to_index),
+            total=len(prompts),
+            desc="Processing prompts",
+        ):
+            index = future_to_index[future]
+            try:
+                results[index] = future.result()
+            except Exception as e:
+                print(f"Error processing prompt at index {index}: {e}")
+                results[index] = f"error:{e}"
+
+    return results
+
+
 def new_directory(path):
     if not os.path.exists(path):
         os.makedirs(path)
@@ -235,7 +315,7 @@ def collect_response_from_llama(
                 temperature=0,
                 stop=["--", "\n\n", ";", "#"],
             )
-        if type(plain_result) == str:
+        if isinstance(plain_result, str):
             sql = plain_result
         else:
             sql = "SELECT" + plain_result["choices"][0]["text"]
@@ -250,6 +330,89 @@ def collect_response_from_llama(
     return response_list
 
 
+def batch_collect_response_from_llama(
+    db_path_list, question_list, api_key, model, knowledge_list=None, batch_size=8
+):
+    """
+    Process multiple questions in parallel using the vllm server.
+
+    Args:
+        db_path_list: List of database paths
+        question_list: List of questions
+        api_key: API key
+        model: Model name
+        knowledge_list: List of knowledge strings (optional)
+        batch_size: Number of parallel requests
+
+    Returns:
+        List of SQL responses
+    """
+    if api_key in ["huggingface", "finetuned"]:
+        from openai import OpenAI
+
+        openai_api_key = "EMPTY"
+        openai_api_base = "http://localhost:8000/v1"
+
+        client = OpenAI(
+            api_key=openai_api_key,
+            base_url=openai_api_base,
+        )
+    else:
+        client = LlamaAPIClient()
+
+    # Generate all prompts first
+    prompts = []
+    for i, question in enumerate(question_list):
+        if knowledge_list:
+            cur_prompt = generate_combined_prompts_one(
+                db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
+            )
+        else:
+            cur_prompt = generate_combined_prompts_one(
+                db_path=db_path_list[i], question=question
+            )
+        prompts.append(cur_prompt)
+
+    print(f"Generated {len(prompts)} prompts for batch processing")
+
+    # Process prompts in parallel
+    if api_key in ["huggingface", "finetuned"]:
+        results = batch_local_llama(
+            client=client, prompts=prompts, model=model, max_workers=batch_size
+        )
+    else:
+        # For cloud API, we could implement a batch version of cloud_llama if needed
+        # For now, just process sequentially
+        results = []
+        for prompt in prompts:
+            plain_result = cloud_llama(
+                client=client,
+                api_key=api_key,
+                model=model,
+                prompt=prompt,
+                max_tokens=10240,
+                temperature=0,
+                stop=["--", "\n\n", ";", "#"],
+            )
+            results.append(plain_result)
+
+    # Format results
+    response_list = []
+    for i, result in enumerate(results):
+        if isinstance(result, str):
+            sql = result
+        else:
+            sql = "SELECT" + result["choices"][0]["text"]
+
+        db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
+        sql = (
+            sql + "\t----- bird -----\t" + db_id
+        )  # to avoid unpredicted \t appearing in codex results
+        response_list.append(sql)
+
+    return response_list
+
+
 def question_package(data_json, knowledge=False):
     question_list = []
     for data in data_json:
@@ -302,9 +465,18 @@ if __name__ == "__main__":
     args_parser.add_argument("--api_key", type=str, required=True)
     args_parser.add_argument("--model", type=str, required=True)
     args_parser.add_argument("--data_output_path", type=str)
+    args_parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=8,
+        help="Number of parallel requests for batch processing",
+    )
+    args_parser.add_argument(
+        "--use_batch", type=str, default="True", help="Whether to use batch processing"
+    )
     args = args_parser.parse_args()
 
-    if not args.api_key in ["huggingface", "finetuned"]:
+    if args.api_key not in ["huggingface", "finetuned"]:
         os.environ["LLAMA_API_KEY"] = args.api_key
 
         try:
@@ -332,22 +504,46 @@ if __name__ == "__main__":
     )
     assert len(question_list) == len(db_path_list) == len(knowledge_list)
 
-    if args.use_knowledge == "True":
-        responses = collect_response_from_llama(  # collect_batch_response_from_llama
-            db_path_list=db_path_list,
-            question_list=question_list,
-            api_key=args.api_key,
-            model=args.model,
-            knowledge_list=knowledge_list,
-        )
+    use_batch = args.use_batch.lower() == "true"
+
+    if use_batch:
+        print(f"Using batch processing with batch_size={args.batch_size}")
+        if args.use_knowledge == "True":
+            responses = batch_collect_response_from_llama(
+                db_path_list=db_path_list,
+                question_list=question_list,
+                api_key=args.api_key,
+                model=args.model,
+                knowledge_list=knowledge_list,
+                batch_size=args.batch_size,
+            )
+        else:
+            responses = batch_collect_response_from_llama(
+                db_path_list=db_path_list,
+                question_list=question_list,
+                api_key=args.api_key,
+                model=args.model,
+                knowledge_list=None,
+                batch_size=args.batch_size,
+            )
     else:
-        responses = collect_response_from_llama(
-            db_path_list=db_path_list,
-            question_list=question_list,
-            api_key=args.api_key,
-            model=args.model,
-            knowledge_list=None,
-        )
+        print("Using sequential processing")
+        if args.use_knowledge == "True":
+            responses = collect_response_from_llama(
+                db_path_list=db_path_list,
+                question_list=question_list,
+                api_key=args.api_key,
+                model=args.model,
+                knowledge_list=knowledge_list,
+            )
+        else:
+            responses = collect_response_from_llama(
+                db_path_list=db_path_list,
+                question_list=question_list,
+                api_key=args.api_key,
+                model=args.model,
+                knowledge_list=None,
+            )
 
     output_name = args.data_output_path + "predict_" + args.mode + ".json"
 

+ 64 - 19
end-to-end-use-cases/coding/text2sql/eval/text2sql_eval.py

@@ -5,6 +5,7 @@ import sqlite3
 import sys
 
 from func_timeout import func_timeout, FunctionTimedOut
+from tqdm import tqdm
 
 
 def load_json(dir):
@@ -17,7 +18,7 @@ def result_callback(result):
     exec_result.append(result)
 
 
-def execute_sql(predicted_sql, ground_truth, db_path):
+def execute_sql(predicted_sql, ground_truth, db_path, debug=False):
     conn = sqlite3.connect(db_path)
     # Connect to the database
     cursor = conn.cursor()
@@ -28,7 +29,7 @@ def execute_sql(predicted_sql, ground_truth, db_path):
     res = 0
     if set(predicted_res) == set(ground_truth_res):
         res = 1
-    else:
+    elif debug:
         print(
             f"\n\n==== INCORRECT SQL GENERATED ====\n{predicted_sql=}\n{predicted_res=}\n{ground_truth=}\n{ground_truth_res=}\n======\n\n"
         )
@@ -36,10 +37,14 @@ def execute_sql(predicted_sql, ground_truth, db_path):
     return res
 
 
-def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out):
+def execute_model(
+    predicted_sql, ground_truth, db_place, idx, meta_time_out, debug=False
+):
     try:
         res = func_timeout(
-            meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place)
+            meta_time_out,
+            execute_sql,
+            args=(predicted_sql, ground_truth, db_place, debug),
         )
     except KeyboardInterrupt:
         sys.exit(0)
@@ -79,19 +84,35 @@ def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
     return clean_sqls, db_path_list
 
 
-def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0):
+def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0, debug=False):
     pool = mp.Pool(processes=num_cpus)
-    for i, sql_pair in enumerate(sqls):
 
+    # Create a progress bar if not in debug mode
+    if not debug:
+        pbar = tqdm(total=len(sqls), desc="Evaluating SQL queries")
+
+    for i, sql_pair in enumerate(sqls):
         predicted_sql, ground_truth = sql_pair
         pool.apply_async(
             execute_model,
-            args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out),
-            callback=result_callback,
+            args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out, debug),
+            callback=lambda result: result_callback_with_progress(
+                result, not debug, pbar
+            ),
         )
     pool.close()
     pool.join()
 
+    # Close the progress bar if not in debug mode
+    if not debug:
+        pbar.close()
+
+
+def result_callback_with_progress(result, use_progress, pbar=None):
+    exec_result.append(result)
+    if use_progress and pbar:
+        pbar.update(1)
+
 
 def sort_results(list_of_dicts):
     return sorted(list_of_dicts, key=lambda x: x["sql_idx"])
@@ -137,14 +158,19 @@ def compute_acc_by_diff(exec_results, diff_json_path):
     )
 
 
-def print_data(score_lists, count_lists):
+def print_data(score_lists, count_lists, debug=False):
     levels = ["simple", "moderate", "challenging", "total"]
-    print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
-    print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists))
 
-    print(
-        "======================================    ACCURACY    ====================================="
-    )
+    if debug:
+        print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
+        print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists))
+        print(
+            "======================================    ACCURACY    ====================================="
+        )
+    else:
+        print("\nEvaluation Results:")
+        print("-" * 40)
+
     print(
         "{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists)
     )
@@ -164,9 +190,19 @@ if __name__ == "__main__":
     args_parser.add_argument("--mode_predict", type=str, default="gpt")
     args_parser.add_argument("--difficulty", type=str, default="simple")
     args_parser.add_argument("--diff_json_path", type=str, default="")
+    args_parser.add_argument(
+        "--debug", action="store_true", help="Enable debug mode with detailed prints"
+    )
     args = args_parser.parse_args()
     exec_result = []
 
+    if args.debug:
+        print("Debug mode enabled - showing detailed output")
+
+    # Show loading progress if not in debug mode
+    if not args.debug:
+        print("Loading SQL queries and database paths...")
+
     pred_queries, db_paths = package_sqls(
         args.predicted_sql_path,
         args.db_root_path,
@@ -179,20 +215,29 @@ if __name__ == "__main__":
     )
 
     query_pairs = list(zip(pred_queries, gt_queries))
+
+    if args.debug:
+        print(f"Executing {len(query_pairs)} SQL query pairs...")
+
     run_sqls_parallel(
         query_pairs,
         db_places=db_paths,
         num_cpus=args.num_cpus,
         meta_time_out=args.meta_time_out,
+        debug=args.debug,
     )
     exec_result = sort_results(exec_result)
 
-    print("Evaluating statistics...")
+    if args.debug:
+        print("Evaluating statistics...")
+
     simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff(
         exec_result, args.diff_json_path
     )
     score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
-    print_data(score_lists, count_lists)
-    print(
-        "==========================================================================================="
-    )
+    print_data(score_lists, count_lists, debug=args.debug)
+
+    if args.debug:
+        print(
+            "==========================================================================================="
+        )