|
@@ -12,6 +12,7 @@ from typing import List
|
|
|
|
|
|
|
|
from datasets import Dataset
|
|
from datasets import Dataset
|
|
|
from func_timeout import func_timeout, FunctionTimedOut
|
|
from func_timeout import func_timeout, FunctionTimedOut
|
|
|
|
|
+from together import Together
|
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoTokenizer
|
|
|
from transformers.trainer_utils import get_last_checkpoint
|
|
from transformers.trainer_utils import get_last_checkpoint
|
|
@@ -267,8 +268,16 @@ def ensemble_n_gram_reward_func(completions, answer, **kwargs):
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
rewards = []
|
|
rewards = []
|
|
|
-
|
|
|
|
|
- for completion, gt in zip(completions, answer):
|
|
|
|
|
|
|
+ questions = kwargs.get("question")
|
|
|
|
|
+ evidences = kwargs.get("evidence")
|
|
|
|
|
+
|
|
|
|
|
+ for completion, gt, question, evidence in zip(
|
|
|
|
|
+ completions, answer, questions, evidences
|
|
|
|
|
+ ):
|
|
|
|
|
+ # print(f">>>>>ensemble_n_gram_reward_func: {gt=}")
|
|
|
|
|
+ # print(f">>>>>ensemble_n_gram_reward_func: {completion=}")
|
|
|
|
|
+ # print(f">>>>>ensemble_n_gram_reward_func: {question=}")
|
|
|
|
|
+ # print(f">>>>>ensemble_n_gram_reward_func: {evidence=}")
|
|
|
try:
|
|
try:
|
|
|
match = re.search(r"<answer>(.*?)<\/answer>", completion)
|
|
match = re.search(r"<answer>(.*?)<\/answer>", completion)
|
|
|
if match is None:
|
|
if match is None:
|
|
@@ -285,6 +294,7 @@ def ensemble_n_gram_reward_func(completions, answer, **kwargs):
|
|
|
|
|
|
|
|
# Average the scores to get the final ensemble reward
|
|
# Average the scores to get the final ensemble reward
|
|
|
average_jaccard = (jaccard_1 + jaccard_2 + jaccard_3) / 3.0
|
|
average_jaccard = (jaccard_1 + jaccard_2 + jaccard_3) / 3.0
|
|
|
|
|
+ print(f"{average_jaccard=}")
|
|
|
rewards.append(average_jaccard)
|
|
rewards.append(average_jaccard)
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
rewards.append(0.0)
|
|
rewards.append(0.0)
|
|
@@ -293,6 +303,118 @@ def ensemble_n_gram_reward_func(completions, answer, **kwargs):
|
|
|
return rewards
|
|
return rewards
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def llm_as_a_judge_reward_func(completions, answer, **kwargs):
|
|
|
|
|
+ """
|
|
|
|
|
+ Use Llama 3.3 70b as a judge to evaluate the quality of the generated SQL statements by comparing them to the ground truth answers.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ completions (list[str]): Generated outputs
|
|
|
|
|
+ answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json)
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ list[float]: Reward scores
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ rewards = []
|
|
|
|
|
+
|
|
|
|
|
+ client = Together()
|
|
|
|
|
+ PROMPT_TEMPLATE = """
|
|
|
|
|
+You are an experienced database expert. Your task is to evaluate a generated SQL query by comparing it
|
|
|
|
|
+to the ground truth (gold) query and then assign a score between 0.0 and 2.0. A higher score indicates
|
|
|
|
|
+the predicted query is more correct, while a score of 0.0 means it is completely incorrect.
|
|
|
|
|
+
|
|
|
|
|
+Follow these evaluation rules strictly:
|
|
|
|
|
+
|
|
|
|
|
+1. SELECT Clause:
|
|
|
|
|
+• Only select columns that are mentioned in the user’s question.
|
|
|
|
|
+• Do not include unnecessary columns or values.
|
|
|
|
|
+
|
|
|
|
|
+2. Aggregation (MAX/MIN):
|
|
|
|
|
+• Always perform JOINs before applying MAX() or MIN().
|
|
|
|
|
+
|
|
|
|
|
+3. ORDER BY with Distinct Values:
|
|
|
|
|
+• Use a GROUP BY <column> before an ORDER BY <column> ASC|DESC to ensure
|
|
|
|
|
+distinct values.
|
|
|
|
|
+
|
|
|
|
|
+4. Handling NULLs:
|
|
|
|
|
+• If a column may contain NULL values (indicated by "None" in value examples
|
|
|
|
|
+or explicitly mentioned), include a JOIN or a WHERE <column> IS NOT NULL
|
|
|
|
|
+clause.
|
|
|
|
|
+
|
|
|
|
|
+5. FROM/JOIN Clauses:
|
|
|
|
|
+• Only include the tables essential for answering the question.
|
|
|
|
|
+
|
|
|
|
|
+6. Strictly Follow Hints:
|
|
|
|
|
+• Adhere to all hints provided with the question.
|
|
|
|
|
+
|
|
|
|
|
+7. Thorough Question Analysis:
|
|
|
|
|
+• Ensure all conditions and requirements mentioned in the question are ad-
|
|
|
|
|
+dressed.
|
|
|
|
|
+
|
|
|
|
|
+8. DISTINCT Keyword:
|
|
|
|
|
+• Use SELECT DISTINCTwhen the question requires unique values (e.g., IDs, URLs)
|
|
|
|
|
+or when column statistics (Value Statics) indicate its necessity.
|
|
|
|
|
+
|
|
|
|
|
+9. Column Selection:
|
|
|
|
|
+• Carefully analyze column descriptions and hints to choose the correct column
|
|
|
|
|
+when similar columns exist across tables.
|
|
|
|
|
+
|
|
|
|
|
+10. String Concatenation:
|
|
|
|
|
+• Do not use any string concatenation methods (e.g., || ’ ’ ||) in the SELECT
|
|
|
|
|
+clause.
|
|
|
|
|
+
|
|
|
|
|
+11. JOIN Preference:
|
|
|
|
|
+• Prefer using INNER JOINover nested SELECT statements.
|
|
|
|
|
+
|
|
|
|
|
+12. Date Processing:
|
|
|
|
|
+• Use STRFTIME()for any date manipulations (e.g., STRFTIME(’%Y’, SOMETIME)to
|
|
|
|
|
+extract the year).
|
|
|
|
|
+
|
|
|
|
|
+You are provided with the following inputs:
|
|
|
|
|
+• Question: {QUESTION}
|
|
|
|
|
+• Hint: {HINT}
|
|
|
|
|
+• Gold Query: {GOLD_QUERY}
|
|
|
|
|
+• Predicted Query: {PREDICTED_QUERY}
|
|
|
|
|
+
|
|
|
|
|
+Based on the above, return a single numeric score between 0.0 and 2.0 that reflects how
|
|
|
|
|
+correct the predicted query is compared to the gold query. Respond with only the score and
|
|
|
|
|
+no additional explanation.
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+ questions = kwargs.get("question")
|
|
|
|
|
+ evidences = kwargs.get("evidence")
|
|
|
|
|
+ for completion, gt, question, evidence in zip(
|
|
|
|
|
+ completions, answer, questions, evidences
|
|
|
|
|
+ ):
|
|
|
|
|
+ try:
|
|
|
|
|
+ match = re.search(r"<answer>(.*?)<\/answer>", completion)
|
|
|
|
|
+ if match is None:
|
|
|
|
|
+ rewards.append(0.0)
|
|
|
|
|
+ log_reward(
|
|
|
|
|
+ "llm_as_a_judge_reward_func 0 - no answer tag found", completion, gt
|
|
|
|
|
+ )
|
|
|
|
|
+ continue
|
|
|
|
|
+ # Extract the "answer" part from the completion
|
|
|
|
|
+ predicted_sql = match.group(1).strip()
|
|
|
|
|
+ prompt = PROMPT_TEMPLATE.format(
|
|
|
|
|
+ QUESTION=question,
|
|
|
|
|
+ HINT=evidence,
|
|
|
|
|
+ GOLD_QUERY=gt,
|
|
|
|
|
+ PREDICTED_QUERY=predicted_sql,
|
|
|
|
|
+ )
|
|
|
|
|
+ response = client.chat.completions.create(
|
|
|
|
|
+ model="meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
|
|
|
|
+ messages=[{"role": "user", "content": prompt}],
|
|
|
|
|
+ temperature=0,
|
|
|
|
|
+ )
|
|
|
|
|
+ rewards.append(float(response.choices[0].message.content))
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ rewards.append(0.0)
|
|
|
|
|
+ log_reward(f"llm_as_a_judge_reward_func 0 - exception {e=}", completion, gt)
|
|
|
|
|
+
|
|
|
|
|
+ return rewards
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def get_checkpoint(training_args: GRPOConfig):
|
|
def get_checkpoint(training_args: GRPOConfig):
|
|
|
last_checkpoint = None
|
|
last_checkpoint = None
|
|
|
if os.path.isdir(training_args.output_dir):
|
|
if os.path.isdir(training_args.output_dir):
|
|
@@ -332,7 +454,10 @@ def generate_schema_prompt(db_path, num_rows=None):
|
|
|
cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
|
|
cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
|
|
|
column_names = [description[0] for description in cursor.description]
|
|
column_names = [description[0] for description in cursor.description]
|
|
|
values = cursor.fetchall()
|
|
values = cursor.fetchall()
|
|
|
- rows_prompt = nice_look_table(column_names=column_names, values=values)
|
|
|
|
|
|
|
+ # Format the rows as a simple table representation
|
|
|
|
|
+ rows_prompt = "\n".join(
|
|
|
|
|
+ "\t".join(str(val) for val in row) for row in values
|
|
|
|
|
+ )
|
|
|
verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
|
|
verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
|
|
|
num_rows, cur_table, num_rows, rows_prompt
|
|
num_rows, cur_table, num_rows, rows_prompt
|
|
|
)
|
|
)
|
|
@@ -406,7 +531,9 @@ def grpo_function(
|
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
|
{"role": "user", "content": prompt},
|
|
{"role": "user", "content": prompt},
|
|
|
{"role": "assistant", "content": SQL + "\t----- bird -----\t" + db_id},
|
|
{"role": "assistant", "content": SQL + "\t----- bird -----\t" + db_id},
|
|
|
- ]
|
|
|
|
|
|
|
+ ],
|
|
|
|
|
+ "question": question,
|
|
|
|
|
+ "evidence": external_knowledge,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
ds.append(example)
|
|
ds.append(example)
|
|
@@ -414,7 +541,9 @@ def grpo_function(
|
|
|
dataset_dict = {key: [d[key] for d in ds] for key in ds[0]}
|
|
dataset_dict = {key: [d[key] for d in ds] for key in ds[0]}
|
|
|
dataset = Dataset.from_dict(dataset_dict)
|
|
dataset = Dataset.from_dict(dataset_dict)
|
|
|
|
|
|
|
|
- def generate_r1_prompt(system_prompt, user_prompt, ground_truth):
|
|
|
|
|
|
|
+ def generate_r1_prompt(
|
|
|
|
|
+ system_prompt, user_prompt, ground_truth, question, evidence
|
|
|
|
|
+ ):
|
|
|
r1_prefix = [
|
|
r1_prefix = [
|
|
|
{
|
|
{
|
|
|
"role": "system",
|
|
"role": "system",
|
|
@@ -428,6 +557,8 @@ def grpo_function(
|
|
|
r1_prefix, tokenize=False, continue_final_message=True
|
|
r1_prefix, tokenize=False, continue_final_message=True
|
|
|
),
|
|
),
|
|
|
"answer": ground_truth,
|
|
"answer": ground_truth,
|
|
|
|
|
+ "question": question,
|
|
|
|
|
+ "evidence": evidence,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
# convert our dataset to the r1 prompt
|
|
# convert our dataset to the r1 prompt
|
|
@@ -436,6 +567,8 @@ def grpo_function(
|
|
|
x["messages"][0]["content"],
|
|
x["messages"][0]["content"],
|
|
|
x["messages"][1]["content"],
|
|
x["messages"][1]["content"],
|
|
|
x["messages"][2]["content"],
|
|
x["messages"][2]["content"],
|
|
|
|
|
+ x["question"],
|
|
|
|
|
+ x["evidence"],
|
|
|
),
|
|
),
|
|
|
remove_columns=dataset.column_names, # Remove original columns to avoid conflicts with apply_chat_template
|
|
remove_columns=dataset.column_names, # Remove original columns to avoid conflicts with apply_chat_template
|
|
|
)
|
|
)
|