Selaa lähdekoodia

llm as a judge running now

Jeff Tang 3 viikkoa sitten
vanhempi
commit
7edf3d8df0

+ 2 - 2
end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/grpo-llama323b-text2sql.yaml

@@ -5,7 +5,7 @@ torch_dtype: bfloat16
 attn_implementation: flash_attention_2
 bf16: true
 tf32: true
-output_dir: runs/llama-3.2-3b-grpo-text2sql-3rewards-6gpu
+output_dir: runs/llama-3.2-3b-grpo-text2sql-4rewards-6gpu
 
 # Lora Arguments
 # No LoRA is used here
@@ -29,7 +29,7 @@ use_vllm: true
 
 # Reward function weights
 # Order: [format_reward_func, execution_reward_func, ensemble_n_gram_reward_func]
-reward_weights: [1.0, 3.0, 1.0]
+reward_weights: [1.0, 3.0, 1.0, 1.0]
 # **Recommended Weight Strategy**
 # Current Setting: `[1.0, 3.0, 1.0]`**
 # *   **Format reward (1.0)**: Standard weight since format correctness is binary but essential

+ 13 - 13
end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/grpo_text2sql.py

@@ -274,10 +274,6 @@ def ensemble_n_gram_reward_func(completions, answer, **kwargs):
     for completion, gt, question, evidence in zip(
         completions, answer, questions, evidences
     ):
-        # print(f">>>>>ensemble_n_gram_reward_func: {gt=}")
-        # print(f">>>>>ensemble_n_gram_reward_func: {completion=}")
-        # print(f">>>>>ensemble_n_gram_reward_func: {question=}")
-        # print(f">>>>>ensemble_n_gram_reward_func: {evidence=}")
         try:
             match = re.search(r"<answer>(.*?)<\/answer>", completion)
             if match is None:
@@ -407,7 +403,9 @@ no additional explanation.
                 messages=[{"role": "user", "content": prompt}],
                 temperature=0,
             )
-            rewards.append(float(response.choices[0].message.content))
+            reward = float(response.choices[0].message.content)
+            print(f"llm_as_a_judge_reward_func>>> {reward=}")
+            rewards.append(reward)
         except Exception as e:
             rewards.append(0.0)
             log_reward(f"llm_as_a_judge_reward_func 0 - exception {e=}", completion, gt)
@@ -593,6 +591,7 @@ def grpo_function(
             format_reward_func,
             execution_reward_func,
             ensemble_n_gram_reward_func,
+            llm_as_a_judge_reward_func,
         ],
         args=training_args,
         train_dataset=train_dataset,
@@ -607,7 +606,7 @@ def grpo_function(
     ###############
     # Check for last checkpoint
     last_checkpoint = get_checkpoint(training_args)
-    # JT: by default training_args.resume_from_checkpoint is None
+    # by default training_args.resume_from_checkpoint is None
     if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
         logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.")
 
@@ -652,10 +651,6 @@ def grpo_function(
 def main():
     parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig))
     model_args, script_args, training_args = parser.parse_args_and_config()
-    # print("model_args", model_args)
-    # print("script_args", script_args)
-    # print("training_args", training_args)
-    # exit()
 
     # Run the main training loop
     grpo_function(model_args, script_args, training_args)
@@ -664,7 +659,12 @@ def main():
 if __name__ == "__main__":
     main()
 
-# two ways to run this script:
-# with-proxy accelerate launch --num_processes 8 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml
+# before running this script, make sure you set the following environment variable
+# so the reward of using LLM as a judge can be calculated:
+# export TOGETHER_API_KEY=<your together.ai api key>
 
-# with-proxy nohup accelerate launch --num_processes 4 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml &
+# two ways to run this script, assuming you have 6 GPUs to use for the training
+
+# with-proxy accelerate launch --num_processes 6 --gpu_ids 2,3,4,5,6,7 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml
+
+# with-proxy nohup accelerate launch --num_processes 6 --gpu_ids 2,3,4,5,6,7 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml &