|
@@ -14,96 +14,17 @@ The eval results of SFT Llama 3.1 8B with different options (epochs is 3, with a
|
|
|
|
|
|
| Fine-tuning Combination | Accuracy |
|
|
|
|-----------------------------|-------------------------------|
|
|
|
-| Non-Quantized, CoT, PEFT | 43.35% |
|
|
|
-| Quantized, CoT, PEFT | 42.89% |
|
|
|
-| Non-Quantized, CoT, FFT | 42.44% (43.87% for 10 epochs) |
|
|
|
-| Non-Quantized, No CoT, PEFT | 39.31% |
|
|
|
-| Quantized, No CoT, PEFT | 39.31% |
|
|
|
-| Non-Quantized, No CoT, FFT | 36.31% (38.27% for 10 epochs) |
|
|
|
-| Quantized, CoT, FFT | N/A |
|
|
|
-| Quantized, No CoT, FFT | N/A |
|
|
|
+| baseline | 39.47% |
|
|
|
+| CoT, PEFT | 43.35% |
|
|
|
+| CoT, FFT | 42.44% (3 epochs) |
|
|
|
+| CoT, FFT | 43.87% (10 epochs) |
|
|
|
|
|
|
-The table above shows that:
|
|
|
|
|
|
-1. The CoT FFT/PEFT model (with or without quantization) outperforms the no CoT FFT/PEFT model (with or without quantization) by 3.5% to 6.1%.
|
|
|
+Using Quantization+PEFT on CoT dataset only dropped the accuracy from 43.35% to 42.89%.
|
|
|
|
|
|
-2. The non-quantized PEFT model (CoT or not) is slightly better than the non-quantized FFT model.
|
|
|
+## Creating dataset
|
|
|
|
|
|
-## SFT with the BIRD TRAIN dataset (No Reasoning)
|
|
|
-
|
|
|
-We'll first use the BIRD TRAIN dataset to prepare for supervised fine-tuning with no reasoning info in the dataset.
|
|
|
-
|
|
|
-### Using the TRAIN to prepare for supervised fine-tuning
|
|
|
-
|
|
|
-1. Get the TRAIN dataset:
|
|
|
-```
|
|
|
-cd data
|
|
|
-sh download_train_unzip.sh
|
|
|
-```
|
|
|
-
|
|
|
-2. Create the dataset
|
|
|
-
|
|
|
-```
|
|
|
-cd ../fine_tuning
|
|
|
-python create_sft_dataset.py --input_json ../data/train/train.json --db_root_path ../data/train/train_databases
|
|
|
-```
|
|
|
-
|
|
|
-This will create `train_text2sql_sft_dataset.json` and `test_text2sql_sft_dataset.json` using the TRAIN set. Each line in the json files is in the conversation format ready for fine-tuning:
|
|
|
-
|
|
|
-```
|
|
|
-{"messages":[{"content":"You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement.","role":"system"},{"content":"-- DB Schema: <DB_SCHEMA>\n\n-- External Knowledge: <KNOWLEDGE_FROM_TRAIN>\n\n-- Question: <TEXT_QUESTION>","role":"user"},{"content":"<GOLD_SQL>","role":"assistant"}]}
|
|
|
-```
|
|
|
-
|
|
|
-### SFT (No Reasoning)
|
|
|
-
|
|
|
-First, you need to login to HuggingFace (via running `huggingface-cli login` and enter your [HF token](https://huggingface.co/settings/tokens)) and have been granted access to the [Llama 3.1 8B Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model.
|
|
|
-
|
|
|
-Then run one of the commands below (`trl_sft.py` has three command line parameters: `--quantized`, `--peft`, and `--cot`, all with true or false values):
|
|
|
-
|
|
|
-```
|
|
|
-python trl_sft.py --quantized false --peft true --cot false
|
|
|
-python trl_sft.py --quantized false --peft false --cot false
|
|
|
-python trl_sft.py --quantized true --peft true --cot false
|
|
|
-```
|
|
|
-
|
|
|
-Note that we don't recommend using the quantized version with FFT (--peft false).
|
|
|
-
|
|
|
-
|
|
|
-After the fine-tuning completes, you'll see the fine-tuned model saved in one of the following folders, as specified in `output_dir` of `SFTConfig` in `trl_sft.py`:
|
|
|
-
|
|
|
-```
|
|
|
-llama31-8b-text2sql-fft-nonquantized-nocot
|
|
|
-lama31-8b-text2sql-peft-nonquantized-nocot
|
|
|
-llama31-8b-text2sql-peft-quantized-nocot
|
|
|
-```
|
|
|
-
|
|
|
-After running `tensorboard --logdir ./llama31-8b-text2sql-fine_tuning` you can open `http://localhost:6006` to see the train loss chart like this:
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-### Evaluating the fine-tuned model (No Reasoning)
|
|
|
-
|
|
|
-First, set the `model` value in `llama_eval.sh` to be one of the fine-tuned model folders above, e.g.
|
|
|
-
|
|
|
-```
|
|
|
-YOUR_API_KEY='finetuned'
|
|
|
-model='fine_tuning/llama31-8b-text2sql-fft-nonquantized-nocot'
|
|
|
-```
|
|
|
-
|
|
|
-Then run `sh llama_eval.sh` to evaluate the fine-tuned model. The accuracy on the BIRD DEV dataset is about 37.16%. This is a 165% improvement over the model before fine-tuning, which has an accuracy of about 14.02% on the same dataset - you can confirm this by comparing the fine-tuned model's accuracy above with the original model's accuracy by modifying `llama_eval.sh` to use the original model:
|
|
|
-
|
|
|
-```
|
|
|
-YOUR_API_KEY='huggingface'
|
|
|
-model='meta-llama/Llama-3.1-8B-Instruct'
|
|
|
-```
|
|
|
-
|
|
|
-Then running `sh llama_eval.sh` to evaluate the original model.
|
|
|
-
|
|
|
-
|
|
|
-## SFT with the BIRD TRAIN dataset (With Reasoning)
|
|
|
-
|
|
|
-Next we'll use the BIRD TRAIN dataset to prepare for supervised fine-tuning with reasoning info in the dataset. The goal is to see if we can improve the accuracy of the fine-tuned model by adding the reasoning info in the dataset.
|
|
|
+We use the BIRD TRAIN dataset to prepare for supervised fine-tuning with reasoning info in the dataset. The goal is to see if we can improve the accuracy of the fine-tuned model by adding the reasoning info in the dataset.
|
|
|
|
|
|
### Creating a reasoning dataset from the TRAIN dataset
|
|
|
|
|
@@ -150,24 +71,19 @@ Let me think through this step by step:\n\n1. First, I need to consider...\n2. T
|
|
|
"""
|
|
|
```
|
|
|
|
|
|
-### SFT (With Reasoning)
|
|
|
+### Running fine-tuning
|
|
|
|
|
|
Run one of the commands below:
|
|
|
|
|
|
```
|
|
|
python trl_sft.py --quantized false --peft true --cot true
|
|
|
-python trl_sft.py --quantized false --peft false --cot true
|
|
|
-python trl_sft.py --quantized true --peft true --cot true
|
|
|
```
|
|
|
|
|
|
-Again we don't recommend using the quantized version with FFT.
|
|
|
-
|
|
|
After the fine-tuning completes, you'll see the fine-tuned model saved in one of the following folders, as specified in `output_dir` of `SFTConfig` in `trl_sft.py`:
|
|
|
|
|
|
```
|
|
|
llama31-8b-text2sql-fft-nonquantized-cot
|
|
|
-lama31-8b-text2sql-peft-nonquantized-cot
|
|
|
-llama31-8b-text2sql-peft-quantized-cot
|
|
|
+llama31-8b-text2sql-peft-nonquantized-cot
|
|
|
```
|
|
|
|
|
|
The train loss chart should look like this:
|