|
@@ -13,36 +13,7 @@ MAX_NEW_TOKENS = 10240 # If API has max tokens (vs max new tokens), we calculat
|
|
|
TIMEOUT = 60 # Timeout in seconds for each API call
|
|
TIMEOUT = 60 # Timeout in seconds for each API call
|
|
|
|
|
|
|
|
|
|
|
|
|
-def local_llama(client, prompt, model):
|
|
|
|
|
- SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
|
|
|
|
|
- # UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
|
|
|
|
|
- # SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
|
|
|
|
|
-
|
|
|
|
|
- messages = [
|
|
|
|
|
- {"content": SYSTEM_PROMPT, "role": "system"},
|
|
|
|
|
- {"role": "user", "content": prompt},
|
|
|
|
|
- ]
|
|
|
|
|
- print(f"local_llama: {model=}")
|
|
|
|
|
- chat_response = client.chat.completions.create(
|
|
|
|
|
- model=model,
|
|
|
|
|
- messages=messages,
|
|
|
|
|
- timeout=TIMEOUT,
|
|
|
|
|
- temperature=0,
|
|
|
|
|
- )
|
|
|
|
|
- answer = chat_response.choices[0].message.content.strip()
|
|
|
|
|
-
|
|
|
|
|
- pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
|
|
|
|
|
- matches = pattern.findall(answer)
|
|
|
|
|
- if not matches:
|
|
|
|
|
- result = answer
|
|
|
|
|
- else:
|
|
|
|
|
- result = matches[0]
|
|
|
|
|
-
|
|
|
|
|
- print(f"{result=}")
|
|
|
|
|
- return result
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def batch_local_llama(client, prompts, model, max_workers=8):
|
|
|
|
|
|
|
+def local_llama(client, api_key, prompts, model, max_workers=8):
|
|
|
"""
|
|
"""
|
|
|
Process multiple prompts in parallel using the vllm server.
|
|
Process multiple prompts in parallel using the vllm server.
|
|
|
|
|
|
|
@@ -55,10 +26,19 @@ def batch_local_llama(client, prompts, model, max_workers=8):
|
|
|
Returns:
|
|
Returns:
|
|
|
List of results in the same order as prompts
|
|
List of results in the same order as prompts
|
|
|
"""
|
|
"""
|
|
|
|
|
+
|
|
|
SYSTEM_PROMPT = (
|
|
SYSTEM_PROMPT = (
|
|
|
- "You are a text to SQL query translator. Using the SQLite DB Schema "
|
|
|
|
|
- "and the External Knowledge, translate the following text question "
|
|
|
|
|
- "into a SQLite SQL select statement."
|
|
|
|
|
|
|
+ (
|
|
|
|
|
+ "You are a text to SQL query translator. Using the SQLite DB Schema "
|
|
|
|
|
+ "and the External Knowledge, translate the following text question "
|
|
|
|
|
+ "into a SQLite SQL select statement."
|
|
|
|
|
+ )
|
|
|
|
|
+ if api_key == "huggingface"
|
|
|
|
|
+ else (
|
|
|
|
|
+ "You are a text to SQL query translator. Using the SQLite DB Schema "
|
|
|
|
|
+ "and the External Knowledge, generate the step-by-step reasoning and "
|
|
|
|
|
+ "then the final SQLite SQL select statement from the text question."
|
|
|
|
|
+ )
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
def process_single_prompt(prompt):
|
|
def process_single_prompt(prompt):
|
|
@@ -88,7 +68,7 @@ def batch_local_llama(client, prompts, model, max_workers=8):
|
|
|
return f"error:{e}"
|
|
return f"error:{e}"
|
|
|
|
|
|
|
|
print(
|
|
print(
|
|
|
- f"batch_local_llama: Processing {len(prompts)} prompts with {model=} "
|
|
|
|
|
|
|
+ f"local_llama: Processing {len(prompts)} prompts with {model=} "
|
|
|
f"using {max_workers} workers"
|
|
f"using {max_workers} workers"
|
|
|
)
|
|
)
|
|
|
results = []
|
|
results = []
|
|
@@ -231,103 +211,61 @@ def generate_combined_prompts_one(db_path, question, knowledge=None):
|
|
|
return combined_prompts
|
|
return combined_prompts
|
|
|
|
|
|
|
|
|
|
|
|
|
-def cloud_llama(client, api_key, model, prompt, max_tokens, temperature, stop):
|
|
|
|
|
-
|
|
|
|
|
- SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
|
|
|
|
|
- try:
|
|
|
|
|
- messages = [
|
|
|
|
|
- {"content": SYSTEM_PROMPT, "role": "system"},
|
|
|
|
|
- {"role": "user", "content": prompt},
|
|
|
|
|
- ]
|
|
|
|
|
- final_max_tokens = len(messages) + MAX_NEW_TOKENS
|
|
|
|
|
- response = client.chat.completions.create(
|
|
|
|
|
- model=model,
|
|
|
|
|
- messages=messages,
|
|
|
|
|
- temperature=0,
|
|
|
|
|
- max_completion_tokens=final_max_tokens,
|
|
|
|
|
- )
|
|
|
|
|
- answer = response.completion_message.content.text
|
|
|
|
|
-
|
|
|
|
|
- pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
|
|
|
|
|
- matches = pattern.findall(answer)
|
|
|
|
|
- if matches != []:
|
|
|
|
|
- result = matches[0]
|
|
|
|
|
- else:
|
|
|
|
|
- result = answer
|
|
|
|
|
-
|
|
|
|
|
- print(result)
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- result = "error:{}".format(e)
|
|
|
|
|
- print(f"{result=}")
|
|
|
|
|
- return result
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def collect_response_from_llama(
|
|
|
|
|
- db_path_list, question_list, api_key, model, knowledge_list=None
|
|
|
|
|
-):
|
|
|
|
|
|
|
+def cloud_llama(client, api_key, model, prompts):
|
|
|
"""
|
|
"""
|
|
|
- :param db_path: str
|
|
|
|
|
- :param question_list: []
|
|
|
|
|
- :return: dict of responses
|
|
|
|
|
- """
|
|
|
|
|
- response_list = []
|
|
|
|
|
-
|
|
|
|
|
- if api_key in ["huggingface", "finetuned"]:
|
|
|
|
|
- from openai import OpenAI
|
|
|
|
|
|
|
+ Process multiple prompts sequentially using the cloud API, showing progress with tqdm.
|
|
|
|
|
|
|
|
- openai_api_key = "EMPTY"
|
|
|
|
|
- openai_api_base = "http://localhost:8000/v1"
|
|
|
|
|
|
|
+ Args:
|
|
|
|
|
+ client: LlamaAPIClient
|
|
|
|
|
+ api_key: API key
|
|
|
|
|
+ model: Model name
|
|
|
|
|
+ prompts: List of prompts to process (or a single prompt as string)
|
|
|
|
|
|
|
|
- client = OpenAI(
|
|
|
|
|
- api_key=openai_api_key,
|
|
|
|
|
- base_url=openai_api_base,
|
|
|
|
|
- )
|
|
|
|
|
- else:
|
|
|
|
|
- client = LlamaAPIClient()
|
|
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ List of results if prompts is a list, or a single result if prompts is a string
|
|
|
|
|
+ """
|
|
|
|
|
+ SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
|
|
|
|
|
|
|
|
- for i, question in enumerate(question_list):
|
|
|
|
|
- print(
|
|
|
|
|
- "--------------------- processing question #{}---------------------".format(
|
|
|
|
|
- i + 1
|
|
|
|
|
- )
|
|
|
|
|
- )
|
|
|
|
|
- print("the question is: {}".format(question))
|
|
|
|
|
|
|
+ # Handle the case where a single prompt is passed
|
|
|
|
|
+ single_prompt = False
|
|
|
|
|
+ if isinstance(prompts, str):
|
|
|
|
|
+ prompts = [prompts]
|
|
|
|
|
+ single_prompt = True
|
|
|
|
|
|
|
|
- if knowledge_list:
|
|
|
|
|
- cur_prompt = generate_combined_prompts_one(
|
|
|
|
|
- db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
|
|
|
|
|
- )
|
|
|
|
|
- else:
|
|
|
|
|
- cur_prompt = generate_combined_prompts_one(
|
|
|
|
|
- db_path=db_path_list[i], question=question
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- if api_key in ["huggingface", "finetuned"]:
|
|
|
|
|
- plain_result = local_llama(client=client, prompt=cur_prompt, model=model)
|
|
|
|
|
- else:
|
|
|
|
|
|
|
+ results = []
|
|
|
|
|
|
|
|
- plain_result = cloud_llama(
|
|
|
|
|
- client=client,
|
|
|
|
|
- api_key=api_key,
|
|
|
|
|
|
|
+ # Process each prompt sequentially with tqdm progress bar
|
|
|
|
|
+ for prompt in tqdm(prompts, desc="Processing prompts", unit="prompt"):
|
|
|
|
|
+ try:
|
|
|
|
|
+ messages = [
|
|
|
|
|
+ {"content": SYSTEM_PROMPT, "role": "system"},
|
|
|
|
|
+ {"role": "user", "content": prompt},
|
|
|
|
|
+ ]
|
|
|
|
|
+ final_max_tokens = len(messages) + MAX_NEW_TOKENS
|
|
|
|
|
+ response = client.chat.completions.create(
|
|
|
model=model,
|
|
model=model,
|
|
|
- prompt=cur_prompt,
|
|
|
|
|
- max_tokens=10240,
|
|
|
|
|
|
|
+ messages=messages,
|
|
|
temperature=0,
|
|
temperature=0,
|
|
|
- stop=["--", "\n\n", ";", "#"],
|
|
|
|
|
|
|
+ max_completion_tokens=final_max_tokens,
|
|
|
)
|
|
)
|
|
|
- if isinstance(plain_result, str):
|
|
|
|
|
- sql = plain_result
|
|
|
|
|
- else:
|
|
|
|
|
- sql = "SELECT" + plain_result["choices"][0]["text"]
|
|
|
|
|
|
|
+ answer = response.completion_message.content.text
|
|
|
|
|
|
|
|
- # responses_dict[i] = sql
|
|
|
|
|
- db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
|
|
|
|
|
- sql = (
|
|
|
|
|
- sql + "\t----- bird -----\t" + db_id
|
|
|
|
|
- ) # to avoid unpredicted \t appearing in codex results
|
|
|
|
|
- response_list.append(sql)
|
|
|
|
|
|
|
+ pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
|
|
|
|
|
+ matches = pattern.findall(answer)
|
|
|
|
|
+ if matches != []:
|
|
|
|
|
+ result = matches[0]
|
|
|
|
|
+ else:
|
|
|
|
|
+ result = answer
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ result = "error:{}".format(e)
|
|
|
|
|
+ print(f"{result=}")
|
|
|
|
|
|
|
|
- return response_list
|
|
|
|
|
|
|
+ results.append(result)
|
|
|
|
|
+
|
|
|
|
|
+ # Return a single result if input was a single prompt
|
|
|
|
|
+ if single_prompt:
|
|
|
|
|
+ return results[0]
|
|
|
|
|
+ return results
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch_collect_response_from_llama(
|
|
def batch_collect_response_from_llama(
|
|
@@ -376,25 +314,24 @@ def batch_collect_response_from_llama(
|
|
|
print(f"Generated {len(prompts)} prompts for batch processing")
|
|
print(f"Generated {len(prompts)} prompts for batch processing")
|
|
|
|
|
|
|
|
# Process prompts in parallel
|
|
# Process prompts in parallel
|
|
|
- if api_key in ["huggingface", "finetuned"]:
|
|
|
|
|
- results = batch_local_llama(
|
|
|
|
|
- client=client, prompts=prompts, model=model, max_workers=batch_size
|
|
|
|
|
|
|
+ if api_key in [
|
|
|
|
|
+ "huggingface",
|
|
|
|
|
+ "finetuned",
|
|
|
|
|
+ ]: # running vllm on multiple GPUs to see best performance
|
|
|
|
|
+ results = local_llama(
|
|
|
|
|
+ client=client,
|
|
|
|
|
+ api_key=api_key,
|
|
|
|
|
+ prompts=prompts,
|
|
|
|
|
+ model=model,
|
|
|
|
|
+ max_workers=batch_size,
|
|
|
)
|
|
)
|
|
|
else:
|
|
else:
|
|
|
- # For cloud API, we could implement a batch version of cloud_llama if needed
|
|
|
|
|
- # For now, just process sequentially
|
|
|
|
|
- results = []
|
|
|
|
|
- for prompt in prompts:
|
|
|
|
|
- plain_result = cloud_llama(
|
|
|
|
|
- client=client,
|
|
|
|
|
- api_key=api_key,
|
|
|
|
|
- model=model,
|
|
|
|
|
- prompt=prompt,
|
|
|
|
|
- max_tokens=10240,
|
|
|
|
|
- temperature=0,
|
|
|
|
|
- stop=["--", "\n\n", ";", "#"],
|
|
|
|
|
- )
|
|
|
|
|
- results.append(plain_result)
|
|
|
|
|
|
|
+ results = cloud_llama(
|
|
|
|
|
+ client=client,
|
|
|
|
|
+ api_key=api_key,
|
|
|
|
|
+ model=model,
|
|
|
|
|
+ prompts=prompts,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
# Format results
|
|
# Format results
|
|
|
response_list = []
|
|
response_list = []
|
|
@@ -471,9 +408,6 @@ if __name__ == "__main__":
|
|
|
default=8,
|
|
default=8,
|
|
|
help="Number of parallel requests for batch processing",
|
|
help="Number of parallel requests for batch processing",
|
|
|
)
|
|
)
|
|
|
- args_parser.add_argument(
|
|
|
|
|
- "--use_batch", type=str, default="True", help="Whether to use batch processing"
|
|
|
|
|
- )
|
|
|
|
|
args = args_parser.parse_args()
|
|
args = args_parser.parse_args()
|
|
|
|
|
|
|
|
if args.api_key not in ["huggingface", "finetuned"]:
|
|
if args.api_key not in ["huggingface", "finetuned"]:
|
|
@@ -488,62 +422,36 @@ if __name__ == "__main__":
|
|
|
temperature=0,
|
|
temperature=0,
|
|
|
)
|
|
)
|
|
|
answer = response.completion_message.content.text
|
|
answer = response.completion_message.content.text
|
|
|
-
|
|
|
|
|
- print(f"{answer=}")
|
|
|
|
|
except Exception as exception:
|
|
except Exception as exception:
|
|
|
print(f"{exception=}")
|
|
print(f"{exception=}")
|
|
|
exit(1)
|
|
exit(1)
|
|
|
|
|
|
|
|
eval_data = json.load(open(args.eval_path, "r"))
|
|
eval_data = json.load(open(args.eval_path, "r"))
|
|
|
- # '''for debug'''
|
|
|
|
|
- # eval_data = eval_data[:3]
|
|
|
|
|
- # '''for debug'''
|
|
|
|
|
|
|
|
|
|
question_list, db_path_list, knowledge_list = decouple_question_schema(
|
|
question_list, db_path_list, knowledge_list = decouple_question_schema(
|
|
|
datasets=eval_data, db_root_path=args.db_root_path
|
|
datasets=eval_data, db_root_path=args.db_root_path
|
|
|
)
|
|
)
|
|
|
assert len(question_list) == len(db_path_list) == len(knowledge_list)
|
|
assert len(question_list) == len(db_path_list) == len(knowledge_list)
|
|
|
|
|
|
|
|
- use_batch = args.use_batch.lower() == "true"
|
|
|
|
|
-
|
|
|
|
|
- if use_batch:
|
|
|
|
|
- print(f"Using batch processing with batch_size={args.batch_size}")
|
|
|
|
|
- if args.use_knowledge == "True":
|
|
|
|
|
- responses = batch_collect_response_from_llama(
|
|
|
|
|
- db_path_list=db_path_list,
|
|
|
|
|
- question_list=question_list,
|
|
|
|
|
- api_key=args.api_key,
|
|
|
|
|
- model=args.model,
|
|
|
|
|
- knowledge_list=knowledge_list,
|
|
|
|
|
- batch_size=args.batch_size,
|
|
|
|
|
- )
|
|
|
|
|
- else:
|
|
|
|
|
- responses = batch_collect_response_from_llama(
|
|
|
|
|
- db_path_list=db_path_list,
|
|
|
|
|
- question_list=question_list,
|
|
|
|
|
- api_key=args.api_key,
|
|
|
|
|
- model=args.model,
|
|
|
|
|
- knowledge_list=None,
|
|
|
|
|
- batch_size=args.batch_size,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ print(f"Using batch processing with batch_size={args.batch_size}")
|
|
|
|
|
+ if args.use_knowledge == "True":
|
|
|
|
|
+ responses = batch_collect_response_from_llama(
|
|
|
|
|
+ db_path_list=db_path_list,
|
|
|
|
|
+ question_list=question_list,
|
|
|
|
|
+ api_key=args.api_key,
|
|
|
|
|
+ model=args.model,
|
|
|
|
|
+ knowledge_list=knowledge_list,
|
|
|
|
|
+ batch_size=args.batch_size,
|
|
|
|
|
+ )
|
|
|
else:
|
|
else:
|
|
|
- print("Using sequential processing")
|
|
|
|
|
- if args.use_knowledge == "True":
|
|
|
|
|
- responses = collect_response_from_llama(
|
|
|
|
|
- db_path_list=db_path_list,
|
|
|
|
|
- question_list=question_list,
|
|
|
|
|
- api_key=args.api_key,
|
|
|
|
|
- model=args.model,
|
|
|
|
|
- knowledge_list=knowledge_list,
|
|
|
|
|
- )
|
|
|
|
|
- else:
|
|
|
|
|
- responses = collect_response_from_llama(
|
|
|
|
|
- db_path_list=db_path_list,
|
|
|
|
|
- question_list=question_list,
|
|
|
|
|
- api_key=args.api_key,
|
|
|
|
|
- model=args.model,
|
|
|
|
|
- knowledge_list=None,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ responses = batch_collect_response_from_llama(
|
|
|
|
|
+ db_path_list=db_path_list,
|
|
|
|
|
+ question_list=question_list,
|
|
|
|
|
+ api_key=args.api_key,
|
|
|
|
|
+ model=args.model,
|
|
|
|
|
+ knowledge_list=None,
|
|
|
|
|
+ batch_size=args.batch_size,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
output_name = args.data_output_path + "predict_" + args.mode + ".json"
|
|
output_name = args.data_output_path + "predict_" + args.mode + ".json"
|
|
|
|
|
|