Преглед на файлове

added llm as a judge reward func

Jeff Tang преди 4 седмици
родител
ревизия
57c05170eb

+ 3 - 3
end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/grpo-llama323b-text2sql.yaml

@@ -5,13 +5,13 @@ torch_dtype: bfloat16
 attn_implementation: flash_attention_2
 attn_implementation: flash_attention_2
 bf16: true
 bf16: true
 tf32: true
 tf32: true
-output_dir: runs/llama-3.2-3b-grpo-text2sql-alltrain-lr5-ng8
+output_dir: runs/llama-3.2-3b-grpo-text2sql-3rewards-6gpu
 
 
 # Lora Arguments
 # Lora Arguments
 # No LoRA is used here
 # No LoRA is used here
 
 
 # Training arguments
 # Training arguments
-max_steps: 500 # 1000 #500
+max_steps: 750 # 1000 #500
 per_device_train_batch_size: 1
 per_device_train_batch_size: 1
 gradient_accumulation_steps: 8
 gradient_accumulation_steps: 8
 gradient_checkpointing: true
 gradient_checkpointing: true
@@ -46,7 +46,7 @@ reward_weights: [1.0, 3.0, 1.0]
 # **Similarity-focused: `[1.0, 2.0, 2.0]`**
 # **Similarity-focused: `[1.0, 2.0, 2.0]`**
 # *   Higher weight on N-gram similarity
 # *   Higher weight on N-gram similarity
 # *   Useful if execution often fails initially
 # *   Useful if execution often fails initially
-
+# final_reward = format_reward*1.0 + execution_reward*3.0 + ngram_reward*1.0
 
 
 vllm_device: "cuda:0" # use vLLM for generation and DeepSpeed for distributed training.
 vllm_device: "cuda:0" # use vLLM for generation and DeepSpeed for distributed training.
 # Set the num_processes to the number of GPUs you have -
 # Set the num_processes to the number of GPUs you have -

+ 138 - 5
end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/grpo_text2sql.py

@@ -12,6 +12,7 @@ from typing import List
 
 
 from datasets import Dataset
 from datasets import Dataset
 from func_timeout import func_timeout, FunctionTimedOut
 from func_timeout import func_timeout, FunctionTimedOut
+from together import Together
 from tqdm import tqdm
 from tqdm import tqdm
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
 from transformers.trainer_utils import get_last_checkpoint
 from transformers.trainer_utils import get_last_checkpoint
@@ -267,8 +268,16 @@ def ensemble_n_gram_reward_func(completions, answer, **kwargs):
     """
     """
 
 
     rewards = []
     rewards = []
-
-    for completion, gt in zip(completions, answer):
+    questions = kwargs.get("question")
+    evidences = kwargs.get("evidence")
+
+    for completion, gt, question, evidence in zip(
+        completions, answer, questions, evidences
+    ):
+        # print(f">>>>>ensemble_n_gram_reward_func: {gt=}")
+        # print(f">>>>>ensemble_n_gram_reward_func: {completion=}")
+        # print(f">>>>>ensemble_n_gram_reward_func: {question=}")
+        # print(f">>>>>ensemble_n_gram_reward_func: {evidence=}")
         try:
         try:
             match = re.search(r"<answer>(.*?)<\/answer>", completion)
             match = re.search(r"<answer>(.*?)<\/answer>", completion)
             if match is None:
             if match is None:
@@ -285,6 +294,7 @@ def ensemble_n_gram_reward_func(completions, answer, **kwargs):
 
 
             # Average the scores to get the final ensemble reward
             # Average the scores to get the final ensemble reward
             average_jaccard = (jaccard_1 + jaccard_2 + jaccard_3) / 3.0
             average_jaccard = (jaccard_1 + jaccard_2 + jaccard_3) / 3.0
+            print(f"{average_jaccard=}")
             rewards.append(average_jaccard)
             rewards.append(average_jaccard)
         except Exception as e:
         except Exception as e:
             rewards.append(0.0)
             rewards.append(0.0)
@@ -293,6 +303,118 @@ def ensemble_n_gram_reward_func(completions, answer, **kwargs):
     return rewards
     return rewards
 
 
 
 
+def llm_as_a_judge_reward_func(completions, answer, **kwargs):
+    """
+    Use Llama 3.3 70b as a judge to evaluate the quality of the generated SQL statements by comparing them to the ground truth answers.
+
+    Args:
+        completions (list[str]): Generated outputs
+        answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json)
+
+    Returns:
+        list[float]: Reward scores
+    """
+
+    rewards = []
+
+    client = Together()
+    PROMPT_TEMPLATE = """
+You are an experienced database expert. Your task is to evaluate a generated SQL query by comparing it
+to the ground truth (gold) query and then assign a score between 0.0 and 2.0. A higher score indicates
+the predicted query is more correct, while a score of 0.0 means it is completely incorrect.
+
+Follow these evaluation rules strictly:
+
+1. SELECT Clause:
+• Only select columns that are mentioned in the user’s question.
+• Do not include unnecessary columns or values.
+
+2. Aggregation (MAX/MIN):
+• Always perform JOINs before applying MAX() or MIN().
+
+3. ORDER BY with Distinct Values:
+• Use a GROUP BY <column> before an ORDER BY <column> ASC|DESC to ensure
+distinct values.
+
+4. Handling NULLs:
+• If a column may contain NULL values (indicated by "None" in value examples
+or explicitly mentioned), include a JOIN or a WHERE <column> IS NOT NULL
+clause.
+
+5. FROM/JOIN Clauses:
+• Only include the tables essential for answering the question.
+
+6. Strictly Follow Hints:
+• Adhere to all hints provided with the question.
+
+7. Thorough Question Analysis:
+• Ensure all conditions and requirements mentioned in the question are ad-
+dressed.
+
+8. DISTINCT Keyword:
+• Use SELECT DISTINCTwhen the question requires unique values (e.g., IDs, URLs)
+or when column statistics (Value Statics) indicate its necessity.
+
+9. Column Selection:
+• Carefully analyze column descriptions and hints to choose the correct column
+when similar columns exist across tables.
+
+10. String Concatenation:
+• Do not use any string concatenation methods (e.g., || ’ ’ ||) in the SELECT
+clause.
+
+11. JOIN Preference:
+• Prefer using INNER JOINover nested SELECT statements.
+
+12. Date Processing:
+• Use STRFTIME()for any date manipulations (e.g., STRFTIME(’%Y’, SOMETIME)to
+extract the year).
+
+You are provided with the following inputs:
+• Question: {QUESTION}
+• Hint: {HINT}
+• Gold Query: {GOLD_QUERY}
+• Predicted Query: {PREDICTED_QUERY}
+
+Based on the above, return a single numeric score between 0.0 and 2.0 that reflects how
+correct the predicted query is compared to the gold query. Respond with only the score and
+no additional explanation.
+"""
+
+    questions = kwargs.get("question")
+    evidences = kwargs.get("evidence")
+    for completion, gt, question, evidence in zip(
+        completions, answer, questions, evidences
+    ):
+        try:
+            match = re.search(r"<answer>(.*?)<\/answer>", completion)
+            if match is None:
+                rewards.append(0.0)
+                log_reward(
+                    "llm_as_a_judge_reward_func 0 - no answer tag found", completion, gt
+                )
+                continue
+            # Extract the "answer" part from the completion
+            predicted_sql = match.group(1).strip()
+            prompt = PROMPT_TEMPLATE.format(
+                QUESTION=question,
+                HINT=evidence,
+                GOLD_QUERY=gt,
+                PREDICTED_QUERY=predicted_sql,
+            )
+            response = client.chat.completions.create(
+                model="meta-llama/Llama-3.3-70B-Instruct-Turbo",
+                messages=[{"role": "user", "content": prompt}],
+                temperature=0,
+            )
+            rewards.append(float(response.choices[0].message.content))
+        except Exception as e:
+            rewards.append(0.0)
+            log_reward(f"llm_as_a_judge_reward_func 0 - exception {e=}", completion, gt)
+
+    return rewards
+
+
 def get_checkpoint(training_args: GRPOConfig):
 def get_checkpoint(training_args: GRPOConfig):
     last_checkpoint = None
     last_checkpoint = None
     if os.path.isdir(training_args.output_dir):
     if os.path.isdir(training_args.output_dir):
@@ -332,7 +454,10 @@ def generate_schema_prompt(db_path, num_rows=None):
             cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
             cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
             column_names = [description[0] for description in cursor.description]
             column_names = [description[0] for description in cursor.description]
             values = cursor.fetchall()
             values = cursor.fetchall()
-            rows_prompt = nice_look_table(column_names=column_names, values=values)
+            # Format the rows as a simple table representation
+            rows_prompt = "\n".join(
+                "\t".join(str(val) for val in row) for row in values
+            )
             verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
             verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
                 num_rows, cur_table, num_rows, rows_prompt
                 num_rows, cur_table, num_rows, rows_prompt
             )
             )
@@ -406,7 +531,9 @@ def grpo_function(
                 {"role": "system", "content": SYSTEM_PROMPT},
                 {"role": "system", "content": SYSTEM_PROMPT},
                 {"role": "user", "content": prompt},
                 {"role": "user", "content": prompt},
                 {"role": "assistant", "content": SQL + "\t----- bird -----\t" + db_id},
                 {"role": "assistant", "content": SQL + "\t----- bird -----\t" + db_id},
-            ]
+            ],
+            "question": question,
+            "evidence": external_knowledge,
         }
         }
 
 
         ds.append(example)
         ds.append(example)
@@ -414,7 +541,9 @@ def grpo_function(
     dataset_dict = {key: [d[key] for d in ds] for key in ds[0]}
     dataset_dict = {key: [d[key] for d in ds] for key in ds[0]}
     dataset = Dataset.from_dict(dataset_dict)
     dataset = Dataset.from_dict(dataset_dict)
 
 
-    def generate_r1_prompt(system_prompt, user_prompt, ground_truth):
+    def generate_r1_prompt(
+        system_prompt, user_prompt, ground_truth, question, evidence
+    ):
         r1_prefix = [
         r1_prefix = [
             {
             {
                 "role": "system",
                 "role": "system",
@@ -428,6 +557,8 @@ def grpo_function(
                 r1_prefix, tokenize=False, continue_final_message=True
                 r1_prefix, tokenize=False, continue_final_message=True
             ),
             ),
             "answer": ground_truth,
             "answer": ground_truth,
+            "question": question,
+            "evidence": evidence,
         }
         }
 
 
     # convert our dataset to the r1 prompt
     # convert our dataset to the r1 prompt
@@ -436,6 +567,8 @@ def grpo_function(
             x["messages"][0]["content"],
             x["messages"][0]["content"],
             x["messages"][1]["content"],
             x["messages"][1]["content"],
             x["messages"][2]["content"],
             x["messages"][2]["content"],
+            x["question"],
+            x["evidence"],
         ),
         ),
         remove_columns=dataset.column_names,  # Remove original columns to avoid conflicts with apply_chat_template
         remove_columns=dataset.column_names,  # Remove original columns to avoid conflicts with apply_chat_template
     )
     )

+ 1 - 0
end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/requirements.txt

@@ -11,3 +11,4 @@ trl==0.14.0
 peft==0.15.2
 peft==0.15.2
 vllm==0.7.0
 vllm==0.7.0
 func_timeout==4.3.5
 func_timeout==4.3.5
+together==1.5.26