ソースを参照

vllm enabled eval for HF and fine-tuned models; code cleanup and refactoring for text2sql_eval; minimum eval packages for eval requirements; merge peft script to make vllm happy

Jeff Tang 3 ヶ月 前
コミット
5baa1e3fd7

+ 3 - 12
end-to-end-use-cases/coding/text2sql/eval/llama_eval.sh

@@ -2,14 +2,6 @@ eval_path='../data/dev_20240627/dev.json'
 db_root_path='../data/dev_20240627/dev_databases/'
 ground_truth_path='../data/'
 
-# Llama models on Together
-#YOUR_API_KEY='YOUR_TOGETHER_API_KEY'
-#model='meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
-#model='meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo'
-#model='meta-llama/Llama-3.3-70B-Instruct-Turbo'
-#model='meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8'
-#model='meta-llama/Llama-4-Scout-17B-16E-Instruct'
-
 # Llama models on Llama API
 YOUR_API_KEY='YOUR_LLAMA_API_KEY'
 model='Llama-3.3-8B-Instruct'
@@ -17,14 +9,13 @@ model='Llama-3.3-8B-Instruct'
 #model='Llama-4-Maverick-17B-128E-Instruct-FP8'
 #model='Llama-4-Scout-17B-16E-Instruct-FP8'
 
-# Llama model on Hugging Face Hub
-# https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct
+# Llama model on Hugging Face Hub https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct
 # YOUR_API_KEY='huggingface'
 # model='meta-llama/Llama-3.1-8B-Instruct'
 
 # Fine-tuned Llama models locally
-#YOUR_API_KEY='finetuned'
-#model='../fine_tuning/llama31-8b-text2sql-fine-tuned'
+# YOUR_API_KEY='finetuned'
+# model='../fine-tuning/final_test/llama31-8b-text2sql-peft-quantized-cot_merged'
 
 data_output_path="./output/$model/"
 

+ 54 - 143
end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py

@@ -5,21 +5,13 @@ import re
 import sqlite3
 from typing import Dict
 
-import torch
-from langchain_together import ChatTogether
 from llama_api_client import LlamaAPIClient
-from peft import AutoPeftModelForCausalLM
-from tqdm import tqdm
-from transformers import (
-    AutoModelForCausalLM,
-    AutoTokenizer,
-    BitsAndBytesConfig,
-    pipeline,
-)
-
-MAX_NEW_TOKENS=10240  # If API has max tokens (vs max new tokens), we calculate it
-
-def local_llama(prompt, pipe):
+
+MAX_NEW_TOKENS = 10240  # If API has max tokens (vs max new tokens), we calculate it
+TIMEOUT = 60  # Timeout in seconds for each API call
+
+
+def local_llama(client, prompt, model):
     SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
     # UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
     # SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
@@ -28,27 +20,13 @@ def local_llama(prompt, pipe):
         {"content": SYSTEM_PROMPT, "role": "system"},
         {"role": "user", "content": prompt},
     ]
-
-    raw_prompt = pipe.tokenizer.apply_chat_template(
-        messages,
-        tokenize=False,
-        add_generation_prompt=True,
-    )
-
-    print(f"local_llama: {raw_prompt=}")
-
-    outputs = pipe(
-        raw_prompt,
-        max_new_tokens=MAX_NEW_TOKENS,
-        do_sample=False,
-        temperature=0.0,
-        top_k=50,
-        top_p=0.1,
-        eos_token_id=pipe.tokenizer.eos_token_id,
-        pad_token_id=pipe.tokenizer.pad_token_id,
+    print(f"local_llama: {model=}")
+    chat_response = client.chat.completions.create(
+        model=model,
+        messages=messages,
+        timeout=TIMEOUT,
     )
-
-    answer = outputs[0]["generated_text"][len(raw_prompt) :].strip()
+    answer = chat_response.choices[0].message.content.strip()
 
     pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
     matches = pattern.findall(answer)
@@ -98,12 +76,9 @@ def nice_look_table(column_names: list, values: list):
         for i in range(len(column_names))
     ]
 
-    # Print the column names
     header = "".join(
         f"{column.rjust(width)} " for column, width in zip(column_names, widths)
     )
-    # print(header)
-    # Print the values
     for value in values:
         row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
         rows.append(row)
@@ -176,33 +151,22 @@ def generate_combined_prompts_one(db_path, question, knowledge=None):
     return combined_prompts
 
 
-def cloud_llama(api_key, model, prompt, max_tokens, temperature, stop):
+def cloud_llama(client, api_key, model, prompt, max_tokens, temperature, stop):
 
     SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
     try:
-        if model.startswith("meta-llama/"):
-            final_prompt = SYSTEM_PROMPT + "\n\n" + prompt
-            final_max_tokens = len(final_prompt) + MAX_NEW_TOKENS
-            llm = ChatTogether(
-                model=model,
-                temperature=0,
-                max_tokens=final_max_tokens,
-            )
-            answer = llm.invoke(final_prompt).content
-        else:
-            client = LlamaAPIClient()
-            messages = [
-                {"content": SYSTEM_PROMPT, "role": "system"},
-                {"role": "user", "content": prompt},
-            ]
-            final_max_tokens = len(messages) + MAX_NEW_TOKENS
-            response = client.chat.completions.create(
-                model=model,
-                messages=messages,
-                temperature=0,
-                max_completion_tokens=final_max_tokens,
-            )
-            answer = response.completion_message.content.text
+        messages = [
+            {"content": SYSTEM_PROMPT, "role": "system"},
+            {"role": "user", "content": prompt},
+        ]
+        final_max_tokens = len(messages) + MAX_NEW_TOKENS
+        response = client.chat.completions.create(
+            model=model,
+            messages=messages,
+            temperature=0,
+            max_completion_tokens=final_max_tokens,
+        )
+        answer = response.completion_message.content.text
 
         pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
         matches = pattern.findall(answer)
@@ -218,57 +182,6 @@ def cloud_llama(api_key, model, prompt, max_tokens, temperature, stop):
     return result
 
 
-def huggingface_finetuned(api_key, model):
-    if api_key == "finetuned":
-        model_id = model
-
-        # Check if this is a PEFT model by looking for adapter_config.json
-        import os
-
-        is_peft_model = os.path.exists(os.path.join(model_id, "adapter_config.json"))
-
-        if is_peft_model:
-            # Use AutoPeftModelForCausalLM for PEFT fine-tuned models
-            print(f"Loading PEFT model from {model_id}")
-            model = AutoPeftModelForCausalLM.from_pretrained(
-                model_id, device_map="auto", torch_dtype=torch.float16
-            )
-            tokenizer = AutoTokenizer.from_pretrained(model_id)
-        else:
-            # Use AutoModelForCausalLM for FFT (Full Fine-Tuning) models
-            print(f"Loading FFT model from {model_id}")
-            model = AutoModelForCausalLM.from_pretrained(
-                model_id, device_map="auto", torch_dtype=torch.float16
-            )
-            tokenizer = AutoTokenizer.from_pretrained(model_id)
-
-            # For FFT models, handle pad token if it was added during training
-            if tokenizer.pad_token is None:
-                tokenizer.add_special_tokens({"pad_token": "[PAD]"})
-                model.resize_token_embeddings(len(tokenizer))
-
-        tokenizer.padding_side = "right"  # to prevent warnings
-
-    elif api_key == "huggingface":
-        model_id = model
-        bnb_config = BitsAndBytesConfig(
-            load_in_4bit=True,
-            bnb_4bit_use_double_quant=True,
-            bnb_4bit_quant_type="nf4",
-            bnb_4bit_compute_dtype=torch.bfloat16,
-        )
-        model = AutoModelForCausalLM.from_pretrained(
-            model_id,
-            device_map="auto",
-            torch_dtype=torch.bfloat16,
-            quantization_config=bnb_config,  # None
-        )
-        tokenizer = AutoTokenizer.from_pretrained(model_id)
-
-    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
-    return pipe
-
-
 def collect_response_from_llama(
     db_path_list, question_list, api_key, model, knowledge_list=None
 ):
@@ -280,9 +193,19 @@ def collect_response_from_llama(
     response_list = []
 
     if api_key in ["huggingface", "finetuned"]:
-        pipe = huggingface_finetuned(api_key=api_key, model=model)
+        from openai import OpenAI
+
+        openai_api_key = "EMPTY"
+        openai_api_base = "http://localhost:8000/v1"
 
-    for i, question in tqdm(enumerate(question_list)):
+        client = OpenAI(
+            api_key=openai_api_key,
+            base_url=openai_api_base,
+        )
+    else:
+        client = LlamaAPIClient()
+
+    for i, question in enumerate(question_list):
         print(
             "--------------------- processing question #{}---------------------".format(
                 i + 1
@@ -300,9 +223,11 @@ def collect_response_from_llama(
             )
 
         if api_key in ["huggingface", "finetuned"]:
-            plain_result = local_llama(prompt=cur_prompt, pipe=pipe)
+            plain_result = local_llama(client=client, prompt=cur_prompt, model=model)
         else:
+
             plain_result = cloud_llama(
+                client=client,
                 api_key=api_key,
                 model=model,
                 prompt=cur_prompt,
@@ -310,7 +235,7 @@ def collect_response_from_llama(
                 temperature=0,
                 stop=["--", "\n\n", ";", "#"],
             )
-        if isinstance(plain_result, str):
+        if type(plain_result) == str:
             sql = plain_result
         else:
             sql = "SELECT" + plain_result["choices"][0]["text"]
@@ -379,37 +304,23 @@ if __name__ == "__main__":
     args_parser.add_argument("--data_output_path", type=str)
     args = args_parser.parse_args()
 
-    if args.api_key not in ["huggingface", "finetuned"]:
-        if args.model.startswith("meta-llama/"):  # Llama model on together
+    if not args.api_key in ["huggingface", "finetuned"]:
+        os.environ["LLAMA_API_KEY"] = args.api_key
+
+        try:
+            client = LlamaAPIClient()
 
-            os.environ["TOGETHER_API_KEY"] = args.api_key
-            llm = ChatTogether(
+            response = client.chat.completions.create(
                 model=args.model,
+                messages=[{"role": "user", "content": "125*125 is?"}],
                 temperature=0,
             )
-            try:
-                response = llm.invoke("125*125 is?").content
-                print(f"{response=}")
-            except Exception as exception:
-                print(f"{exception=}")
-                exit(1)
-        else:  # Llama model on Llama API
-            os.environ["LLAMA_API_KEY"] = args.api_key
-
-            try:
-                client = LlamaAPIClient()
-
-                response = client.chat.completions.create(
-                    model=args.model,
-                    messages=[{"role": "user", "content": "125*125 is?"}],
-                    temperature=0,
-                )
-                answer = response.completion_message.content.text
+            answer = response.completion_message.content.text
 
-                print(f"{answer=}")
-            except Exception as exception:
-                print(f"{exception=}")
-                exit(1)
+            print(f"{answer=}")
+        except Exception as exception:
+            print(f"{exception=}")
+            exit(1)
 
     eval_data = json.load(open(args.eval_path, "r"))
     # '''for debug'''
@@ -422,7 +333,7 @@ if __name__ == "__main__":
     assert len(question_list) == len(db_path_list) == len(knowledge_list)
 
     if args.use_knowledge == "True":
-        responses = collect_response_from_llama(
+        responses = collect_response_from_llama(  # collect_batch_response_from_llama
             db_path_list=db_path_list,
             question_list=question_list,
             api_key=args.api_key,

+ 5 - 18
end-to-end-use-cases/coding/text2sql/eval/requirements.txt

@@ -1,19 +1,6 @@
-llama_api_client==0.1.1
-langchain-together==0.3.0
-sqlparse==0.5.3
-torch==2.4.1
-tensorboard==2.19.0
-liger-kernel==0.4.2
-setuptools==78.1.1
-deepspeed==0.15.4
-transformers==4.46.3
-datasets==3.6.0
-accelerate==1.1.1
-bitsandbytes==0.44.1
-trl==0.12.1
-peft==0.13.2
-lighteval==0.6.2
-hf-transfer==0.1.8
+llama_api_client==0.1.2
 func_timeout==4.3.5
-vllm==0.9.2
-flashinfer-python==0.2.7.post1
+
+# uncomment to run vllm for eval with Llama 3.1 8B on HF and its fine-tuned models
+# vllm==0.9.2
+# openai==1.90.0

+ 5 - 4
end-to-end-use-cases/coding/text2sql/eval/text2sql_eval.py

@@ -4,7 +4,7 @@ import multiprocessing as mp
 import sqlite3
 import sys
 
-from func_timeout import FunctionTimedOut, func_timeout
+from func_timeout import func_timeout, FunctionTimedOut
 
 
 def load_json(dir):
@@ -44,10 +44,10 @@ def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out):
     except KeyboardInterrupt:
         sys.exit(0)
     except FunctionTimedOut:
-        result = [("timeout",)]
+        result = [(f"timeout",)]
         res = 0
     except Exception as e:
-        result = [("error",)]  # possibly len(query) > 512 or not executable
+        result = [(f"{e}",)]  # possibly len(query) > 512 or not executable
         res = 0
     result = {"sql_idx": idx, "res": res}
     return result
@@ -59,7 +59,7 @@ def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
     if mode == "gpt":
         sql_data = json.load(open(sql_path + "predict_" + data_mode + ".json", "r"))
         for idx, sql_str in sql_data.items():
-            if isinstance(sql_str, str):
+            if type(sql_str) == str:
                 sql, db_name = sql_str.split("\t----- bird -----\t")
             else:
                 sql, db_name = " ", "financial"
@@ -82,6 +82,7 @@ def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
 def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0):
     pool = mp.Pool(processes=num_cpus)
     for i, sql_pair in enumerate(sqls):
+
         predicted_sql, ground_truth = sql_pair
         pool.apply_async(
             execute_model,