|
@@ -1,4 +1,5 @@
|
|
import argparse
|
|
import argparse
|
|
|
|
+import concurrent.futures
|
|
import json
|
|
import json
|
|
import os
|
|
import os
|
|
import re
|
|
import re
|
|
@@ -6,6 +7,7 @@ import sqlite3
|
|
from typing import Dict
|
|
from typing import Dict
|
|
|
|
|
|
from llama_api_client import LlamaAPIClient
|
|
from llama_api_client import LlamaAPIClient
|
|
|
|
+from tqdm import tqdm
|
|
|
|
|
|
MAX_NEW_TOKENS = 10240 # If API has max tokens (vs max new tokens), we calculate it
|
|
MAX_NEW_TOKENS = 10240 # If API has max tokens (vs max new tokens), we calculate it
|
|
TIMEOUT = 60 # Timeout in seconds for each API call
|
|
TIMEOUT = 60 # Timeout in seconds for each API call
|
|
@@ -25,20 +27,98 @@ def local_llama(client, prompt, model):
|
|
model=model,
|
|
model=model,
|
|
messages=messages,
|
|
messages=messages,
|
|
timeout=TIMEOUT,
|
|
timeout=TIMEOUT,
|
|
|
|
+ temperature=0,
|
|
)
|
|
)
|
|
answer = chat_response.choices[0].message.content.strip()
|
|
answer = chat_response.choices[0].message.content.strip()
|
|
|
|
|
|
pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
|
|
pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
|
|
matches = pattern.findall(answer)
|
|
matches = pattern.findall(answer)
|
|
- if matches != []:
|
|
|
|
- result = matches[0]
|
|
|
|
- else:
|
|
|
|
|
|
+ if not matches:
|
|
result = answer
|
|
result = answer
|
|
|
|
+ else:
|
|
|
|
+ result = matches[0]
|
|
|
|
|
|
print(f"{result=}")
|
|
print(f"{result=}")
|
|
return result
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
+def batch_local_llama(client, prompts, model, max_workers=8):
|
|
|
|
+ """
|
|
|
|
+ Process multiple prompts in parallel using the vllm server.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ client: OpenAI client
|
|
|
|
+ prompts: List of prompts to process
|
|
|
|
+ model: Model name
|
|
|
|
+ max_workers: Maximum number of parallel workers
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ List of results in the same order as prompts
|
|
|
|
+ """
|
|
|
|
+ SYSTEM_PROMPT = (
|
|
|
|
+ "You are a text to SQL query translator. Using the SQLite DB Schema "
|
|
|
|
+ "and the External Knowledge, translate the following text question "
|
|
|
|
+ "into a SQLite SQL select statement."
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ def process_single_prompt(prompt):
|
|
|
|
+ messages = [
|
|
|
|
+ {"content": SYSTEM_PROMPT, "role": "system"},
|
|
|
|
+ {"role": "user", "content": prompt},
|
|
|
|
+ ]
|
|
|
|
+ try:
|
|
|
|
+ chat_response = client.chat.completions.create(
|
|
|
|
+ model=model,
|
|
|
|
+ messages=messages,
|
|
|
|
+ timeout=TIMEOUT,
|
|
|
|
+ temperature=0,
|
|
|
|
+ )
|
|
|
|
+ answer = chat_response.choices[0].message.content.strip()
|
|
|
|
+
|
|
|
|
+ pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
|
|
|
|
+ matches = pattern.findall(answer)
|
|
|
|
+ if not matches:
|
|
|
|
+ result = answer
|
|
|
|
+ else:
|
|
|
|
+ result = matches[0]
|
|
|
|
+
|
|
|
|
+ return result
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"Error processing prompt: {e}")
|
|
|
|
+ return f"error:{e}"
|
|
|
|
+
|
|
|
|
+ print(
|
|
|
|
+ f"batch_local_llama: Processing {len(prompts)} prompts with {model=} "
|
|
|
|
+ f"using {max_workers} workers"
|
|
|
|
+ )
|
|
|
|
+ results = []
|
|
|
|
+
|
|
|
|
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
|
+ # Submit all tasks and create a map of futures to their indices
|
|
|
|
+ future_to_index = {
|
|
|
|
+ executor.submit(process_single_prompt, prompt): i
|
|
|
|
+ for i, prompt in enumerate(prompts)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ # Initialize results list with None values
|
|
|
|
+ results = [None] * len(prompts)
|
|
|
|
+
|
|
|
|
+ # Process completed futures as they complete
|
|
|
|
+ for future in tqdm(
|
|
|
|
+ concurrent.futures.as_completed(future_to_index),
|
|
|
|
+ total=len(prompts),
|
|
|
|
+ desc="Processing prompts",
|
|
|
|
+ ):
|
|
|
|
+ index = future_to_index[future]
|
|
|
|
+ try:
|
|
|
|
+ results[index] = future.result()
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"Error processing prompt at index {index}: {e}")
|
|
|
|
+ results[index] = f"error:{e}"
|
|
|
|
+
|
|
|
|
+ return results
|
|
|
|
+
|
|
|
|
+
|
|
def new_directory(path):
|
|
def new_directory(path):
|
|
if not os.path.exists(path):
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
os.makedirs(path)
|
|
@@ -235,7 +315,7 @@ def collect_response_from_llama(
|
|
temperature=0,
|
|
temperature=0,
|
|
stop=["--", "\n\n", ";", "#"],
|
|
stop=["--", "\n\n", ";", "#"],
|
|
)
|
|
)
|
|
- if type(plain_result) == str:
|
|
|
|
|
|
+ if isinstance(plain_result, str):
|
|
sql = plain_result
|
|
sql = plain_result
|
|
else:
|
|
else:
|
|
sql = "SELECT" + plain_result["choices"][0]["text"]
|
|
sql = "SELECT" + plain_result["choices"][0]["text"]
|
|
@@ -250,6 +330,89 @@ def collect_response_from_llama(
|
|
return response_list
|
|
return response_list
|
|
|
|
|
|
|
|
|
|
|
|
+def batch_collect_response_from_llama(
|
|
|
|
+ db_path_list, question_list, api_key, model, knowledge_list=None, batch_size=8
|
|
|
|
+):
|
|
|
|
+ """
|
|
|
|
+ Process multiple questions in parallel using the vllm server.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ db_path_list: List of database paths
|
|
|
|
+ question_list: List of questions
|
|
|
|
+ api_key: API key
|
|
|
|
+ model: Model name
|
|
|
|
+ knowledge_list: List of knowledge strings (optional)
|
|
|
|
+ batch_size: Number of parallel requests
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ List of SQL responses
|
|
|
|
+ """
|
|
|
|
+ if api_key in ["huggingface", "finetuned"]:
|
|
|
|
+ from openai import OpenAI
|
|
|
|
+
|
|
|
|
+ openai_api_key = "EMPTY"
|
|
|
|
+ openai_api_base = "http://localhost:8000/v1"
|
|
|
|
+
|
|
|
|
+ client = OpenAI(
|
|
|
|
+ api_key=openai_api_key,
|
|
|
|
+ base_url=openai_api_base,
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ client = LlamaAPIClient()
|
|
|
|
+
|
|
|
|
+ # Generate all prompts first
|
|
|
|
+ prompts = []
|
|
|
|
+ for i, question in enumerate(question_list):
|
|
|
|
+ if knowledge_list:
|
|
|
|
+ cur_prompt = generate_combined_prompts_one(
|
|
|
|
+ db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ cur_prompt = generate_combined_prompts_one(
|
|
|
|
+ db_path=db_path_list[i], question=question
|
|
|
|
+ )
|
|
|
|
+ prompts.append(cur_prompt)
|
|
|
|
+
|
|
|
|
+ print(f"Generated {len(prompts)} prompts for batch processing")
|
|
|
|
+
|
|
|
|
+ # Process prompts in parallel
|
|
|
|
+ if api_key in ["huggingface", "finetuned"]:
|
|
|
|
+ results = batch_local_llama(
|
|
|
|
+ client=client, prompts=prompts, model=model, max_workers=batch_size
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ # For cloud API, we could implement a batch version of cloud_llama if needed
|
|
|
|
+ # For now, just process sequentially
|
|
|
|
+ results = []
|
|
|
|
+ for prompt in prompts:
|
|
|
|
+ plain_result = cloud_llama(
|
|
|
|
+ client=client,
|
|
|
|
+ api_key=api_key,
|
|
|
|
+ model=model,
|
|
|
|
+ prompt=prompt,
|
|
|
|
+ max_tokens=10240,
|
|
|
|
+ temperature=0,
|
|
|
|
+ stop=["--", "\n\n", ";", "#"],
|
|
|
|
+ )
|
|
|
|
+ results.append(plain_result)
|
|
|
|
+
|
|
|
|
+ # Format results
|
|
|
|
+ response_list = []
|
|
|
|
+ for i, result in enumerate(results):
|
|
|
|
+ if isinstance(result, str):
|
|
|
|
+ sql = result
|
|
|
|
+ else:
|
|
|
|
+ sql = "SELECT" + result["choices"][0]["text"]
|
|
|
|
+
|
|
|
|
+ db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
|
|
|
|
+ sql = (
|
|
|
|
+ sql + "\t----- bird -----\t" + db_id
|
|
|
|
+ ) # to avoid unpredicted \t appearing in codex results
|
|
|
|
+ response_list.append(sql)
|
|
|
|
+
|
|
|
|
+ return response_list
|
|
|
|
+
|
|
|
|
+
|
|
def question_package(data_json, knowledge=False):
|
|
def question_package(data_json, knowledge=False):
|
|
question_list = []
|
|
question_list = []
|
|
for data in data_json:
|
|
for data in data_json:
|
|
@@ -302,9 +465,18 @@ if __name__ == "__main__":
|
|
args_parser.add_argument("--api_key", type=str, required=True)
|
|
args_parser.add_argument("--api_key", type=str, required=True)
|
|
args_parser.add_argument("--model", type=str, required=True)
|
|
args_parser.add_argument("--model", type=str, required=True)
|
|
args_parser.add_argument("--data_output_path", type=str)
|
|
args_parser.add_argument("--data_output_path", type=str)
|
|
|
|
+ args_parser.add_argument(
|
|
|
|
+ "--batch_size",
|
|
|
|
+ type=int,
|
|
|
|
+ default=8,
|
|
|
|
+ help="Number of parallel requests for batch processing",
|
|
|
|
+ )
|
|
|
|
+ args_parser.add_argument(
|
|
|
|
+ "--use_batch", type=str, default="True", help="Whether to use batch processing"
|
|
|
|
+ )
|
|
args = args_parser.parse_args()
|
|
args = args_parser.parse_args()
|
|
|
|
|
|
- if not args.api_key in ["huggingface", "finetuned"]:
|
|
|
|
|
|
+ if args.api_key not in ["huggingface", "finetuned"]:
|
|
os.environ["LLAMA_API_KEY"] = args.api_key
|
|
os.environ["LLAMA_API_KEY"] = args.api_key
|
|
|
|
|
|
try:
|
|
try:
|
|
@@ -332,22 +504,46 @@ if __name__ == "__main__":
|
|
)
|
|
)
|
|
assert len(question_list) == len(db_path_list) == len(knowledge_list)
|
|
assert len(question_list) == len(db_path_list) == len(knowledge_list)
|
|
|
|
|
|
- if args.use_knowledge == "True":
|
|
|
|
- responses = collect_response_from_llama( # collect_batch_response_from_llama
|
|
|
|
- db_path_list=db_path_list,
|
|
|
|
- question_list=question_list,
|
|
|
|
- api_key=args.api_key,
|
|
|
|
- model=args.model,
|
|
|
|
- knowledge_list=knowledge_list,
|
|
|
|
- )
|
|
|
|
|
|
+ use_batch = args.use_batch.lower() == "true"
|
|
|
|
+
|
|
|
|
+ if use_batch:
|
|
|
|
+ print(f"Using batch processing with batch_size={args.batch_size}")
|
|
|
|
+ if args.use_knowledge == "True":
|
|
|
|
+ responses = batch_collect_response_from_llama(
|
|
|
|
+ db_path_list=db_path_list,
|
|
|
|
+ question_list=question_list,
|
|
|
|
+ api_key=args.api_key,
|
|
|
|
+ model=args.model,
|
|
|
|
+ knowledge_list=knowledge_list,
|
|
|
|
+ batch_size=args.batch_size,
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ responses = batch_collect_response_from_llama(
|
|
|
|
+ db_path_list=db_path_list,
|
|
|
|
+ question_list=question_list,
|
|
|
|
+ api_key=args.api_key,
|
|
|
|
+ model=args.model,
|
|
|
|
+ knowledge_list=None,
|
|
|
|
+ batch_size=args.batch_size,
|
|
|
|
+ )
|
|
else:
|
|
else:
|
|
- responses = collect_response_from_llama(
|
|
|
|
- db_path_list=db_path_list,
|
|
|
|
- question_list=question_list,
|
|
|
|
- api_key=args.api_key,
|
|
|
|
- model=args.model,
|
|
|
|
- knowledge_list=None,
|
|
|
|
- )
|
|
|
|
|
|
+ print("Using sequential processing")
|
|
|
|
+ if args.use_knowledge == "True":
|
|
|
|
+ responses = collect_response_from_llama(
|
|
|
|
+ db_path_list=db_path_list,
|
|
|
|
+ question_list=question_list,
|
|
|
|
+ api_key=args.api_key,
|
|
|
|
+ model=args.model,
|
|
|
|
+ knowledge_list=knowledge_list,
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ responses = collect_response_from_llama(
|
|
|
|
+ db_path_list=db_path_list,
|
|
|
|
+ question_list=question_list,
|
|
|
|
+ api_key=args.api_key,
|
|
|
|
+ model=args.model,
|
|
|
|
+ knowledge_list=None,
|
|
|
|
+ )
|
|
|
|
|
|
output_name = args.data_output_path + "predict_" + args.mode + ".json"
|
|
output_name = args.data_output_path + "predict_" + args.mode + ".json"
|
|
|
|
|