grpo_text2sql.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. import json
  2. import logging
  3. import os
  4. import random
  5. import re
  6. import sqlite3
  7. import sys
  8. from dataclasses import dataclass
  9. from datetime import datetime
  10. from typing import List
  11. from datasets import Dataset
  12. from func_timeout import func_timeout, FunctionTimedOut
  13. from tqdm import tqdm
  14. from transformers import AutoTokenizer
  15. from transformers.trainer_utils import get_last_checkpoint
  16. from trl import get_peft_config, GRPOConfig, GRPOTrainer, ModelConfig, TrlParser
  17. os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
  18. TRAIN_JSON = "../../data/train/train.json"
  19. DB_ROOT_PATH = "../../data/train/train_databases/"
  20. LOG_REWARD_FILE_NAME = "text2sql_grpo_rewards5.log"
  21. COMPLETION_SAMPLE_TXT_FILE_NAME = "completion_samples5.txt"
  22. def load_json(dir):
  23. with open(dir, "r") as j:
  24. contents = json.loads(j.read())
  25. return contents
  26. def execute_sql(predicted_sql, ground_truth_dbid):
  27. ground_truth, db_name = ground_truth_dbid.split("\t----- bird -----\t")
  28. # print(f"\n==== execute_sql ====\n{predicted_sql=}\n{ground_truth=}")
  29. db_path = DB_ROOT_PATH + db_name + "/" + db_name + ".sqlite"
  30. conn = sqlite3.connect(db_path)
  31. # Connect to the database
  32. cursor = conn.cursor()
  33. cursor.execute(predicted_sql)
  34. predicted_res = cursor.fetchall()
  35. cursor.execute(ground_truth)
  36. ground_truth_res = cursor.fetchall()
  37. res = 0
  38. if set(predicted_res) == set(ground_truth_res):
  39. res = 1
  40. print("execution result same")
  41. else:
  42. print("execution result different")
  43. conn.close()
  44. return res
  45. @dataclass
  46. class ScriptArguments:
  47. tokenizer_name_or_path: str = None
  48. ########################
  49. # Setup logging
  50. ########################
  51. logging.basicConfig(level=logging.INFO)
  52. logger = logging.getLogger(__name__)
  53. logger.setLevel(logging.INFO)
  54. handler = logging.StreamHandler()
  55. handler.setFormatter(
  56. logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
  57. )
  58. logger.addHandler(handler)
  59. ########################
  60. # Helper functions
  61. ########################
  62. def log_reward(reason, completion, gt):
  63. import os
  64. os.makedirs("logs", exist_ok=True)
  65. log_file = os.path.join("logs", LOG_REWARD_FILE_NAME)
  66. with open(log_file, "a") as f:
  67. f.write("\n\n==============\n")
  68. f.write(f">>>{reason=}\n>>>{completion=}\n>>>{gt=}\n")
  69. def extract_answer(text):
  70. """
  71. Extracts the final SQL statement answer from the raw text.
  72. """
  73. try:
  74. match = re.search(r"#### (\-?[\d\.,$]+)", text)
  75. if match:
  76. matched_string = match.group(1)
  77. # Remove any characters that would cause a ValueError,
  78. # such as dollar signs ($) and commas (,)
  79. cleaned_string = re.sub(r"[$,]", "", matched_string)
  80. return float(cleaned_string)
  81. match = re.search(
  82. r"(?:The final answer is|The answer is):?\s*(\-?[\d\.,$]+)",
  83. text,
  84. re.IGNORECASE,
  85. )
  86. if match:
  87. matched_string = match.group(1)
  88. cleaned_string = re.sub(r"[$,]", "", matched_string)
  89. return float(cleaned_string)
  90. except (ValueError, AttributeError):
  91. print(f"Error extracting answer from text: {match.group(1)}")
  92. pass
  93. return None
  94. def format_reward_func(completions, answer, **kwargs):
  95. """
  96. Format: <think>...</think><answer>...</answer>
  97. Args:
  98. completions (list[str]): Generated outputs
  99. answer (list[str]): Expected answers
  100. Returns:
  101. list[float]: Reward scores
  102. """
  103. rewards = []
  104. for completion, gt in zip(completions, answer):
  105. try:
  106. if random.random() < 0.1: # 1% chance to write samples into a file
  107. os.makedirs("completion_samples", exist_ok=True)
  108. log_file = os.path.join(
  109. "completion_samples", COMPLETION_SAMPLE_TXT_FILE_NAME
  110. )
  111. with open(log_file, "a") as f:
  112. f.write(f"\n\n==============\n")
  113. f.write(completion)
  114. # Check if the format is correct
  115. regex = r"<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\s*<answer>([\s\S]*?)<\/answer>$"
  116. match = re.search(regex, completion, re.DOTALL)
  117. # if the format is not correct, reward is 0
  118. if match is None or len(match.groups()) != 2:
  119. rewards.append(0.0)
  120. log_reward("format_reward 0", completion, gt)
  121. else:
  122. rewards.append(1.0)
  123. log_reward("format_reward 1", completion, gt)
  124. except Exception as e:
  125. rewards.append(0.0)
  126. log_reward(f"format_reward 0 - exception {e=}", completion, gt)
  127. return rewards
  128. def execution_reward_func(completions, answer, **kwargs):
  129. """
  130. Evaluates completions based on SQL statement execution result
  131. Args:
  132. completions (list[str]): Generated outputs
  133. answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json)
  134. Returns:
  135. list[float]: Reward scores
  136. """
  137. rewards = []
  138. for completion, gt in zip(completions, answer):
  139. try:
  140. # gt = extract_answer(gt)
  141. match = re.search(r"<answer>(.*?)<\/answer>", completion)
  142. if match is None:
  143. rewards.append(0.0)
  144. log_reward("execution_reward 0 - no answer tag found", completion, gt)
  145. continue
  146. # Extract the "answer" part from the completion
  147. predicted_sql = match.group(1).strip()
  148. reason = "execution result different"
  149. # execute the sql_generated and gt and compare the results
  150. try:
  151. res = func_timeout(
  152. 30.0,
  153. execute_sql,
  154. args=(predicted_sql, gt),
  155. )
  156. except KeyboardInterrupt:
  157. sys.exit(0)
  158. except FunctionTimedOut:
  159. print("FunctionTimedOut")
  160. reason = "execution timeout"
  161. res = 0
  162. except Exception as e:
  163. print("Exception", e)
  164. reason = f"execution exception {e}"
  165. res = 0
  166. if res == 1:
  167. # reason = "execution result same"
  168. rewards.append(1.0)
  169. log_reward("execution_reward 1", completion, gt)
  170. else:
  171. rewards.append(0.0)
  172. log_reward(
  173. f"execution_reward 0 {reason=}, {predicted_sql=}",
  174. completion,
  175. gt,
  176. )
  177. except Exception as e:
  178. # If evaluation fails, reward is 0
  179. rewards.append(0.0)
  180. log_reward(f"execution_reward 0 - exception {e=}", completion, gt)
  181. return rewards
  182. def get_ngrams(tokens: List[str], n: int) -> set:
  183. """Generates a set of n-grams from a list of tokens."""
  184. # Ensure there are enough tokens to create at least one n-gram
  185. if len(tokens) < n:
  186. return set()
  187. return {tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)}
  188. def n_gram_jaccard_similarity(candidate_query: str, gold_query: str, n: int) -> float:
  189. """Calculates the n-gram Jaccard similarity for a single n."""
  190. # Tokenize the SQL queries. Using .lower() for case-insensitivity.
  191. candidate_tokens = candidate_query.lower().split()
  192. gold_tokens = gold_query.lower().split()
  193. # Get the n-grams for both sets of tokens.
  194. candidate_ngrams = get_ngrams(candidate_tokens, n)
  195. gold_ngrams = get_ngrams(gold_tokens, n)
  196. # Handle the edge case where one or both sets are empty.
  197. if not candidate_ngrams and not gold_ngrams:
  198. return 1.0
  199. if not candidate_ngrams or not gold_ngrams:
  200. return 0.0
  201. # Calculate Jaccard similarity.
  202. intersection = len(candidate_ngrams.intersection(gold_ngrams))
  203. union = len(candidate_ngrams.union(gold_ngrams))
  204. return intersection / union
  205. def ensemble_n_gram_reward_func(completions, answer, **kwargs):
  206. """
  207. Calculates the averaged ensemble n-gram Jaccard similarity reward.
  208. This function computes the Jaccard similarity for n=1, 2, and 3
  209. and returns the average score for each sample.
  210. Args:
  211. completions (list[str]): Generated outputs
  212. answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json)
  213. Returns:
  214. list[float]: Reward scores
  215. """
  216. rewards = []
  217. for completion, gt in zip(completions, answer):
  218. try:
  219. match = re.search(r"<answer>(.*?)<\/answer>", completion)
  220. if match is None:
  221. rewards.append(0.0)
  222. log_reward("n_gram_reward 0 - no answer tag found", completion, gt)
  223. continue
  224. # Extract the "answer" part from the completion
  225. predicted_sql = match.group(1).strip()
  226. # Calculate Jaccard similarity for n=1, 2, and 3
  227. jaccard_1 = n_gram_jaccard_similarity(predicted_sql, gt, n=1)
  228. jaccard_2 = n_gram_jaccard_similarity(predicted_sql, gt, n=2)
  229. jaccard_3 = n_gram_jaccard_similarity(predicted_sql, gt, n=3)
  230. # Average the scores to get the final ensemble reward
  231. average_jaccard = (jaccard_1 + jaccard_2 + jaccard_3) / 3.0
  232. rewards.append(average_jaccard)
  233. except Exception as e:
  234. rewards.append(0.0)
  235. log_reward(f"n_gram_reward 0 - exception {e=}", completion, gt)
  236. return rewards
  237. def get_checkpoint(training_args: GRPOConfig):
  238. last_checkpoint = None
  239. if os.path.isdir(training_args.output_dir):
  240. last_checkpoint = get_last_checkpoint(training_args.output_dir)
  241. return last_checkpoint
  242. def generate_schema_prompt(db_path, num_rows=None):
  243. # extract create ddls
  244. """
  245. :param root_place:
  246. :param db_name:
  247. :return:
  248. """
  249. full_schema_prompt_list = []
  250. conn = sqlite3.connect(db_path)
  251. # Create a cursor object
  252. cursor = conn.cursor()
  253. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  254. tables = cursor.fetchall()
  255. schemas = {}
  256. for table in tables:
  257. if table == "sqlite_sequence":
  258. continue
  259. cursor.execute(
  260. "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
  261. table[0]
  262. )
  263. )
  264. create_prompt = cursor.fetchone()[0]
  265. schemas[table[0]] = create_prompt
  266. if num_rows:
  267. cur_table = table[0]
  268. if cur_table in ["order", "by", "group"]:
  269. cur_table = "`{}`".format(cur_table)
  270. cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
  271. column_names = [description[0] for description in cursor.description]
  272. values = cursor.fetchall()
  273. rows_prompt = nice_look_table(column_names=column_names, values=values)
  274. verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
  275. num_rows, cur_table, num_rows, rows_prompt
  276. )
  277. schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
  278. for k, v in schemas.items():
  279. full_schema_prompt_list.append(v)
  280. schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list)
  281. return schema_prompt
  282. def generate_comment_prompt(question, knowledge=None):
  283. knowledge_prompt = "-- External Knowledge: {}".format(knowledge)
  284. question_prompt = "-- Question: {}".format(question)
  285. result_prompt = knowledge_prompt + "\n\n" + question_prompt
  286. return result_prompt
  287. def generate_combined_prompts_one(db_path, question, knowledge=None):
  288. schema_prompt = generate_schema_prompt(db_path, num_rows=None)
  289. comment_prompt = generate_comment_prompt(question, knowledge)
  290. combined_prompts = schema_prompt + "\n\n" + comment_prompt
  291. return combined_prompts
  292. def grpo_function(
  293. model_args: ModelConfig, script_args: ScriptArguments, training_args: GRPOConfig
  294. ):
  295. logger.info(f"Model parameters {model_args}")
  296. logger.info(f"Training/evaluation parameters {training_args}")
  297. tokenizer = AutoTokenizer.from_pretrained(
  298. (
  299. script_args.tokenizer_name_or_path
  300. if script_args.tokenizer_name_or_path
  301. else model_args.model_name_or_path
  302. ),
  303. revision=model_args.model_revision,
  304. trust_remote_code=model_args.trust_remote_code,
  305. )
  306. if tokenizer.pad_token is None:
  307. tokenizer.pad_token = tokenizer.eos_token
  308. ds = []
  309. SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
  310. input_json = json.load(open(TRAIN_JSON, "r"))
  311. for i, item in tqdm(enumerate(input_json)):
  312. print(f"processing #{i+1}")
  313. db_id = item["db_id"]
  314. question = item["question"]
  315. external_knowledge = item["evidence"]
  316. SQL = item["SQL"]
  317. db_path = DB_ROOT_PATH + "/" + db_id + "/" + db_id + ".sqlite"
  318. prompt = generate_combined_prompts_one(
  319. db_path,
  320. question,
  321. knowledge=external_knowledge,
  322. )
  323. example = {
  324. "messages": [
  325. {"role": "system", "content": SYSTEM_PROMPT},
  326. {"role": "user", "content": prompt},
  327. {"role": "assistant", "content": SQL + "\t----- bird -----\t" + db_id},
  328. ]
  329. }
  330. ds.append(example)
  331. dataset_dict = {key: [d[key] for d in ds] for key in ds[0]}
  332. dataset = Dataset.from_dict(dataset_dict)
  333. def generate_r1_prompt(system_prompt, user_prompt, ground_truth):
  334. r1_prefix = [
  335. {
  336. "role": "system",
  337. "content": """You are great at reasoning and translating natural language question to SQLite SQL query. Given DB Schema, External Knowledge, and Question, your task is to first generate step-by-step reasoning, then apply the resoning to generate the SQLite select statement as the accurate translation of the Question. Enclose the step-by-step reasoning within the <think> </think> tags, and the final SQL statement within the <answer> </answer> tags, i.e. <think> reasoning steps </think> <answer> final SQL </answer>.""",
  338. },
  339. {"role": "user", "content": user_prompt},
  340. ]
  341. return {
  342. "prompt": tokenizer.apply_chat_template(
  343. r1_prefix, tokenize=False, continue_final_message=True
  344. ),
  345. "answer": ground_truth,
  346. }
  347. # convert our dataset to the r1 prompt
  348. dataset = dataset.map(
  349. lambda x: generate_r1_prompt(
  350. x["messages"][0]["content"],
  351. x["messages"][1]["content"],
  352. x["messages"][2]["content"],
  353. ),
  354. remove_columns=dataset.column_names, # Remove original columns to avoid conflicts with apply_chat_template
  355. )
  356. # split the dataset into train and test
  357. train_test_split = dataset.train_test_split(test_size=0.3)
  358. train_dataset = train_test_split["train"]
  359. eval_dataset = train_test_split["test"]
  360. print("len(train_dataset)", len(train_dataset))
  361. print(train_dataset[0])
  362. print("len(eval_dataset)", len(eval_dataset))
  363. print(eval_dataset[0])
  364. #########################
  365. # Instantiate DPO trainer
  366. #########################
  367. trainer = GRPOTrainer(
  368. model=model_args.model_name_or_path,
  369. reward_funcs=[
  370. format_reward_func,
  371. execution_reward_func,
  372. ensemble_n_gram_reward_func,
  373. ],
  374. args=training_args,
  375. train_dataset=train_dataset,
  376. eval_dataset=eval_dataset,
  377. peft_config=get_peft_config(model_args),
  378. )
  379. trainer.tokenizer = tokenizer
  380. ###############
  381. # Training loop
  382. ###############
  383. # Check for last checkpoint
  384. last_checkpoint = get_checkpoint(training_args)
  385. # JT: by default training_args.resume_from_checkpoint is None
  386. if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
  387. logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.")
  388. # Train the model
  389. logger.info(
  390. f'*** Starting training {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} for {training_args.num_train_epochs} epochs***'
  391. )
  392. train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
  393. # Log and save metrics
  394. metrics = train_result.metrics
  395. metrics["train_samples"] = len(train_dataset)
  396. trainer.log_metrics("train", metrics)
  397. trainer.save_metrics("train", metrics)
  398. trainer.save_state()
  399. logger.info("*** Training complete ***")
  400. ##################################
  401. # Save model and create model card
  402. ##################################
  403. logger.info("*** Save model ***")
  404. trainer.model.config.use_cache = True
  405. trainer.save_model(training_args.output_dir)
  406. logger.info(f"Model saved to {training_args.output_dir}")
  407. training_args.distributed_state.wait_for_everyone() # wait for all processes to load
  408. tokenizer.save_pretrained(training_args.output_dir)
  409. logger.info(f"Tokenizer saved to {training_args.output_dir}")
  410. # Save everything else on main process
  411. # if trainer.accelerator.is_main_process:
  412. # trainer.create_model_card({"tags": ["rl", "grpo", "tutorial", "philschmid"]})
  413. # push to hub if needed
  414. # if training_args.push_to_hub is True:
  415. # logger.info("Pushing to hub...")
  416. # trainer.push_to_hub()
  417. logger.info("*** Training complete! ***")
  418. def main():
  419. parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig))
  420. model_args, script_args, training_args = parser.parse_args_and_config()
  421. # print("model_args", model_args)
  422. # print("script_args", script_args)
  423. # print("training_args", training_args)
  424. # exit()
  425. # Run the main training loop
  426. grpo_function(model_args, script_args, training_args)
  427. if __name__ == "__main__":
  428. main()
  429. # two ways to run this script:
  430. # with-proxy accelerate launch --num_processes 8 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml
  431. # with-proxy nohup accelerate launch --num_processes 4 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml &