|
@@ -1,8 +1,12 @@
|
|
|
# Enhancing Text-to-SQL with CoT: A Fine-Tuning Approach with Llama
|
|
# Enhancing Text-to-SQL with CoT: A Fine-Tuning Approach with Llama
|
|
|
|
|
|
|
|
-This folder contains scripts to generate datasets from the BIRD TRAIN set with and, for comparison, without CoT (Chain-of-Thought) and scripts to supervised fine-tune (SFT), as the first step, the Llama 3.1 8B model. We observed a **165% improvement on the fine-tuned model without CoT (accuracy 37.16%) and 209% with CoT (accuracy 43.37%) ** over the original model (accuracy 14.02%).
|
|
|
|
|
|
|
+CoT stands for Chain of Thought and we will use "CoT" and "reasoning" interchangeably here, although generally, reasoning encompasses a broader concept than CoT.
|
|
|
|
|
|
|
|
-Note: In this document, we will use "CoT" and "reasoning" interchangeably, although generally, reasoning encompasses a broader concept than CoT.
|
|
|
|
|
|
|
+This folder contains scripts to:
|
|
|
|
|
+
|
|
|
|
|
+* generate a dataset from the BIRD TRAIN set (with no CoT info) for supervised fine-tuning (SFT);
|
|
|
|
|
+* generate a dataset from the BIRD TRAIN set (with CoT info by Llama 3.3 70B) for SFT;
|
|
|
|
|
+* SFT the Llama 3.1 8B model with the generated datasets with different fine-tuning combinations: with or without CoT, using quantization or not, full fine-tuning (FFT) or parameter-efficient fine-tuning (PEFT).
|
|
|
|
|
|
|
|
## SFT with the BIRD TRAIN dataset (No Reasoning)
|
|
## SFT with the BIRD TRAIN dataset (No Reasoning)
|
|
|
|
|
|
|
@@ -33,7 +37,24 @@ This will create `train_text2sql_sft_dataset.json` and `test_text2sql_sft_datase
|
|
|
|
|
|
|
|
First, you need to login to HuggingFace (via running `huggingface-cli login` and enter your [HF token](https://huggingface.co/settings/tokens)) and have been granted access to the [Llama 3.1 8B Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model.
|
|
First, you need to login to HuggingFace (via running `huggingface-cli login` and enter your [HF token](https://huggingface.co/settings/tokens)) and have been granted access to the [Llama 3.1 8B Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model.
|
|
|
|
|
|
|
|
-Then run `python trl_sft.py`. After the fine-tuning completes, you'll see the fine-tuned model saved to `llama31-8b-text2sql-fine-tuned`, specified in `output_dir="llama31-8b-text2sql-fine-tuned"` of `TrainingArguments` in `trl_sft.py`.
|
|
|
|
|
|
|
+Then run one of the commands below (`trl_sft.py` has three command line parameters: `--quantized`, `--peft`, and `--cot`, all with true or false values):
|
|
|
|
|
+
|
|
|
|
|
+```
|
|
|
|
|
+python trl_sft.py --quantized false --peft true --cot false
|
|
|
|
|
+python trl_sft.py --quantized false --peft false --cot false
|
|
|
|
|
+python trl_sft.py --quantized true --peft true --cot false
|
|
|
|
|
+```
|
|
|
|
|
+
|
|
|
|
|
+Note that we don't recommend using the quantized version with FFT (--peft false).
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+After the fine-tuning completes, you'll see the fine-tuned model saved in one of the following folders, as specified in `output_dir` of `SFTConfig` in `trl_sft.py`:
|
|
|
|
|
+
|
|
|
|
|
+```
|
|
|
|
|
+llama31-8b-text2sql-fft-nonquantized-nocot
|
|
|
|
|
+lama31-8b-text2sql-peft-nonquantized-nocot
|
|
|
|
|
+llama31-8b-text2sql-peft-quantized-nocot
|
|
|
|
|
+```
|
|
|
|
|
|
|
|
After running `tensorboard --logdir ./llama31-8b-text2sql-fine_tuning` you can open `http://localhost:6006` to see the train loss chat etc:
|
|
After running `tensorboard --logdir ./llama31-8b-text2sql-fine_tuning` you can open `http://localhost:6006` to see the train loss chat etc:
|
|
|
|
|
|
|
@@ -42,11 +63,11 @@ After running `tensorboard --logdir ./llama31-8b-text2sql-fine_tuning` you can o
|
|
|
|
|
|
|
|
### Evaluating the fine-tuned model (No Reasoning)
|
|
### Evaluating the fine-tuned model (No Reasoning)
|
|
|
|
|
|
|
|
-First, modify `llama_eval.sh` to use the fine-tuned model:
|
|
|
|
|
|
|
+First, set the `model` value in `llama_eval.sh` to be one of the fine-tuned model folders above, e.g.
|
|
|
|
|
|
|
|
```
|
|
```
|
|
|
YOUR_API_KEY='finetuned'
|
|
YOUR_API_KEY='finetuned'
|
|
|
-model='fine_tuning/llama31-8b-text2sql'
|
|
|
|
|
|
|
+model='fine_tuning/llama31-8b-text2sql-fft-nonquantized-nocot'
|
|
|
```
|
|
```
|
|
|
|
|
|
|
|
Then run `sh llama_eval.sh` to evaluate the fine-tuned model. The accuracy on the BIRD DEV dataset is about 37.16%. This is a 165% improvement over the model before fine-tuning, which has an accuracy of about 14.02% on the same dataset - you can confirm this by comparing the fine-tuned model's accuracy above with the original model's accuracy by modifying `llama_eval.sh` to use the original model:
|
|
Then run `sh llama_eval.sh` to evaluate the fine-tuned model. The accuracy on the BIRD DEV dataset is about 37.16%. This is a 165% improvement over the model before fine-tuning, which has an accuracy of about 14.02% on the same dataset - you can confirm this by comparing the fine-tuned model's accuracy above with the original model's accuracy by modifying `llama_eval.sh` to use the original model:
|
|
@@ -58,16 +79,6 @@ model='meta-llama/Llama-3.1-8B-Instruct'
|
|
|
|
|
|
|
|
Then running `sh llama_eval.sh` to evaluate the original model.
|
|
Then running `sh llama_eval.sh` to evaluate the original model.
|
|
|
|
|
|
|
|
-*Note:* We are using the 4-bit quantized Llama 3.1 8b model to reduce the memory footprint and improve the efficiency (as shown in the code nippet of llama_text2sql.py below), hence the accuracy of the quantized version (14.02%) is quite lower than the accuracy of the original Llama 3.1 8b (35.66%).
|
|
|
|
|
-
|
|
|
|
|
-```
|
|
|
|
|
- bnb_config = BitsAndBytesConfig(
|
|
|
|
|
- load_in_4bit=True,
|
|
|
|
|
- bnb_4bit_use_double_quant=True,
|
|
|
|
|
- bnb_4bit_quant_type="nf4",
|
|
|
|
|
- bnb_4bit_compute_dtype=torch.bfloat16,
|
|
|
|
|
- )
|
|
|
|
|
-```
|
|
|
|
|
|
|
|
|
|
## SFT with the BIRD TRAIN dataset (With Reasoning)
|
|
## SFT with the BIRD TRAIN dataset (With Reasoning)
|
|
|
|
|
|
|
@@ -120,18 +131,34 @@ Let me think through this step by step:\n\n1. First, I need to consider...\n2. T
|
|
|
|
|
|
|
|
### SFT (With Reasoning)
|
|
### SFT (With Reasoning)
|
|
|
|
|
|
|
|
-Uncomment the line `# FT_DATASET = "train_text2sql_cot_dataset.json"` in trl_sft.py to use the reasoning dataset for fine-tuning. Then run `python trl_sft.py`. After the fine-tuning completes, you'll see the fine-tuned model saved to `llama31-8b-text2sql-fine-tuned`, specified in `output_dir="llama31-8b-text2sql-fine-tuned"` of `TrainingArguments` in `trl_sft.py` - you may want to rename the `output_dir` folder to something else to avoid overwriting the previous fine-tuned model.
|
|
|
|
|
|
|
+Run one of the commands below:
|
|
|
|
|
+
|
|
|
|
|
+```
|
|
|
|
|
+python trl_sft.py --quantized false --peft true --cot true
|
|
|
|
|
+python trl_sft.py --quantized false --peft false --cot true
|
|
|
|
|
+python trl_sft.py --quantized true --peft true --cot true
|
|
|
|
|
+```
|
|
|
|
|
+
|
|
|
|
|
+Again we don't recommend using the quantized version with FFT.
|
|
|
|
|
+
|
|
|
|
|
+After the fine-tuning completes, you'll see the fine-tuned model saved in one of the following folders, as specified in `output_dir` of `SFTConfig` in `trl_sft.py`:
|
|
|
|
|
+
|
|
|
|
|
+```
|
|
|
|
|
+llama31-8b-text2sql-fft-nonquantized-cot
|
|
|
|
|
+lama31-8b-text2sql-peft-nonquantized-cot
|
|
|
|
|
+llama31-8b-text2sql-peft-quantized-cot
|
|
|
|
|
+```
|
|
|
|
|
|
|
|
The train loss chart will look like this:
|
|
The train loss chart will look like this:
|
|
|

|
|

|
|
|
|
|
|
|
|
### Evaluating the fine-tuned model (With Reasoning)
|
|
### Evaluating the fine-tuned model (With Reasoning)
|
|
|
|
|
|
|
|
-First, modify `llama_eval.sh` to use the fine-tuned model, which should match the `output_dir` in `TrainingArguments` in `trl_sft.py`:
|
|
|
|
|
|
|
+First, set the `model` value in `llama_eval.sh` to be one of the fine-tuned model folders above, e.g.
|
|
|
|
|
|
|
|
```
|
|
```
|
|
|
YOUR_API_KEY='finetuned'
|
|
YOUR_API_KEY='finetuned'
|
|
|
-model='fine_tuning/llama31-8b-text2sql-fine-tuned'
|
|
|
|
|
|
|
+model='fine_tuning/llama31-8b-text2sql-fft-nonquantized-cot'
|
|
|
```
|
|
```
|
|
|
|
|
|
|
|
Then uncomment the line `SYSTEM_PROMPT` [here](https://github.com/meta-llama/llama-cookbook/blob/text2sql/end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py#L31) in `llama_text2sql.py` to use it with the reasoning dataset fine-tuned model.
|
|
Then uncomment the line `SYSTEM_PROMPT` [here](https://github.com/meta-llama/llama-cookbook/blob/text2sql/end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py#L31) in `llama_text2sql.py` to use it with the reasoning dataset fine-tuned model.
|