grpo_text2sql.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671
  1. import json
  2. import logging
  3. import os
  4. import random
  5. import re
  6. import sqlite3
  7. import sys
  8. from dataclasses import dataclass
  9. from datetime import datetime
  10. from typing import List
  11. from datasets import Dataset
  12. from func_timeout import func_timeout, FunctionTimedOut
  13. from together import Together
  14. from tqdm import tqdm
  15. from transformers import AutoTokenizer
  16. from transformers.trainer_utils import get_last_checkpoint
  17. from trl import get_peft_config, GRPOConfig, GRPOTrainer, ModelConfig, TrlParser
  18. os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
  19. TRAIN_JSON = "../../data/train/train.json"
  20. DB_ROOT_PATH = "../../data/train/train_databases/"
  21. LOG_REWARD_FILE_NAME = "text2sql_grpo_rewards5.log"
  22. COMPLETION_SAMPLE_TXT_FILE_NAME = "completion_samples5.txt"
  23. def load_json(dir):
  24. with open(dir, "r") as j:
  25. contents = json.loads(j.read())
  26. return contents
  27. def execute_sql(predicted_sql, ground_truth_dbid):
  28. ground_truth, db_name = ground_truth_dbid.split("\t----- bird -----\t")
  29. # print(f"\n==== execute_sql ====\n{predicted_sql=}\n{ground_truth=}")
  30. db_path = DB_ROOT_PATH + db_name + "/" + db_name + ".sqlite"
  31. conn = sqlite3.connect(db_path)
  32. # Connect to the database
  33. cursor = conn.cursor()
  34. cursor.execute(predicted_sql)
  35. predicted_res = cursor.fetchall()
  36. cursor.execute(ground_truth)
  37. ground_truth_res = cursor.fetchall()
  38. res = 0
  39. if set(predicted_res) == set(ground_truth_res):
  40. res = 1
  41. print("execution result same")
  42. else:
  43. print("execution result different")
  44. conn.close()
  45. return res
  46. @dataclass
  47. class ScriptArguments:
  48. tokenizer_name_or_path: str = None
  49. ########################
  50. # Setup logging
  51. ########################
  52. logging.basicConfig(level=logging.INFO)
  53. logger = logging.getLogger(__name__)
  54. logger.setLevel(logging.INFO)
  55. handler = logging.StreamHandler()
  56. handler.setFormatter(
  57. logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
  58. )
  59. logger.addHandler(handler)
  60. ########################
  61. # Helper functions
  62. ########################
  63. def log_reward(reason, completion, gt):
  64. import os
  65. os.makedirs("logs", exist_ok=True)
  66. log_file = os.path.join("logs", LOG_REWARD_FILE_NAME)
  67. with open(log_file, "a") as f:
  68. f.write("\n\n==============\n")
  69. f.write(f">>>{reason=}\n>>>{completion=}\n>>>{gt=}\n")
  70. def extract_answer(text):
  71. """
  72. Extracts the final SQL statement answer from the raw text.
  73. """
  74. try:
  75. match = re.search(r"#### (\-?[\d\.,$]+)", text)
  76. if match:
  77. matched_string = match.group(1)
  78. # Remove any characters that would cause a ValueError,
  79. # such as dollar signs ($) and commas (,)
  80. cleaned_string = re.sub(r"[$,]", "", matched_string)
  81. return float(cleaned_string)
  82. match = re.search(
  83. r"(?:The final answer is|The answer is):?\s*(\-?[\d\.,$]+)",
  84. text,
  85. re.IGNORECASE,
  86. )
  87. if match:
  88. matched_string = match.group(1)
  89. cleaned_string = re.sub(r"[$,]", "", matched_string)
  90. return float(cleaned_string)
  91. except (ValueError, AttributeError):
  92. print(f"Error extracting answer from text: {match.group(1)}")
  93. pass
  94. return None
  95. def format_reward_func(completions, answer, **kwargs):
  96. """
  97. Format: <think>...</think><answer>...</answer>
  98. Args:
  99. completions (list[str]): Generated outputs
  100. answer (list[str]): Expected answers
  101. Returns:
  102. list[float]: Reward scores
  103. """
  104. rewards = []
  105. for completion, gt in zip(completions, answer):
  106. try:
  107. if random.random() < 0.1: # 1% chance to write samples into a file
  108. os.makedirs("completion_samples", exist_ok=True)
  109. log_file = os.path.join(
  110. "completion_samples", COMPLETION_SAMPLE_TXT_FILE_NAME
  111. )
  112. with open(log_file, "a") as f:
  113. f.write(f"\n\n==============\n")
  114. f.write(completion)
  115. # Check if the format is correct
  116. regex = r"<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\s*<answer>([\s\S]*?)<\/answer>$"
  117. match = re.search(regex, completion, re.DOTALL)
  118. # if the format is not correct, reward is 0
  119. if match is None or len(match.groups()) != 2:
  120. rewards.append(0.0)
  121. log_reward("format_reward 0", completion, gt)
  122. else:
  123. rewards.append(1.0)
  124. log_reward("format_reward 1", completion, gt)
  125. except Exception as e:
  126. rewards.append(0.0)
  127. log_reward(f"format_reward 0 - exception {e=}", completion, gt)
  128. return rewards
  129. def execution_reward_func(completions, answer, **kwargs):
  130. """
  131. Evaluates completions based on SQL statement execution result
  132. Args:
  133. completions (list[str]): Generated outputs
  134. answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json)
  135. Returns:
  136. list[float]: Reward scores
  137. """
  138. rewards = []
  139. for completion, gt in zip(completions, answer):
  140. try:
  141. # gt = extract_answer(gt)
  142. match = re.search(r"<answer>(.*?)<\/answer>", completion)
  143. if match is None:
  144. rewards.append(0.0)
  145. log_reward("execution_reward 0 - no answer tag found", completion, gt)
  146. continue
  147. # Extract the "answer" part from the completion
  148. predicted_sql = match.group(1).strip()
  149. reason = "execution result different"
  150. # execute the sql_generated and gt and compare the results
  151. try:
  152. res = func_timeout(
  153. 30.0,
  154. execute_sql,
  155. args=(predicted_sql, gt),
  156. )
  157. except KeyboardInterrupt:
  158. sys.exit(0)
  159. except FunctionTimedOut:
  160. print("FunctionTimedOut")
  161. reason = "execution timeout"
  162. res = 0
  163. except Exception as e:
  164. print("Exception", e)
  165. reason = f"execution exception {e}"
  166. res = 0
  167. if res == 1:
  168. # reason = "execution result same"
  169. rewards.append(1.0)
  170. log_reward("execution_reward 1", completion, gt)
  171. else:
  172. rewards.append(0.0)
  173. log_reward(
  174. f"execution_reward 0 {reason=}, {predicted_sql=}",
  175. completion,
  176. gt,
  177. )
  178. except Exception as e:
  179. # If evaluation fails, reward is 0
  180. rewards.append(0.0)
  181. log_reward(f"execution_reward 0 - exception {e=}", completion, gt)
  182. return rewards
  183. def get_ngrams(tokens: List[str], n: int) -> set:
  184. """Generates a set of n-grams from a list of tokens."""
  185. # Ensure there are enough tokens to create at least one n-gram
  186. if len(tokens) < n:
  187. return set()
  188. return {tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)}
  189. def n_gram_jaccard_similarity(candidate_query: str, gold_query: str, n: int) -> float:
  190. """Calculates the n-gram Jaccard similarity for a single n."""
  191. # Tokenize the SQL queries. Using .lower() for case-insensitivity.
  192. candidate_tokens = candidate_query.lower().split()
  193. gold_tokens = gold_query.lower().split()
  194. # Get the n-grams for both sets of tokens.
  195. candidate_ngrams = get_ngrams(candidate_tokens, n)
  196. gold_ngrams = get_ngrams(gold_tokens, n)
  197. # Handle the edge case where one or both sets are empty.
  198. if not candidate_ngrams and not gold_ngrams:
  199. return 1.0
  200. if not candidate_ngrams or not gold_ngrams:
  201. return 0.0
  202. # Calculate Jaccard similarity.
  203. intersection = len(candidate_ngrams.intersection(gold_ngrams))
  204. union = len(candidate_ngrams.union(gold_ngrams))
  205. return intersection / union
  206. def ensemble_n_gram_reward_func(completions, answer, **kwargs):
  207. """
  208. Calculates the averaged ensemble n-gram Jaccard similarity reward.
  209. This function computes the Jaccard similarity for n=1, 2, and 3
  210. and returns the average score for each sample.
  211. Args:
  212. completions (list[str]): Generated outputs
  213. answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json)
  214. Returns:
  215. list[float]: Reward scores
  216. """
  217. rewards = []
  218. questions = kwargs.get("question")
  219. evidences = kwargs.get("evidence")
  220. for completion, gt, question, evidence in zip(
  221. completions, answer, questions, evidences
  222. ):
  223. # print(f">>>>>ensemble_n_gram_reward_func: {gt=}")
  224. # print(f">>>>>ensemble_n_gram_reward_func: {completion=}")
  225. # print(f">>>>>ensemble_n_gram_reward_func: {question=}")
  226. # print(f">>>>>ensemble_n_gram_reward_func: {evidence=}")
  227. try:
  228. match = re.search(r"<answer>(.*?)<\/answer>", completion)
  229. if match is None:
  230. rewards.append(0.0)
  231. log_reward("n_gram_reward 0 - no answer tag found", completion, gt)
  232. continue
  233. # Extract the "answer" part from the completion
  234. predicted_sql = match.group(1).strip()
  235. # Calculate Jaccard similarity for n=1, 2, and 3
  236. jaccard_1 = n_gram_jaccard_similarity(predicted_sql, gt, n=1)
  237. jaccard_2 = n_gram_jaccard_similarity(predicted_sql, gt, n=2)
  238. jaccard_3 = n_gram_jaccard_similarity(predicted_sql, gt, n=3)
  239. # Average the scores to get the final ensemble reward
  240. average_jaccard = (jaccard_1 + jaccard_2 + jaccard_3) / 3.0
  241. print(f"{average_jaccard=}")
  242. rewards.append(average_jaccard)
  243. except Exception as e:
  244. rewards.append(0.0)
  245. log_reward(f"n_gram_reward 0 - exception {e=}", completion, gt)
  246. return rewards
  247. def llm_as_a_judge_reward_func(completions, answer, **kwargs):
  248. """
  249. Use Llama 3.3 70b as a judge to evaluate the quality of the generated SQL statements by comparing them to the ground truth answers.
  250. Args:
  251. completions (list[str]): Generated outputs
  252. answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json)
  253. Returns:
  254. list[float]: Reward scores
  255. """
  256. rewards = []
  257. client = Together()
  258. PROMPT_TEMPLATE = """
  259. You are an experienced database expert. Your task is to evaluate a generated SQL query by comparing it
  260. to the ground truth (gold) query and then assign a score between 0.0 and 2.0. A higher score indicates
  261. the predicted query is more correct, while a score of 0.0 means it is completely incorrect.
  262. Follow these evaluation rules strictly:
  263. 1. SELECT Clause:
  264. • Only select columns that are mentioned in the user’s question.
  265. • Do not include unnecessary columns or values.
  266. 2. Aggregation (MAX/MIN):
  267. • Always perform JOINs before applying MAX() or MIN().
  268. 3. ORDER BY with Distinct Values:
  269. • Use a GROUP BY <column> before an ORDER BY <column> ASC|DESC to ensure
  270. distinct values.
  271. 4. Handling NULLs:
  272. • If a column may contain NULL values (indicated by "None" in value examples
  273. or explicitly mentioned), include a JOIN or a WHERE <column> IS NOT NULL
  274. clause.
  275. 5. FROM/JOIN Clauses:
  276. • Only include the tables essential for answering the question.
  277. 6. Strictly Follow Hints:
  278. • Adhere to all hints provided with the question.
  279. 7. Thorough Question Analysis:
  280. • Ensure all conditions and requirements mentioned in the question are ad-
  281. dressed.
  282. 8. DISTINCT Keyword:
  283. • Use SELECT DISTINCTwhen the question requires unique values (e.g., IDs, URLs)
  284. or when column statistics (Value Statics) indicate its necessity.
  285. 9. Column Selection:
  286. • Carefully analyze column descriptions and hints to choose the correct column
  287. when similar columns exist across tables.
  288. 10. String Concatenation:
  289. • Do not use any string concatenation methods (e.g., || ’ ’ ||) in the SELECT
  290. clause.
  291. 11. JOIN Preference:
  292. • Prefer using INNER JOINover nested SELECT statements.
  293. 12. Date Processing:
  294. • Use STRFTIME()for any date manipulations (e.g., STRFTIME(’%Y’, SOMETIME)to
  295. extract the year).
  296. You are provided with the following inputs:
  297. • Question: {QUESTION}
  298. • Hint: {HINT}
  299. • Gold Query: {GOLD_QUERY}
  300. • Predicted Query: {PREDICTED_QUERY}
  301. Based on the above, return a single numeric score between 0.0 and 2.0 that reflects how
  302. correct the predicted query is compared to the gold query. Respond with only the score and
  303. no additional explanation.
  304. """
  305. questions = kwargs.get("question")
  306. evidences = kwargs.get("evidence")
  307. for completion, gt, question, evidence in zip(
  308. completions, answer, questions, evidences
  309. ):
  310. try:
  311. match = re.search(r"<answer>(.*?)<\/answer>", completion)
  312. if match is None:
  313. rewards.append(0.0)
  314. log_reward(
  315. "llm_as_a_judge_reward_func 0 - no answer tag found", completion, gt
  316. )
  317. continue
  318. # Extract the "answer" part from the completion
  319. predicted_sql = match.group(1).strip()
  320. prompt = PROMPT_TEMPLATE.format(
  321. QUESTION=question,
  322. HINT=evidence,
  323. GOLD_QUERY=gt,
  324. PREDICTED_QUERY=predicted_sql,
  325. )
  326. response = client.chat.completions.create(
  327. model="meta-llama/Llama-3.3-70B-Instruct-Turbo",
  328. messages=[{"role": "user", "content": prompt}],
  329. temperature=0,
  330. )
  331. rewards.append(float(response.choices[0].message.content))
  332. except Exception as e:
  333. rewards.append(0.0)
  334. log_reward(f"llm_as_a_judge_reward_func 0 - exception {e=}", completion, gt)
  335. return rewards
  336. def get_checkpoint(training_args: GRPOConfig):
  337. last_checkpoint = None
  338. if os.path.isdir(training_args.output_dir):
  339. last_checkpoint = get_last_checkpoint(training_args.output_dir)
  340. return last_checkpoint
  341. def generate_schema_prompt(db_path, num_rows=None):
  342. # extract create ddls
  343. """
  344. :param root_place:
  345. :param db_name:
  346. :return:
  347. """
  348. full_schema_prompt_list = []
  349. conn = sqlite3.connect(db_path)
  350. # Create a cursor object
  351. cursor = conn.cursor()
  352. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  353. tables = cursor.fetchall()
  354. schemas = {}
  355. for table in tables:
  356. if table == "sqlite_sequence":
  357. continue
  358. cursor.execute(
  359. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  360. table[0]
  361. )
  362. )
  363. create_prompt = cursor.fetchone()[0]
  364. schemas[table[0]] = create_prompt
  365. if num_rows:
  366. cur_table = table[0]
  367. if cur_table in ["order", "by", "group"]:
  368. cur_table = "`{}`".format(cur_table)
  369. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  370. column_names = [description[0] for description in cursor.description]
  371. values = cursor.fetchall()
  372. # Format the rows as a simple table representation
  373. rows_prompt = "\n".join(
  374. "\t".join(str(val) for val in row) for row in values
  375. )
  376. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  377. num_rows, cur_table, num_rows, rows_prompt
  378. )
  379. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  380. for k, v in schemas.items():
  381. full_schema_prompt_list.append(v)
  382. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  383. return schema_prompt
  384. def generate_comment_prompt(question, knowledge=None):
  385. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  386. question_prompt = "-- Question: {}".format(question)
  387. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  388. return result_prompt
  389. def generate_combined_prompts_one(db_path, question, knowledge=None):
  390. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  391. comment_prompt = generate_comment_prompt(question, knowledge)
  392. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  393. return combined_prompts
  394. def grpo_function(
  395. model_args: ModelConfig, script_args: ScriptArguments, training_args: GRPOConfig
  396. ):
  397. logger.info(f"Model parameters {model_args}")
  398. logger.info(f"Training/evaluation parameters {training_args}")
  399. tokenizer = AutoTokenizer.from_pretrained(
  400. (
  401. script_args.tokenizer_name_or_path
  402. if script_args.tokenizer_name_or_path
  403. else model_args.model_name_or_path
  404. ),
  405. revision=model_args.model_revision,
  406. trust_remote_code=model_args.trust_remote_code,
  407. )
  408. if tokenizer.pad_token is None:
  409. tokenizer.pad_token = tokenizer.eos_token
  410. ds = []
  411. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  412. input_json = json.load(open(TRAIN_JSON, "r"))
  413. for i, item in tqdm(enumerate(input_json)):
  414. print(f"processing #{i+1}")
  415. db_id = item["db_id"]
  416. question = item["question"]
  417. external_knowledge = item["evidence"]
  418. SQL = item["SQL"]
  419. db_path = DB_ROOT_PATH + "/" + db_id + "/" + db_id + ".sqlite"
  420. prompt = generate_combined_prompts_one(
  421. db_path,
  422. question,
  423. knowledge=external_knowledge,
  424. )
  425. example = {
  426. "messages": [
  427. {"role": "system", "content": SYSTEM_PROMPT},
  428. {"role": "user", "content": prompt},
  429. {"role": "assistant", "content": SQL + "\t----- bird -----\t" + db_id},
  430. ],
  431. "question": question,
  432. "evidence": external_knowledge,
  433. }
  434. ds.append(example)
  435. dataset_dict = {key: [d[key] for d in ds] for key in ds[0]}
  436. dataset = Dataset.from_dict(dataset_dict)
  437. def generate_r1_prompt(
  438. system_prompt, user_prompt, ground_truth, question, evidence
  439. ):
  440. r1_prefix = [
  441. {
  442. "role": "system",
  443. "content": """You are great at reasoning and translating natural language question to SQLite SQL query. Given DB Schema, External Knowledge, and Question, your task is to first generate step-by-step reasoning, then apply the resoning to generate the SQLite select statement as the accurate translation of the Question. Enclose the step-by-step reasoning within the <think> </think> tags, and the final SQL statement within the <answer> </answer> tags, i.e. <think> reasoning steps </think> <answer> final SQL </answer>.""",
  444. },
  445. {"role": "user", "content": user_prompt},
  446. ]
  447. return {
  448. "prompt": tokenizer.apply_chat_template(
  449. r1_prefix, tokenize=False, continue_final_message=True
  450. ),
  451. "answer": ground_truth,
  452. "question": question,
  453. "evidence": evidence,
  454. }
  455. # convert our dataset to the r1 prompt
  456. dataset = dataset.map(
  457. lambda x: generate_r1_prompt(
  458. x["messages"][0]["content"],
  459. x["messages"][1]["content"],
  460. x["messages"][2]["content"],
  461. x["question"],
  462. x["evidence"],
  463. ),
  464. remove_columns=dataset.column_names, # Remove original columns to avoid conflicts with apply_chat_template
  465. )
  466. # split the dataset into train and test
  467. train_test_split = dataset.train_test_split(test_size=0.3)
  468. train_dataset = train_test_split["train"]
  469. eval_dataset = train_test_split["test"]
  470. print("len(train_dataset)", len(train_dataset))
  471. print(train_dataset[0])
  472. print("len(eval_dataset)", len(eval_dataset))
  473. print(eval_dataset[0])
  474. #########################
  475. # Instantiate DPO trainer
  476. #########################
  477. trainer = GRPOTrainer(
  478. model=model_args.model_name_or_path,
  479. reward_funcs=[
  480. format_reward_func,
  481. execution_reward_func,
  482. ensemble_n_gram_reward_func,
  483. ],
  484. args=training_args,
  485. train_dataset=train_dataset,
  486. eval_dataset=eval_dataset,
  487. peft_config=get_peft_config(model_args),
  488. )
  489. trainer.tokenizer = tokenizer
  490. ###############
  491. # Training loop
  492. ###############
  493. # Check for last checkpoint
  494. last_checkpoint = get_checkpoint(training_args)
  495. # JT: by default training_args.resume_from_checkpoint is None
  496. if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
  497. logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.")
  498. # Train the model
  499. logger.info(
  500. f'*** Starting training {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} for {training_args.num_train_epochs} epochs***'
  501. )
  502. train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
  503. # Log and save metrics
  504. metrics = train_result.metrics
  505. metrics["train_samples"] = len(train_dataset)
  506. trainer.log_metrics("train", metrics)
  507. trainer.save_metrics("train", metrics)
  508. trainer.save_state()
  509. logger.info("*** Training complete ***")
  510. ##################################
  511. # Save model and create model card
  512. ##################################
  513. logger.info("*** Save model ***")
  514. trainer.model.config.use_cache = True
  515. trainer.save_model(training_args.output_dir)
  516. logger.info(f"Model saved to {training_args.output_dir}")
  517. training_args.distributed_state.wait_for_everyone() # wait for all processes to load
  518. tokenizer.save_pretrained(training_args.output_dir)
  519. logger.info(f"Tokenizer saved to {training_args.output_dir}")
  520. # Save everything else on main process
  521. # if trainer.accelerator.is_main_process:
  522. # trainer.create_model_card({"tags": ["rl", "grpo", "tutorial", "philschmid"]})
  523. # push to hub if needed
  524. # if training_args.push_to_hub is True:
  525. # logger.info("Pushing to hub...")
  526. # trainer.push_to_hub()
  527. logger.info("*** Training complete! ***")
  528. def main():
  529. parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig))
  530. model_args, script_args, training_args = parser.parse_args_and_config()
  531. # print("model_args", model_args)
  532. # print("script_args", script_args)
  533. # print("training_args", training_args)
  534. # exit()
  535. # Run the main training loop
  536. grpo_function(model_args, script_args, training_args)
  537. if __name__ == "__main__":
  538. main()
  539. # two ways to run this script:
  540. # with-proxy accelerate launch --num_processes 8 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml
  541. # with-proxy nohup accelerate launch --num_processes 4 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml &