|
@@ -274,10 +274,6 @@ def ensemble_n_gram_reward_func(completions, answer, **kwargs):
|
|
|
for completion, gt, question, evidence in zip(
|
|
for completion, gt, question, evidence in zip(
|
|
|
completions, answer, questions, evidences
|
|
completions, answer, questions, evidences
|
|
|
):
|
|
):
|
|
|
- # print(f">>>>>ensemble_n_gram_reward_func: {gt=}")
|
|
|
|
|
- # print(f">>>>>ensemble_n_gram_reward_func: {completion=}")
|
|
|
|
|
- # print(f">>>>>ensemble_n_gram_reward_func: {question=}")
|
|
|
|
|
- # print(f">>>>>ensemble_n_gram_reward_func: {evidence=}")
|
|
|
|
|
try:
|
|
try:
|
|
|
match = re.search(r"<answer>(.*?)<\/answer>", completion)
|
|
match = re.search(r"<answer>(.*?)<\/answer>", completion)
|
|
|
if match is None:
|
|
if match is None:
|
|
@@ -407,7 +403,9 @@ no additional explanation.
|
|
|
messages=[{"role": "user", "content": prompt}],
|
|
messages=[{"role": "user", "content": prompt}],
|
|
|
temperature=0,
|
|
temperature=0,
|
|
|
)
|
|
)
|
|
|
- rewards.append(float(response.choices[0].message.content))
|
|
|
|
|
|
|
+ reward = float(response.choices[0].message.content)
|
|
|
|
|
+ print(f"llm_as_a_judge_reward_func>>> {reward=}")
|
|
|
|
|
+ rewards.append(reward)
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
rewards.append(0.0)
|
|
rewards.append(0.0)
|
|
|
log_reward(f"llm_as_a_judge_reward_func 0 - exception {e=}", completion, gt)
|
|
log_reward(f"llm_as_a_judge_reward_func 0 - exception {e=}", completion, gt)
|
|
@@ -593,6 +591,7 @@ def grpo_function(
|
|
|
format_reward_func,
|
|
format_reward_func,
|
|
|
execution_reward_func,
|
|
execution_reward_func,
|
|
|
ensemble_n_gram_reward_func,
|
|
ensemble_n_gram_reward_func,
|
|
|
|
|
+ llm_as_a_judge_reward_func,
|
|
|
],
|
|
],
|
|
|
args=training_args,
|
|
args=training_args,
|
|
|
train_dataset=train_dataset,
|
|
train_dataset=train_dataset,
|
|
@@ -607,7 +606,7 @@ def grpo_function(
|
|
|
###############
|
|
###############
|
|
|
# Check for last checkpoint
|
|
# Check for last checkpoint
|
|
|
last_checkpoint = get_checkpoint(training_args)
|
|
last_checkpoint = get_checkpoint(training_args)
|
|
|
- # JT: by default training_args.resume_from_checkpoint is None
|
|
|
|
|
|
|
+ # by default training_args.resume_from_checkpoint is None
|
|
|
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
|
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
|
|
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.")
|
|
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.")
|
|
|
|
|
|
|
@@ -652,10 +651,6 @@ def grpo_function(
|
|
|
def main():
|
|
def main():
|
|
|
parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig))
|
|
parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig))
|
|
|
model_args, script_args, training_args = parser.parse_args_and_config()
|
|
model_args, script_args, training_args = parser.parse_args_and_config()
|
|
|
- # print("model_args", model_args)
|
|
|
|
|
- # print("script_args", script_args)
|
|
|
|
|
- # print("training_args", training_args)
|
|
|
|
|
- # exit()
|
|
|
|
|
|
|
|
|
|
# Run the main training loop
|
|
# Run the main training loop
|
|
|
grpo_function(model_args, script_args, training_args)
|
|
grpo_function(model_args, script_args, training_args)
|
|
@@ -664,7 +659,12 @@ def main():
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
main()
|
|
|
|
|
|
|
|
-# two ways to run this script:
|
|
|
|
|
-# with-proxy accelerate launch --num_processes 8 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml
|
|
|
|
|
|
|
+# before running this script, make sure you set the following environment variable
|
|
|
|
|
+# so the reward of using LLM as a judge can be calculated:
|
|
|
|
|
+# export TOGETHER_API_KEY=<your together.ai api key>
|
|
|
|
|
|
|
|
-# with-proxy nohup accelerate launch --num_processes 4 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml &
|
|
|
|
|
|
|
+# two ways to run this script, assuming you have 6 GPUs to use for the training
|
|
|
|
|
+
|
|
|
|
|
+# with-proxy accelerate launch --num_processes 6 --gpu_ids 2,3,4,5,6,7 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml
|
|
|
|
|
+
|
|
|
|
|
+# with-proxy nohup accelerate launch --num_processes 6 --gpu_ids 2,3,4,5,6,7 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml &
|