import json import logging import os import random import re import sqlite3 import sys from dataclasses import dataclass from datetime import datetime from typing import List from datasets import Dataset from func_timeout import func_timeout, FunctionTimedOut from tqdm import tqdm from transformers import AutoTokenizer from transformers.trainer_utils import get_last_checkpoint from trl import get_peft_config, GRPOConfig, GRPOTrainer, ModelConfig, TrlParser os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" TRAIN_JSON = "../../data/train/train.json" DB_ROOT_PATH = "../../data/train/train_databases/" LOG_REWARD_FILE_NAME = "text2sql_grpo_rewards5.log" COMPLETION_SAMPLE_TXT_FILE_NAME = "completion_samples5.txt" def load_json(dir): with open(dir, "r") as j: contents = json.loads(j.read()) return contents def execute_sql(predicted_sql, ground_truth_dbid): ground_truth, db_name = ground_truth_dbid.split("\t----- bird -----\t") # print(f"\n==== execute_sql ====\n{predicted_sql=}\n{ground_truth=}") db_path = DB_ROOT_PATH + db_name + "/" + db_name + ".sqlite" conn = sqlite3.connect(db_path) # Connect to the database cursor = conn.cursor() cursor.execute(predicted_sql) predicted_res = cursor.fetchall() cursor.execute(ground_truth) ground_truth_res = cursor.fetchall() res = 0 if set(predicted_res) == set(ground_truth_res): res = 1 print("execution result same") else: print("execution result different") conn.close() return res @dataclass class ScriptArguments: tokenizer_name_or_path: str = None ######################## # Setup logging ######################## logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) handler = logging.StreamHandler() handler.setFormatter( logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") ) logger.addHandler(handler) ######################## # Helper functions ######################## def log_reward(reason, completion, gt): import os os.makedirs("logs", exist_ok=True) log_file = os.path.join("logs", LOG_REWARD_FILE_NAME) with open(log_file, "a") as f: f.write("\n\n==============\n") f.write(f">>>{reason=}\n>>>{completion=}\n>>>{gt=}\n") def extract_answer(text): """ Extracts the final SQL statement answer from the raw text. """ try: match = re.search(r"#### (\-?[\d\.,$]+)", text) if match: matched_string = match.group(1) # Remove any characters that would cause a ValueError, # such as dollar signs ($) and commas (,) cleaned_string = re.sub(r"[$,]", "", matched_string) return float(cleaned_string) match = re.search( r"(?:The final answer is|The answer is):?\s*(\-?[\d\.,$]+)", text, re.IGNORECASE, ) if match: matched_string = match.group(1) cleaned_string = re.sub(r"[$,]", "", matched_string) return float(cleaned_string) except (ValueError, AttributeError): print(f"Error extracting answer from text: {match.group(1)}") pass return None def format_reward_func(completions, answer, **kwargs): """ Format: ...... Args: completions (list[str]): Generated outputs answer (list[str]): Expected answers Returns: list[float]: Reward scores """ rewards = [] for completion, gt in zip(completions, answer): try: if random.random() < 0.1: # 1% chance to write samples into a file os.makedirs("completion_samples", exist_ok=True) log_file = os.path.join( "completion_samples", COMPLETION_SAMPLE_TXT_FILE_NAME ) with open(log_file, "a") as f: f.write(f"\n\n==============\n") f.write(completion) # Check if the format is correct regex = r"([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\s*([\s\S]*?)<\/answer>$" match = re.search(regex, completion, re.DOTALL) # if the format is not correct, reward is 0 if match is None or len(match.groups()) != 2: rewards.append(0.0) log_reward("format_reward 0", completion, gt) else: rewards.append(1.0) log_reward("format_reward 1", completion, gt) except Exception as e: rewards.append(0.0) log_reward(f"format_reward 0 - exception {e=}", completion, gt) return rewards def execution_reward_func(completions, answer, **kwargs): """ Evaluates completions based on SQL statement execution result Args: completions (list[str]): Generated outputs answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json) Returns: list[float]: Reward scores """ rewards = [] for completion, gt in zip(completions, answer): try: # gt = extract_answer(gt) match = re.search(r"(.*?)<\/answer>", completion) if match is None: rewards.append(0.0) log_reward("execution_reward 0 - no answer tag found", completion, gt) continue # Extract the "answer" part from the completion predicted_sql = match.group(1).strip() reason = "execution result different" # execute the sql_generated and gt and compare the results try: res = func_timeout( 30.0, execute_sql, args=(predicted_sql, gt), ) except KeyboardInterrupt: sys.exit(0) except FunctionTimedOut: print("FunctionTimedOut") reason = "execution timeout" res = 0 except Exception as e: print("Exception", e) reason = f"execution exception {e}" res = 0 if res == 1: # reason = "execution result same" rewards.append(1.0) log_reward("execution_reward 1", completion, gt) else: rewards.append(0.0) log_reward( f"execution_reward 0 {reason=}, {predicted_sql=}", completion, gt, ) except Exception as e: # If evaluation fails, reward is 0 rewards.append(0.0) log_reward(f"execution_reward 0 - exception {e=}", completion, gt) return rewards def get_ngrams(tokens: List[str], n: int) -> set: """Generates a set of n-grams from a list of tokens.""" # Ensure there are enough tokens to create at least one n-gram if len(tokens) < n: return set() return {tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)} def n_gram_jaccard_similarity(candidate_query: str, gold_query: str, n: int) -> float: """Calculates the n-gram Jaccard similarity for a single n.""" # Tokenize the SQL queries. Using .lower() for case-insensitivity. candidate_tokens = candidate_query.lower().split() gold_tokens = gold_query.lower().split() # Get the n-grams for both sets of tokens. candidate_ngrams = get_ngrams(candidate_tokens, n) gold_ngrams = get_ngrams(gold_tokens, n) # Handle the edge case where one or both sets are empty. if not candidate_ngrams and not gold_ngrams: return 1.0 if not candidate_ngrams or not gold_ngrams: return 0.0 # Calculate Jaccard similarity. intersection = len(candidate_ngrams.intersection(gold_ngrams)) union = len(candidate_ngrams.union(gold_ngrams)) return intersection / union def ensemble_n_gram_reward_func(completions, answer, **kwargs): """ Calculates the averaged ensemble n-gram Jaccard similarity reward. This function computes the Jaccard similarity for n=1, 2, and 3 and returns the average score for each sample. Args: completions (list[str]): Generated outputs answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json) Returns: list[float]: Reward scores """ rewards = [] for completion, gt in zip(completions, answer): try: match = re.search(r"(.*?)<\/answer>", completion) if match is None: rewards.append(0.0) log_reward("n_gram_reward 0 - no answer tag found", completion, gt) continue # Extract the "answer" part from the completion predicted_sql = match.group(1).strip() # Calculate Jaccard similarity for n=1, 2, and 3 jaccard_1 = n_gram_jaccard_similarity(predicted_sql, gt, n=1) jaccard_2 = n_gram_jaccard_similarity(predicted_sql, gt, n=2) jaccard_3 = n_gram_jaccard_similarity(predicted_sql, gt, n=3) # Average the scores to get the final ensemble reward average_jaccard = (jaccard_1 + jaccard_2 + jaccard_3) / 3.0 rewards.append(average_jaccard) except Exception as e: rewards.append(0.0) log_reward(f"n_gram_reward 0 - exception {e=}", completion, gt) return rewards def get_checkpoint(training_args: GRPOConfig): last_checkpoint = None if os.path.isdir(training_args.output_dir): last_checkpoint = get_last_checkpoint(training_args.output_dir) return last_checkpoint def generate_schema_prompt(db_path, num_rows=None): # extract create ddls """ :param root_place: :param db_name: :return: """ full_schema_prompt_list = [] conn = sqlite3.connect(db_path) # Create a cursor object cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") tables = cursor.fetchall() schemas = {} for table in tables: if table == "sqlite_sequence": continue cursor.execute( "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format( table[0] ) ) create_prompt = cursor.fetchone()[0] schemas[table[0]] = create_prompt if num_rows: cur_table = table[0] if cur_table in ["order", "by", "group"]: cur_table = "`{}`".format(cur_table) cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows)) column_names = [description[0] for description in cursor.description] values = cursor.fetchall() rows_prompt = nice_look_table(column_names=column_names, values=values) verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format( num_rows, cur_table, num_rows, rows_prompt ) schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt) for k, v in schemas.items(): full_schema_prompt_list.append(v) schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list) return schema_prompt def generate_comment_prompt(question, knowledge=None): knowledge_prompt = "-- External Knowledge: {}".format(knowledge) question_prompt = "-- Question: {}".format(question) result_prompt = knowledge_prompt + "\n\n" + question_prompt return result_prompt def generate_combined_prompts_one(db_path, question, knowledge=None): schema_prompt = generate_schema_prompt(db_path, num_rows=None) comment_prompt = generate_comment_prompt(question, knowledge) combined_prompts = schema_prompt + "\n\n" + comment_prompt return combined_prompts def grpo_function( model_args: ModelConfig, script_args: ScriptArguments, training_args: GRPOConfig ): logger.info(f"Model parameters {model_args}") logger.info(f"Training/evaluation parameters {training_args}") tokenizer = AutoTokenizer.from_pretrained( ( script_args.tokenizer_name_or_path if script_args.tokenizer_name_or_path else model_args.model_name_or_path ), revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token ds = [] SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement." input_json = json.load(open(TRAIN_JSON, "r")) for i, item in tqdm(enumerate(input_json)): print(f"processing #{i+1}") db_id = item["db_id"] question = item["question"] external_knowledge = item["evidence"] SQL = item["SQL"] db_path = DB_ROOT_PATH + "/" + db_id + "/" + db_id + ".sqlite" prompt = generate_combined_prompts_one( db_path, question, knowledge=external_knowledge, ) example = { "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}, {"role": "assistant", "content": SQL + "\t----- bird -----\t" + db_id}, ] } ds.append(example) dataset_dict = {key: [d[key] for d in ds] for key in ds[0]} dataset = Dataset.from_dict(dataset_dict) def generate_r1_prompt(system_prompt, user_prompt, ground_truth): r1_prefix = [ { "role": "system", "content": """You are great at reasoning and translating natural language question to SQLite SQL query. Given DB Schema, External Knowledge, and Question, your task is to first generate step-by-step reasoning, then apply the resoning to generate the SQLite select statement as the accurate translation of the Question. Enclose the step-by-step reasoning within the tags, and the final SQL statement within the tags, i.e. reasoning steps final SQL .""", }, {"role": "user", "content": user_prompt}, ] return { "prompt": tokenizer.apply_chat_template( r1_prefix, tokenize=False, continue_final_message=True ), "answer": ground_truth, } # convert our dataset to the r1 prompt dataset = dataset.map( lambda x: generate_r1_prompt( x["messages"][0]["content"], x["messages"][1]["content"], x["messages"][2]["content"], ), remove_columns=dataset.column_names, # Remove original columns to avoid conflicts with apply_chat_template ) # split the dataset into train and test train_test_split = dataset.train_test_split(test_size=0.3) train_dataset = train_test_split["train"] eval_dataset = train_test_split["test"] print("len(train_dataset)", len(train_dataset)) print(train_dataset[0]) print("len(eval_dataset)", len(eval_dataset)) print(eval_dataset[0]) ######################### # Instantiate DPO trainer ######################### trainer = GRPOTrainer( model=model_args.model_name_or_path, reward_funcs=[ format_reward_func, execution_reward_func, ensemble_n_gram_reward_func, ], args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, peft_config=get_peft_config(model_args), ) trainer.tokenizer = tokenizer ############### # Training loop ############### # Check for last checkpoint last_checkpoint = get_checkpoint(training_args) # JT: by default training_args.resume_from_checkpoint is None if last_checkpoint is not None and training_args.resume_from_checkpoint is None: logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.") # Train the model logger.info( f'*** Starting training {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} for {training_args.num_train_epochs} epochs***' ) train_result = trainer.train(resume_from_checkpoint=last_checkpoint) # Log and save metrics metrics = train_result.metrics metrics["train_samples"] = len(train_dataset) trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() logger.info("*** Training complete ***") ################################## # Save model and create model card ################################## logger.info("*** Save model ***") trainer.model.config.use_cache = True trainer.save_model(training_args.output_dir) logger.info(f"Model saved to {training_args.output_dir}") training_args.distributed_state.wait_for_everyone() # wait for all processes to load tokenizer.save_pretrained(training_args.output_dir) logger.info(f"Tokenizer saved to {training_args.output_dir}") # Save everything else on main process # if trainer.accelerator.is_main_process: # trainer.create_model_card({"tags": ["rl", "grpo", "tutorial", "philschmid"]}) # push to hub if needed # if training_args.push_to_hub is True: # logger.info("Pushing to hub...") # trainer.push_to_hub() logger.info("*** Training complete! ***") def main(): parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig)) model_args, script_args, training_args = parser.parse_args_and_config() # print("model_args", model_args) # print("script_args", script_args) # print("training_args", training_args) # exit() # Run the main training loop grpo_function(model_args, script_args, training_args) if __name__ == "__main__": main() # two ways to run this script: # with-proxy accelerate launch --num_processes 8 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml # with-proxy nohup accelerate launch --num_processes 4 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml &