Browse Source

updated create_reasoning_dataset, llama_text2sql.py and README for finetuned with reasoning eval

Jeff Tang 4 months ago
parent
commit
6d76ea0f7e

+ 32 - 12
end-to-end-use-cases/coding/text2sql/tool/README.md

@@ -6,7 +6,7 @@ This folder contains the scripts for evaluating Llama (original and fine-tuned)
 
 We have updated and significantly simplified the original eval scripts from the BIRD [repo](https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/bird) for Llama 3 & 4 models hosted via Meta's [Llama API](https://llama.developer.meta.com) or [Together.ai](https://together.ai), as well as the fine-tuned Llama 3.1 model, so you can quickly evaluate in 1-2-3 steps how well different Llama models perform on the Text2SQL task.
 
-We have also provided end-to-end scripts for generating datasets and fine-tuning a quantized Llama 3.1 8B model to gain a **165% accuracy improvement** over the original model.
+We have also provided end-to-end scripts for generating datasets (with and without reasoning steps) and fine-tuning the quantized Llama 3.1 8B model to gain a **165% (with no reasoning) and 209% (with reasoning) accuracy improvement** over the original model.
 
 ## Llama Text2SQL Evaluation Results
 
@@ -20,7 +20,9 @@ Below are the results of the Llama models we have evaluated on the BIRD DEV data
 | Llama 4 Scout          | 44.39%             | 43.94%            |
 | Llama 4 Maverick       | 44.00%             | 41.46%            |
 
-Llama 3.1 8b quantized model: 14.02% (original) -> 37.16% (fine-tuned)
+Llama 3.1 8b quantized model: 14.02% (original)
+Fine-tuned with no reasoning dataset: 37.16%
+Fine-tuned with reasoning dataset: 43.37%
 
 ## Quick Start on Evaluating Llama on Text2SQL
 
@@ -46,6 +48,8 @@ sh download_dev_unzip.sh
 
 3. Run the evaluation script `sh llama_eval.sh`, which will use the BIRD DEV dataset (1534 examples in total) with external knowledge turned on to run the Llama model on each text question and compare the generated SQL with the gold SQL.
 
+*Note:* If your API key or model name is incorrect, the script will exit with an authentication or model not supported error.
+
 After the script completes, you'll see the accuracy of the Llama model on the BIRD DEV text2sql. For example, the total accuracy is about 54.24% with `YOUR_API_KEY` set to your Llama API key and `model='Llama-3.3-70B-Instruct'`, or about 35.07% with `YOUR_API_KEY` set to your Together API key and `model=meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo`.
 
 *Note:* To compare your evaluated accuracy of your selected Llama model with other results in the BIRD Dev leaderboard, click [here](https://bird-bench.github.io/).
@@ -101,18 +105,18 @@ This will create `train_text2sql_sft_dataset.json` and `test_text2sql_sft_datase
 {"messages":[{"content":"You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement.","role":"system"},{"content":"-- DB Schema: <DB_SCHEMA>\n\n-- External Knowledge: <KNOWLEDGE_FROM_TRAIN>\n\n-- Question: <TEXT_QUESTION>","role":"user"},{"content":"<GOLD_SQL>","role":"assistant"}]}
 ```
 
-### Supervised Fine-tuning
+### Supervised Fine-tuning (No Reasoning)
 
 First, you need to login to HuggingFace (via running `huggingface-cli login` and enter your [HF token](https://huggingface.co/settings/tokens)) and have been granted access to the [Llama 3.1 8B Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model.
 
-Then run `python trl_sft.py`. After the fine-tuning completes, you'll see the fine-tuned model saved to `llama31-8b-text2sql-fine_tuning`.
+Then run `python trl_sft.py`. After the fine-tuning completes, you'll see the fine-tuned model saved to `llama31-8b-text2sql-fine-tuned`, specified in `output_dir="llama31-8b-text2sql-fine-tuned"` of `TrainingArguments` in `trl_sft.py`.
 
 After running `tensorboard --logdir ./llama31-8b-text2sql-fine_tuning` you can open `http://localhost:6006` to see the train loss chat etc:
 
 ![](fine_tuning/train_loss.png)
 
 
-### Evaluating the fine-tuned model
+### Evaluating the fine-tuned model (No Reasoning)
 
 First, modify `llama_eval.sh` to use the fine-tuned model:
 
@@ -155,7 +159,7 @@ cd fine_tuning
 python create_reasoning_dataset.py --input_json ../data/train/train.json --db_root_path ../data/train/train_databases
 ```
 
-This will create `text2sql_cot_dataset` dataset in the conversation format ready for fine-tuning. Each example in the dataset is generated from the code snippet below:
+This will create a `text2sql_cot_dataset` dataset and `train_text2sql_cot_dataset.json` in the conversation format ready for fine-tuning. Each example in the dataset is generated from the code snippet below:
 
 ```
 prompt = f"""
@@ -191,10 +195,26 @@ Let me think through this step by step:\n\n1. First, I need to consider...\n2. T
 """
 ```
 
+### Supervised Fine-tuning (With Reasoning)
+
+Uncomment the line `# FT_DATASET = "train_text2sql_cot_dataset.json"` in trl_sft.py to use the reasoning dataset for fine-tuning. Then run `python trl_sft.py`. After the fine-tuning completes, you'll see the fine-tuned model saved to `llama31-8b-text2sql-fine-tuned`, specified in `output_dir="llama31-8b-text2sql-fine-tuned"` of `TrainingArguments` in `trl_sft.py` - you may want to rename the `output_dir` folder to something else to avoid overwriting the previous fine-tuned model.
+
+### Evaluating the fine-tuned model (With Reasoning)
+
+First, modify `llama_eval.sh` to use the fine-tuned model, which should match the `output_dir` in `TrainingArguments` in `trl_sft.py`:
+
+```
+YOUR_API_KEY='finetuned'
+model='fine_tuning/llama31-8b-text2sql-fine-tuned'
+```
+
+Then uncomment the line `SYSTEM_PROMPT` [here](https://github.com/meta-llama/llama-cookbook/blob/text2sql/end-to-end-use-cases/coding/text2sql/tool/llama_text2sql.py#L31) in `llama_text2sql.py` to use it with the reasoning dataset fine-tuned model.
+
+Now run `sh llama_eval.sh`, which will take longer because the reasoning is needed to generate the SQL. The accuracy this time is 43.37%, compared with 37.16% without reasoning. This is another 16% improvement over the model with fine-tuning without reasoning.
+
 ## Next Steps
-1. Fine-tune the model with the reasoning dataset and evaluate its accuracy.
-2. Add a Colab notebook for fine-tuning and evaluation.
-3. Try reinforcement fine-tuning to improve the accuracy further with reasoning.
-4. Use torchtune for full and non-quantized fine-tuning of Llama 3.3 70b and Llama 4 models.
-5. Introduce agent to try to improve the accuracy further.
-6. Expand the tool to support other databases.
+1. Add a Colab notebook for fine-tuning and evaluation.
+2. Try reinforcement fine-tuning to improve the accuracy further with reasoning.
+3. Use torchtune for full and non-quantized fine-tuning of Llama 3.3 70b and Llama 4 models.
+4. Introduce agent to try to improve the accuracy further.
+5. Expand the tool to support other databases.

+ 13 - 18
end-to-end-use-cases/coding/text2sql/tool/llama_eval.sh

@@ -2,36 +2,31 @@ eval_path='./data/dev_20240627/dev.json'
 db_root_path='./data/dev_20240627/dev_databases/'
 ground_truth_path='./data/'
 
-# Llama model on Hugging Face Hub
-# https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct
-# YOUR_API_KEY='huggingface'
-# model='meta-llama/Llama-3.1-8B-Instruct'
-
-# Fine-tuned Llama model locally
-#YOUR_API_KEY='finetuned'
-#model='fine_tuning/llama31-8b-text2sql-epochs-3'
-#model='fine_tuning/llama31-8b-text2sql-epochs-8'
-
-YOUR_API_KEY='xxx'
 # Llama models on Together
+#YOUR_API_KEY='YOUR_TOGETHER_API_KEY'
 #model='meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
 #model='meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo'
-model='meta-llama/Llama-3.3-70B-Instruct-Turbo'
+#model='meta-llama/Llama-3.3-70B-Instruct-Turbo'
 #model='meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8'
 #model='meta-llama/Llama-4-Scout-17B-16E-Instruct'
 
-#YOUR_API_KEY='yyy'
 # Llama models on Llama API
-#model='Llama-3.3-8B-Instruct'
+YOUR_API_KEY='YOUR_LLAMA_API_KEY'
+model='Llama-3.3-8B-Instruct'
 #model='Llama-3.3-70B-Instruct'
 #model='Llama-4-Maverick-17B-128E-Instruct-FP8'
 #model='Llama-4-Scout-17B-16E-Instruct-FP8'
 
-#model="llama31-8b-text-sql-epochs-25"
-#model="llama31-8b-text-sql-epochs-3"
-#model="llama31-8b-text-sql"
+# Llama model on Hugging Face Hub
+# https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct
+# YOUR_API_KEY='huggingface'
+# model='meta-llama/Llama-3.1-8B-Instruct'
+
+# Fine-tuned Llama models locally
+#YOUR_API_KEY='finetuned'
+#model='fine_tuning/llama31-8b-text2sql-fine-tuned'
 
-data_output_path="./output/$model/v2/"
+data_output_path="./output/$model/"
 
 echo "Text2SQL using $model"
 python3 -u llama_text2sql.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} \

+ 13 - 3
end-to-end-use-cases/coding/text2sql/tool/llama_text2sql.py

@@ -14,6 +14,7 @@ import sqlparse
 import torch
 from datasets import Dataset, load_dataset
 from langchain_together import ChatTogether
+from llama_api_client import LlamaAPIClient
 from peft import AutoPeftModelForCausalLM
 from tqdm import tqdm
 from transformers import (
@@ -26,6 +27,8 @@ from transformers import (
 
 def local_llama(prompt, pipe):
     SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
+    # UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
+    # SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
 
     messages = [
         {"content": SYSTEM_PROMPT, "role": "system"},
@@ -51,10 +54,17 @@ def local_llama(prompt, pipe):
         pad_token_id=pipe.tokenizer.pad_token_id,
     )
 
-    generated_answer = outputs[0]["generated_text"][len(raw_prompt) :].strip()
+    answer = outputs[0]["generated_text"][len(raw_prompt) :].strip()
 
-    print(f"{generated_answer=}")
-    return generated_answer
+    pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
+    matches = pattern.findall(answer)
+    if matches != []:
+        result = matches[0]
+    else:
+        result = answer
+
+    print(f"{result=}")
+    return result
 
 
 def new_directory(path):