grpo_text2sql.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661
  1. import json
  2. import logging
  3. import os
  4. import random
  5. import re
  6. import sqlite3
  7. import sys
  8. from dataclasses import dataclass
  9. from datetime import datetime
  10. from typing import List
  11. from datasets import Dataset
  12. from func_timeout import func_timeout, FunctionTimedOut
  13. from together import Together
  14. from tqdm import tqdm
  15. from transformers import AutoTokenizer
  16. from transformers.trainer_utils import get_last_checkpoint
  17. from trl import get_peft_config, GRPOConfig, GRPOTrainer, ModelConfig, TrlParser
  18. os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
  19. TRAIN_JSON = "../../data/train/train.json"
  20. DB_ROOT_PATH = "../../data/train/train_databases/"
  21. LOG_REWARD_FILE_NAME = "text2sql_grpo_rewards5.log"
  22. COMPLETION_SAMPLE_TXT_FILE_NAME = "completion_samples5.txt"
  23. def load_json(dir):
  24. with open(dir, "r") as j:
  25. contents = json.loads(j.read())
  26. return contents
  27. def execute_sql(predicted_sql, ground_truth_dbid):
  28. ground_truth, db_name = ground_truth_dbid.split("\t----- bird -----\t")
  29. # print(f"\n==== execute_sql ====\n{predicted_sql=}\n{ground_truth=}")
  30. db_path = DB_ROOT_PATH + db_name + "/" + db_name + ".sqlite"
  31. conn = sqlite3.connect(db_path)
  32. # Connect to the database
  33. cursor = conn.cursor()
  34. cursor.execute(predicted_sql)
  35. predicted_res = cursor.fetchall()
  36. cursor.execute(ground_truth)
  37. ground_truth_res = cursor.fetchall()
  38. res = 0
  39. if set(predicted_res) == set(ground_truth_res):
  40. res = 1
  41. print("execution result same")
  42. else:
  43. print("execution result different")
  44. conn.close()
  45. return res
  46. @dataclass
  47. class ScriptArguments:
  48. tokenizer_name_or_path: str = None
  49. ########################
  50. # Setup logging
  51. ########################
  52. logging.basicConfig(level=logging.INFO)
  53. logger = logging.getLogger(__name__)
  54. logger.setLevel(logging.INFO)
  55. handler = logging.StreamHandler()
  56. handler.setFormatter(
  57. logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
  58. )
  59. logger.addHandler(handler)
  60. ########################
  61. # Helper functions
  62. ########################
  63. def log_reward(reason, completion, gt):
  64. import os
  65. os.makedirs("logs", exist_ok=True)
  66. log_file = os.path.join("logs", LOG_REWARD_FILE_NAME)
  67. with open(log_file, "a") as f:
  68. f.write("\n\n==============\n")
  69. f.write(f">>>{reason=}\n>>>{completion=}\n>>>{gt=}\n")
  70. def extract_answer(text):
  71. """
  72. Extracts the final SQL statement answer from the raw text.
  73. """
  74. try:
  75. match = re.search(r"#### (\-?[\d\.,$]+)", text)
  76. if match:
  77. matched_string = match.group(1)
  78. # Remove any characters that would cause a ValueError,
  79. # such as dollar signs ($) and commas (,)
  80. cleaned_string = re.sub(r"[$,]", "", matched_string)
  81. return float(cleaned_string)
  82. match = re.search(
  83. r"(?:The final answer is|The answer is):?\s*(\-?[\d\.,$]+)",
  84. text,
  85. re.IGNORECASE,
  86. )
  87. if match:
  88. matched_string = match.group(1)
  89. cleaned_string = re.sub(r"[$,]", "", matched_string)
  90. return float(cleaned_string)
  91. except (ValueError, AttributeError):
  92. print(f"Error extracting answer from text: {match.group(1)}")
  93. pass
  94. return None
  95. def format_reward_func(completions, answer, **kwargs):
  96. """
  97. Format: <think>...</think><answer>...</answer>
  98. Args:
  99. completions (list[str]): Generated outputs
  100. answer (list[str]): Expected answers
  101. Returns:
  102. list[float]: Reward scores
  103. """
  104. rewards = []
  105. for completion, gt in zip(completions, answer):
  106. try:
  107. if random.random() < 0.1: # 1% chance to write samples into a file
  108. os.makedirs("completion_samples", exist_ok=True)
  109. log_file = os.path.join(
  110. "completion_samples", COMPLETION_SAMPLE_TXT_FILE_NAME
  111. )
  112. with open(log_file, "a") as f:
  113. f.write(f"\n\n==============\n")
  114. f.write(completion)
  115. # Check if the format is correct
  116. regex = r"<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\s*<answer>([\s\S]*?)<\/answer>$"
  117. match = re.search(regex, completion, re.DOTALL)
  118. # if the format is not correct, reward is 0
  119. if match is None or len(match.groups()) != 2:
  120. rewards.append(0.0)
  121. log_reward("format_reward 0", completion, gt)
  122. else:
  123. rewards.append(1.0)
  124. log_reward("format_reward 1", completion, gt)
  125. except Exception as e:
  126. rewards.append(0.0)
  127. log_reward(f"format_reward 0 - exception {e=}", completion, gt)
  128. return rewards
  129. def execution_reward_func(completions, answer, **kwargs):
  130. """
  131. Evaluates completions based on SQL statement execution result
  132. Args:
  133. completions (list[str]): Generated outputs
  134. answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json)
  135. Returns:
  136. list[float]: Reward scores
  137. """
  138. rewards = []
  139. for completion, gt in zip(completions, answer):
  140. try:
  141. # gt = extract_answer(gt)
  142. match = re.search(r"<answer>(.*?)<\/answer>", completion)
  143. if match is None:
  144. rewards.append(0.0)
  145. log_reward("execution_reward 0 - no answer tag found", completion, gt)
  146. continue
  147. # Extract the "answer" part from the completion
  148. predicted_sql = match.group(1).strip()
  149. reason = "execution result different"
  150. # execute the sql_generated and gt and compare the results
  151. try:
  152. res = func_timeout(
  153. 30.0,
  154. execute_sql,
  155. args=(predicted_sql, gt),
  156. )
  157. except KeyboardInterrupt:
  158. sys.exit(0)
  159. except FunctionTimedOut:
  160. print("FunctionTimedOut")
  161. reason = "execution timeout"
  162. res = 0
  163. except Exception as e:
  164. print("Exception", e)
  165. reason = f"execution exception {e}"
  166. res = 0
  167. if res == 1:
  168. # reason = "execution result same"
  169. rewards.append(1.0)
  170. log_reward("execution_reward 1", completion, gt)
  171. else:
  172. rewards.append(0.0)
  173. log_reward(
  174. f"execution_reward 0 {reason=}, {predicted_sql=}",
  175. completion,
  176. gt,
  177. )
  178. except Exception as e:
  179. # If evaluation fails, reward is 0
  180. rewards.append(0.0)
  181. log_reward(f"execution_reward 0 - exception {e=}", completion, gt)
  182. return rewards
  183. def get_ngrams(tokens: List[str], n: int) -> set:
  184. """Generates a set of n-grams from a list of tokens."""
  185. # Ensure there are enough tokens to create at least one n-gram
  186. if len(tokens) < n:
  187. return set()
  188. return {tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)}
  189. def n_gram_jaccard_similarity(candidate_query: str, gold_query: str, n: int) -> float:
  190. """Calculates the n-gram Jaccard similarity for a single n."""
  191. # Tokenize the SQL queries. Using .lower() for case-insensitivity.
  192. candidate_tokens = candidate_query.lower().split()
  193. gold_tokens = gold_query.lower().split()
  194. # Get the n-grams for both sets of tokens.
  195. candidate_ngrams = get_ngrams(candidate_tokens, n)
  196. gold_ngrams = get_ngrams(gold_tokens, n)
  197. # Handle the edge case where one or both sets are empty.
  198. if not candidate_ngrams and not gold_ngrams:
  199. return 1.0
  200. if not candidate_ngrams or not gold_ngrams:
  201. return 0.0
  202. # Calculate Jaccard similarity.
  203. intersection = len(candidate_ngrams.intersection(gold_ngrams))
  204. union = len(candidate_ngrams.union(gold_ngrams))
  205. return intersection / union
  206. def ensemble_n_gram_reward_func(completions, answer, **kwargs):
  207. """
  208. Calculates the averaged ensemble n-gram Jaccard similarity reward.
  209. This function computes the Jaccard similarity for n=1, 2, and 3
  210. and returns the average score for each sample.
  211. Args:
  212. completions (list[str]): Generated outputs
  213. answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json)
  214. Returns:
  215. list[float]: Reward scores
  216. """
  217. rewards = []
  218. questions = kwargs.get("question")
  219. evidences = kwargs.get("evidence")
  220. for completion, gt, question, evidence in zip(
  221. completions, answer, questions, evidences
  222. ):
  223. try:
  224. match = re.search(r"<answer>(.*?)<\/answer>", completion)
  225. if match is None:
  226. rewards.append(0.0)
  227. log_reward("n_gram_reward 0 - no answer tag found", completion, gt)
  228. continue
  229. # Extract the "answer" part from the completion
  230. predicted_sql = match.group(1).strip()
  231. # Calculate Jaccard similarity for n=1, 2, and 3
  232. jaccard_1 = n_gram_jaccard_similarity(predicted_sql, gt, n=1)
  233. jaccard_2 = n_gram_jaccard_similarity(predicted_sql, gt, n=2)
  234. jaccard_3 = n_gram_jaccard_similarity(predicted_sql, gt, n=3)
  235. # Average the scores to get the final ensemble reward
  236. average_jaccard = (jaccard_1 + jaccard_2 + jaccard_3) / 3.0
  237. print(f"{average_jaccard=}")
  238. rewards.append(average_jaccard)
  239. except Exception as e:
  240. rewards.append(0.0)
  241. log_reward(f"n_gram_reward 0 - exception {e=}", completion, gt)
  242. return rewards
  243. def llm_as_a_judge_reward_func(completions, answer, **kwargs):
  244. """
  245. Use Llama 3.3 70b as a judge to evaluate the quality of the generated SQL statements by comparing them to the ground truth answers.
  246. Args:
  247. completions (list[str]): Generated outputs
  248. answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json)
  249. Returns:
  250. list[float]: Reward scores
  251. """
  252. rewards = []
  253. client = Together()
  254. PROMPT_TEMPLATE = """
  255. You are an experienced database expert. Your task is to evaluate a generated SQL query by comparing it
  256. to the ground truth (gold) query and then assign a score between 0.0 and 2.0. A higher score indicates
  257. the predicted query is more correct, while a score of 0.0 means it is completely incorrect.
  258. Follow these evaluation rules strictly:
  259. 1. SELECT Clause:
  260. • Only select columns that are mentioned in the user’s question.
  261. • Do not include unnecessary columns or values.
  262. 2. Aggregation (MAX/MIN):
  263. • Always perform JOINs before applying MAX() or MIN().
  264. 3. ORDER BY with Distinct Values:
  265. • Use a GROUP BY <column> before an ORDER BY <column> ASC|DESC to ensure
  266. distinct values.
  267. 4. Handling NULLs:
  268. • If a column may contain NULL values (indicated by "None" in value examples
  269. or explicitly mentioned), include a JOIN or a WHERE <column> IS NOT NULL
  270. clause.
  271. 5. FROM/JOIN Clauses:
  272. • Only include the tables essential for answering the question.
  273. 6. Strictly Follow Hints:
  274. • Adhere to all hints provided with the question.
  275. 7. Thorough Question Analysis:
  276. • Ensure all conditions and requirements mentioned in the question are ad-
  277. dressed.
  278. 8. DISTINCT Keyword:
  279. • Use SELECT DISTINCTwhen the question requires unique values (e.g., IDs, URLs)
  280. or when column statistics (Value Statics) indicate its necessity.
  281. 9. Column Selection:
  282. • Carefully analyze column descriptions and hints to choose the correct column
  283. when similar columns exist across tables.
  284. 10. String Concatenation:
  285. • Do not use any string concatenation methods (e.g., || ’ ’ ||) in the SELECT
  286. clause.
  287. 11. JOIN Preference:
  288. • Prefer using INNER JOINover nested SELECT statements.
  289. 12. Date Processing:
  290. • Use STRFTIME()for any date manipulations (e.g., STRFTIME(’%Y’, SOMETIME)to
  291. extract the year).
  292. You are provided with the following inputs:
  293. • Question: {QUESTION}
  294. • Hint: {HINT}
  295. • Gold Query: {GOLD_QUERY}
  296. • Predicted Query: {PREDICTED_QUERY}
  297. Based on the above, return a single numeric score between 0.0 and 2.0 that reflects how
  298. correct the predicted query is compared to the gold query. Respond with only the score and
  299. no additional explanation.
  300. """
  301. questions = kwargs.get("question")
  302. evidences = kwargs.get("evidence")
  303. for completion, gt, question, evidence in zip(
  304. completions, answer, questions, evidences
  305. ):
  306. try:
  307. match = re.search(r"<answer>(.*?)<\/answer>", completion)
  308. if match is None:
  309. rewards.append(0.0)
  310. log_reward(
  311. "llm_as_a_judge_reward_func 0 - no answer tag found", completion, gt
  312. )
  313. continue
  314. # Extract the "answer" part from the completion
  315. predicted_sql = match.group(1).strip()
  316. prompt = PROMPT_TEMPLATE.format(
  317. QUESTION=question,
  318. HINT=evidence,
  319. GOLD_QUERY=gt,
  320. PREDICTED_QUERY=predicted_sql,
  321. )
  322. response = client.chat.completions.create(
  323. model="meta-llama/Llama-3.3-70B-Instruct-Turbo",
  324. messages=[{"role": "user", "content": prompt}],
  325. temperature=0,
  326. )
  327. reward = float(response.choices[0].message.content)
  328. print(f"llm_as_a_judge_reward_func>>> {reward=}")
  329. rewards.append(reward)
  330. except Exception as e:
  331. rewards.append(0.0)
  332. log_reward(f"llm_as_a_judge_reward_func 0 - exception {e=}", completion, gt)
  333. return rewards
  334. def get_checkpoint(training_args: GRPOConfig):
  335. last_checkpoint = None
  336. if os.path.isdir(training_args.output_dir):
  337. last_checkpoint = get_last_checkpoint(training_args.output_dir)
  338. return last_checkpoint
  339. def generate_schema_prompt(db_path, num_rows=None):
  340. # extract create ddls
  341. """
  342. :param root_place:
  343. :param db_name:
  344. :return:
  345. """
  346. full_schema_prompt_list = []
  347. conn = sqlite3.connect(db_path)
  348. # Create a cursor object
  349. cursor = conn.cursor()
  350. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  351. tables = cursor.fetchall()
  352. schemas = {}
  353. for table in tables:
  354. if table == "sqlite_sequence":
  355. continue
  356. cursor.execute(
  357. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  358. table[0]
  359. )
  360. )
  361. create_prompt = cursor.fetchone()[0]
  362. schemas[table[0]] = create_prompt
  363. if num_rows:
  364. cur_table = table[0]
  365. if cur_table in ["order", "by", "group"]:
  366. cur_table = "`{}`".format(cur_table)
  367. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  368. column_names = [description[0] for description in cursor.description]
  369. values = cursor.fetchall()
  370. # Format the rows as a simple table representation
  371. rows_prompt = "\n".join(
  372. "\t".join(str(val) for val in row) for row in values
  373. )
  374. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  375. num_rows, cur_table, num_rows, rows_prompt
  376. )
  377. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  378. for k, v in schemas.items():
  379. full_schema_prompt_list.append(v)
  380. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  381. return schema_prompt
  382. def generate_comment_prompt(question, knowledge=None):
  383. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  384. question_prompt = "-- Question: {}".format(question)
  385. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  386. return result_prompt
  387. def generate_combined_prompts_one(db_path, question, knowledge=None):
  388. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  389. comment_prompt = generate_comment_prompt(question, knowledge)
  390. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  391. return combined_prompts
  392. def grpo_function(
  393. model_args: ModelConfig, script_args: ScriptArguments, training_args: GRPOConfig
  394. ):
  395. logger.info(f"Model parameters {model_args}")
  396. logger.info(f"Training/evaluation parameters {training_args}")
  397. tokenizer = AutoTokenizer.from_pretrained(
  398. (
  399. script_args.tokenizer_name_or_path
  400. if script_args.tokenizer_name_or_path
  401. else model_args.model_name_or_path
  402. ),
  403. revision=model_args.model_revision,
  404. trust_remote_code=model_args.trust_remote_code,
  405. )
  406. if tokenizer.pad_token is None:
  407. tokenizer.pad_token = tokenizer.eos_token
  408. ds = []
  409. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  410. input_json = json.load(open(TRAIN_JSON, "r"))
  411. for i, item in tqdm(enumerate(input_json)):
  412. print(f"processing #{i+1}")
  413. db_id = item["db_id"]
  414. question = item["question"]
  415. external_knowledge = item["evidence"]
  416. SQL = item["SQL"]
  417. db_path = DB_ROOT_PATH + "/" + db_id + "/" + db_id + ".sqlite"
  418. prompt = generate_combined_prompts_one(
  419. db_path,
  420. question,
  421. knowledge=external_knowledge,
  422. )
  423. example = {
  424. "messages": [
  425. {"role": "system", "content": SYSTEM_PROMPT},
  426. {"role": "user", "content": prompt},
  427. {"role": "assistant", "content": SQL + "\t----- bird -----\t" + db_id},
  428. ],
  429. "question": question,
  430. "evidence": external_knowledge,
  431. }
  432. ds.append(example)
  433. dataset_dict = {key: [d[key] for d in ds] for key in ds[0]}
  434. dataset = Dataset.from_dict(dataset_dict)
  435. def generate_r1_prompt(
  436. system_prompt, user_prompt, ground_truth, question, evidence
  437. ):
  438. r1_prefix = [
  439. {
  440. "role": "system",
  441. "content": """You are great at reasoning and translating natural language question to SQLite SQL query. Given DB Schema, External Knowledge, and Question, your task is to first generate step-by-step reasoning, then apply the resoning to generate the SQLite select statement as the accurate translation of the Question. Enclose the step-by-step reasoning within the <think> </think> tags, and the final SQL statement within the <answer> </answer> tags, i.e. <think> reasoning steps </think> <answer> final SQL </answer>.""",
  442. },
  443. {"role": "user", "content": user_prompt},
  444. ]
  445. return {
  446. "prompt": tokenizer.apply_chat_template(
  447. r1_prefix, tokenize=False, continue_final_message=True
  448. ),
  449. "answer": ground_truth,
  450. "question": question,
  451. "evidence": evidence,
  452. }
  453. # convert our dataset to the r1 prompt
  454. dataset = dataset.map(
  455. lambda x: generate_r1_prompt(
  456. x["messages"][0]["content"],
  457. x["messages"][1]["content"],
  458. x["messages"][2]["content"],
  459. x["question"],
  460. x["evidence"],
  461. ),
  462. remove_columns=dataset.column_names, # Remove original columns to avoid conflicts with apply_chat_template
  463. )
  464. # split the dataset into train and test
  465. train_test_split = dataset.train_test_split(test_size=0.3)
  466. train_dataset = train_test_split["train"]
  467. eval_dataset = train_test_split["test"]
  468. print("len(train_dataset)", len(train_dataset))
  469. print(train_dataset[0])
  470. print("len(eval_dataset)", len(eval_dataset))
  471. print(eval_dataset[0])
  472. #########################
  473. # Instantiate DPO trainer
  474. #########################
  475. trainer = GRPOTrainer(
  476. model=model_args.model_name_or_path,
  477. reward_funcs=[
  478. format_reward_func,
  479. execution_reward_func,
  480. ensemble_n_gram_reward_func,
  481. llm_as_a_judge_reward_func,
  482. ],
  483. args=training_args,
  484. train_dataset=train_dataset,
  485. eval_dataset=eval_dataset,
  486. peft_config=get_peft_config(model_args),
  487. )
  488. trainer.tokenizer = tokenizer
  489. ###############
  490. # Training loop
  491. ###############
  492. # Check for last checkpoint
  493. last_checkpoint = get_checkpoint(training_args)
  494. # by default training_args.resume_from_checkpoint is None
  495. if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
  496. logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.")
  497. # Train the model
  498. logger.info(
  499. f'*** Starting training {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} for {training_args.num_train_epochs} epochs***'
  500. )
  501. train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
  502. # Log and save metrics
  503. metrics = train_result.metrics
  504. metrics["train_samples"] = len(train_dataset)
  505. trainer.log_metrics("train", metrics)
  506. trainer.save_metrics("train", metrics)
  507. trainer.save_state()
  508. logger.info("*** Training complete ***")
  509. ##################################
  510. # Save model and create model card
  511. ##################################
  512. logger.info("*** Save model ***")
  513. trainer.model.config.use_cache = True
  514. trainer.save_model(training_args.output_dir)
  515. logger.info(f"Model saved to {training_args.output_dir}")
  516. training_args.distributed_state.wait_for_everyone() # wait for all processes to load
  517. tokenizer.save_pretrained(training_args.output_dir)
  518. logger.info(f"Tokenizer saved to {training_args.output_dir}")
  519. # Save everything else on main process
  520. # if trainer.accelerator.is_main_process:
  521. # trainer.create_model_card({"tags": ["rl", "grpo", "tutorial", "philschmid"]})
  522. # push to hub if needed
  523. # if training_args.push_to_hub is True:
  524. # logger.info("Pushing to hub...")
  525. # trainer.push_to_hub()
  526. logger.info("*** Training complete! ***")
  527. def main():
  528. parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig))
  529. model_args, script_args, training_args = parser.parse_args_and_config()
  530. # Run the main training loop
  531. grpo_function(model_args, script_args, training_args)
  532. if __name__ == "__main__":
  533. main()