Prechádzať zdrojové kódy

README and create_reasoning_dataset.py and trl_sft.py update

Jeff Tang 4 mesiacov pred
rodič
commit
46d3245425

+ 21 - 20
end-to-end-use-cases/coding/text2sql/tool/README.md

@@ -6,7 +6,20 @@ This folder contains the scripts for evaluating Llama (original and fine-tuned)
 
 We have significantly simplified the original eval scripts from the BIRD [repo](https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/bird) for Llama models hosted via Meta's [Llama API](https://llama.developer.meta.com) or [Together.ai](https://together.ai), so you can quickly evaluate in 1-2-3 steps how well different Llama models perform on the Text2SQL task.
 
-We have also provided end-to-end scripts for generating datasets and fine-tuning a quantized Llama 3.1 8B model to gain a 165% accuracy improvement over the original model.
+We have also provided end-to-end scripts for generating datasets and fine-tuning a quantized Llama 3.1 8B model to gain a **165% accuracy improvement** over the original model.
+
+## Llama Text2SQL Evaluation Results
+
+Below are the results of the Llama models we have evaluated on the BIRD DEV dataset:
+
+| Model                  | Llama API Accuracy | Together Accuracy |
+|------------------------|--------------------|-------------------|
+| Llama 3.1 8b           | -                  | 35.66%            |
+| Llama 3.3 70b          | 54.11%             | 54.63%            |
+| Llama-3.1-405B         | -                  | 55.80%            |
+| Llama 4 Scout          | 44.39%             | 43.94%            |
+| Llama 4 Maverick       | 44.00%             | 41.46%            |
+
 
 ## Quick Start on Evaluating Llama on Text2SQL
 
@@ -36,17 +49,15 @@ After the script completes, you'll see the accuracy of the Llama model on the BI
 
 *Note:* To compare your evaluated accuracy of your selected Llama model with other results in the BIRD Dev leaderboard, click [here](https://bird-bench.github.io/).
 
-### Evaluation Results
+## Evaluation Process
 
-Below are the results of the Llama models we have evaluated on the BIRD DEV dataset:
+1. **SQL Generation**: `llama_text2sql.py` sends natural language questions to the specified Llama model and collects the generated SQL queries.
 
-| Model                  | Llama API Accuracy | Together Accuracy |
-|------------------------|--------------------|-------------------|
-| Llama 3.1 8b           | -                  | 35.66%            |
-| Llama 3.3 70b          | 54.11%             | 54.63%            |
-| Llama-3.1-405B         | -                  | 55.80%            |
-| Llama 4 Scout          | 44.39%             | 43.94%            |
-| Llama 4 Maverick       | 44.00%             | 41.46%            |
+2. **SQL Execution**: `text2sql_eval.py` executes both the generated SQL and ground truth SQL against the corresponding databases, then continues with steps 3 and 4 below.
+
+3. **Result Comparison**: The results from executing the generated SQL are compared with the results from the ground truth SQL to determine correctness.
+
+4. **Accuracy Calculation**: Accuracy scores are calculated overall and broken down by difficulty levels (simple, moderate, challenging).
 
 ## Supported Models
 
@@ -64,16 +75,6 @@ Below are the results of the Llama models we have evaluated on the BIRD DEV data
 - Llama-4-Scout-17B-16E-Instruct-FP8
 - other Llama models hosted on Llama API
 
-## Evaluation Process
-
-1. **SQL Generation**: `llama_text2sql.py` sends natural language questions to the specified Llama model and collects the generated SQL queries.
-
-2. **SQL Execution**: `text2sql_eval.py` executes both the generated SQL and ground truth SQL against the corresponding databases, then continues with steps 3 and 4 below.
-
-3. **Result Comparison**: The results from executing the generated SQL are compared with the results from the ground truth SQL to determine correctness.
-
-4. **Accuracy Calculation**: Accuracy scores are calculated overall and broken down by difficulty levels (simple, moderate, challenging).
-
 ## Preparing Fine-tuning Dataset
 
 ### Using the TRAIN to prepare for supervised fine-tuning

+ 20 - 3
end-to-end-use-cases/coding/text2sql/tool/fine_tuning/create_reasoning_dataset.py

@@ -8,7 +8,7 @@ import sqlite3
 from typing import Dict, List, Tuple
 
 import sqlparse
-from datasets import Dataset
+from datasets import Dataset, load_from_disk
 
 from langchain_together import ChatTogether
 from llama_api_client import LlamaAPIClient
@@ -139,6 +139,16 @@ def generate_schema_prompt(db_path, num_rows=None):
     return schema_prompt
 
 
+def create_conversation(sample):
+    return {
+        "messages": [
+            {"role": "system", "content": sample["messages"][0]["content"]},
+            {"role": "user", "content": sample["messages"][1]["content"]},
+            {"role": "assistant", "content": sample["messages"][2]["content"]},
+        ]
+    }
+
+
 def create_cot_dataset(input_json, db_root_path):
     cot_list = []
     diff = 0
@@ -174,8 +184,6 @@ def create_cot_dataset(input_json, db_root_path):
             .replace("{gold_SQL}", gold_SQL)
         )
         reasoning = llama(prompt_to_generate_reasoning)
-        # print(f"\n======\n{prompt_to_generate_reasoning=}\n\n")
-        # print(f"\n======\n{reasoning=}\n\n")
 
         pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
         matches = pattern.findall(reasoning)
@@ -213,6 +221,15 @@ def create_cot_dataset(input_json, db_root_path):
     hf_dataset = Dataset.from_dict(dataset_dict)
     hf_dataset.save_to_disk(f"text2sql_cot_dataset")
 
+    dataset = load_from_disk("text2sql_cot_dataset")
+    dataset = dataset.map(
+        create_conversation, remove_columns=dataset.features, batched=False
+    )
+    dataset = dataset.train_test_split(test_size=0.3)
+
+    dataset["train"].to_json("train_text2sql_cot_dataset.json", orient="records")
+    dataset["test"].to_json("test_text2sql_cot_dataset.json", orient="records")
+
 
 if __name__ == "__main__":
     args_parser = argparse.ArgumentParser()

+ 7 - 5
end-to-end-use-cases/coding/text2sql/tool/fine_tuning/trl_sft.py

@@ -11,9 +11,11 @@ from transformers import (
 )
 from trl import setup_chat_format, SFTTrainer
 
-dataset = load_dataset(
-    "json", data_files="train_text2sql_sft_dataset.json", split="train"
-)
+FT_DATASET = "train_text2sql_sft_dataset.json"
+# uncomment to use the reasoning dataset created by "create_reasoning_dataset.py"
+# FT_DATASET = "train_text2sql_cot_dataset.json"
+
+dataset = load_dataset("json", data_files=SFT_DATASET, split="train")
 
 model_id = "meta-llama/Llama-3.1-8B-Instruct"
 
@@ -48,8 +50,8 @@ peft_config = LoraConfig(
 )
 
 args = TrainingArguments(
-    output_dir="llama31-8b-text2sql-epochs-20",  # directory to save and repository id
-    num_train_epochs=20,  # number of training epochs
+    output_dir="llama31-8b-text2sql-fine-tuned",  # directory to save and repository id
+    num_train_epochs=3,  # number of training epochs
     per_device_train_batch_size=3,  # batch size per device during training
     gradient_accumulation_steps=2,  # number of steps before performing a backward/update pass
     gradient_checkpointing=True,  # use gradient checkpointing to save memory