浏览代码

updated fine-tuning README

Jeff Tang 3 月之前
父节点
当前提交
b02334a9cc
共有 1 个文件被更改,包括 45 次插入18 次删除
  1. 45 18
      end-to-end-use-cases/coding/text2sql/fine-tuning/README.md

+ 45 - 18
end-to-end-use-cases/coding/text2sql/fine-tuning/README.md

@@ -1,8 +1,12 @@
 # Enhancing Text-to-SQL with CoT: A Fine-Tuning Approach with Llama
 # Enhancing Text-to-SQL with CoT: A Fine-Tuning Approach with Llama
 
 
-This folder contains scripts to generate datasets from the BIRD TRAIN set with and, for comparison, without CoT (Chain-of-Thought) and scripts to supervised fine-tune (SFT), as the first step, the Llama 3.1 8B model. We observed a **165% improvement on the fine-tuned model without CoT (accuracy 37.16%) and 209% with CoT (accuracy 43.37%) ** over the original model (accuracy 14.02%).
+CoT stands for Chain of Thought and we will use "CoT" and "reasoning" interchangeably here, although generally, reasoning encompasses a broader concept than CoT.
 
 
-Note: In this document, we will use "CoT" and "reasoning" interchangeably, although generally, reasoning encompasses a broader concept than CoT.
+This folder contains scripts to:
+
+* generate a dataset from the BIRD TRAIN set (with no CoT info) for supervised fine-tuning (SFT);
+* generate a dataset from the BIRD TRAIN set (with CoT info by Llama 3.3 70B) for SFT;
+* SFT the Llama 3.1 8B model with the generated datasets with different fine-tuning combinations: with or without CoT, using quantization or not,  full fine-tuning (FFT) or parameter-efficient fine-tuning (PEFT).
 
 
 ## SFT with the BIRD TRAIN dataset (No Reasoning)
 ## SFT with the BIRD TRAIN dataset (No Reasoning)
 
 
@@ -33,7 +37,24 @@ This will create `train_text2sql_sft_dataset.json` and `test_text2sql_sft_datase
 
 
 First, you need to login to HuggingFace (via running `huggingface-cli login` and enter your [HF token](https://huggingface.co/settings/tokens)) and have been granted access to the [Llama 3.1 8B Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model.
 First, you need to login to HuggingFace (via running `huggingface-cli login` and enter your [HF token](https://huggingface.co/settings/tokens)) and have been granted access to the [Llama 3.1 8B Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model.
 
 
-Then run `python trl_sft.py`. After the fine-tuning completes, you'll see the fine-tuned model saved to `llama31-8b-text2sql-fine-tuned`, specified in `output_dir="llama31-8b-text2sql-fine-tuned"` of `TrainingArguments` in `trl_sft.py`.
+Then run one of the commands below (`trl_sft.py` has three command line parameters: `--quantized`, `--peft`, and `--cot`, all with true or false values):
+
+```
+python trl_sft.py --quantized false --peft true --cot false
+python trl_sft.py --quantized false --peft false --cot false
+python trl_sft.py --quantized true --peft true --cot false
+```
+
+Note that we don't recommend using the quantized version with FFT (--peft false).
+
+
+After the fine-tuning completes, you'll see the fine-tuned model saved in one of the following folders, as specified in `output_dir` of `SFTConfig` in `trl_sft.py`:
+
+```
+llama31-8b-text2sql-fft-nonquantized-nocot
+lama31-8b-text2sql-peft-nonquantized-nocot
+llama31-8b-text2sql-peft-quantized-nocot
+```
 
 
 After running `tensorboard --logdir ./llama31-8b-text2sql-fine_tuning` you can open `http://localhost:6006` to see the train loss chat etc:
 After running `tensorboard --logdir ./llama31-8b-text2sql-fine_tuning` you can open `http://localhost:6006` to see the train loss chat etc:
 
 
@@ -42,11 +63,11 @@ After running `tensorboard --logdir ./llama31-8b-text2sql-fine_tuning` you can o
 
 
 ### Evaluating the fine-tuned model (No Reasoning)
 ### Evaluating the fine-tuned model (No Reasoning)
 
 
-First, modify `llama_eval.sh` to use the fine-tuned model:
+First, set the `model` value in `llama_eval.sh` to be one of the fine-tuned model folders above, e.g.
 
 
 ```
 ```
 YOUR_API_KEY='finetuned'
 YOUR_API_KEY='finetuned'
-model='fine_tuning/llama31-8b-text2sql'
+model='fine_tuning/llama31-8b-text2sql-fft-nonquantized-nocot'
 ```
 ```
 
 
 Then run `sh llama_eval.sh` to evaluate the fine-tuned model. The accuracy on the BIRD DEV dataset is about 37.16%. This is a 165% improvement over the model before fine-tuning, which has an accuracy of about 14.02% on the same dataset - you can confirm this by comparing the fine-tuned model's accuracy above with the original model's accuracy by modifying `llama_eval.sh` to use the original model:
 Then run `sh llama_eval.sh` to evaluate the fine-tuned model. The accuracy on the BIRD DEV dataset is about 37.16%. This is a 165% improvement over the model before fine-tuning, which has an accuracy of about 14.02% on the same dataset - you can confirm this by comparing the fine-tuned model's accuracy above with the original model's accuracy by modifying `llama_eval.sh` to use the original model:
@@ -58,16 +79,6 @@ model='meta-llama/Llama-3.1-8B-Instruct'
 
 
 Then running `sh llama_eval.sh` to evaluate the original model.
 Then running `sh llama_eval.sh` to evaluate the original model.
 
 
-*Note:* We are using the 4-bit quantized Llama 3.1 8b model to reduce the memory footprint and improve the efficiency (as shown in the code nippet of llama_text2sql.py below), hence the accuracy of the quantized version (14.02%) is quite lower than the accuracy of the original Llama 3.1 8b (35.66%).
-
-```
-  bnb_config = BitsAndBytesConfig(
-      load_in_4bit=True,
-      bnb_4bit_use_double_quant=True,
-      bnb_4bit_quant_type="nf4",
-      bnb_4bit_compute_dtype=torch.bfloat16,
-  )
-```
 
 
 ## SFT with the BIRD TRAIN dataset (With Reasoning)
 ## SFT with the BIRD TRAIN dataset (With Reasoning)
 
 
@@ -120,18 +131,34 @@ Let me think through this step by step:\n\n1. First, I need to consider...\n2. T
 
 
 ### SFT (With Reasoning)
 ### SFT (With Reasoning)
 
 
-Uncomment the line `# FT_DATASET = "train_text2sql_cot_dataset.json"` in trl_sft.py to use the reasoning dataset for fine-tuning. Then run `python trl_sft.py`. After the fine-tuning completes, you'll see the fine-tuned model saved to `llama31-8b-text2sql-fine-tuned`, specified in `output_dir="llama31-8b-text2sql-fine-tuned"` of `TrainingArguments` in `trl_sft.py` - you may want to rename the `output_dir` folder to something else to avoid overwriting the previous fine-tuned model.
+Run one of the commands below:
+
+```
+python trl_sft.py --quantized false --peft true --cot true
+python trl_sft.py --quantized false --peft false --cot true
+python trl_sft.py --quantized true --peft true --cot true
+```
+
+Again we don't recommend using the quantized version with FFT.
+
+After the fine-tuning completes, you'll see the fine-tuned model saved in one of the following folders, as specified in `output_dir` of `SFTConfig` in `trl_sft.py`:
+
+```
+llama31-8b-text2sql-fft-nonquantized-cot
+lama31-8b-text2sql-peft-nonquantized-cot
+llama31-8b-text2sql-peft-quantized-cot
+```
 
 
 The train loss chart will look like this:
 The train loss chart will look like this:
 ![](train_loss_cot.png)
 ![](train_loss_cot.png)
 
 
 ### Evaluating the fine-tuned model (With Reasoning)
 ### Evaluating the fine-tuned model (With Reasoning)
 
 
-First, modify `llama_eval.sh` to use the fine-tuned model, which should match the `output_dir` in `TrainingArguments` in `trl_sft.py`:
+First, set the `model` value in `llama_eval.sh` to be one of the fine-tuned model folders above, e.g.
 
 
 ```
 ```
 YOUR_API_KEY='finetuned'
 YOUR_API_KEY='finetuned'
-model='fine_tuning/llama31-8b-text2sql-fine-tuned'
+model='fine_tuning/llama31-8b-text2sql-fft-nonquantized-cot'
 ```
 ```
 
 
 Then uncomment the line `SYSTEM_PROMPT` [here](https://github.com/meta-llama/llama-cookbook/blob/text2sql/end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py#L31) in `llama_text2sql.py` to use it with the reasoning dataset fine-tuned model.
 Then uncomment the line `SYSTEM_PROMPT` [here](https://github.com/meta-llama/llama-cookbook/blob/text2sql/end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py#L31) in `llama_text2sql.py` to use it with the reasoning dataset fine-tuned model.