Ver código fonte

Merge branch 'text2sql' of https://github.com/meta-llama/llama-cookbook into text2sql

Jeff Tang 3 dias atrás
pai
commit
deca42ccd8

+ 31 - 4
end-to-end-use-cases/coding/text2sql/eval/README.md

@@ -1,6 +1,6 @@
 # Llama Text2SQL Evaluation
 # Llama Text2SQL Evaluation
 
 
-We have updated and simplified the original eval scripts from the BIRD [repo](https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/bird) to 3 simple steps for Llama 3 & 4 models hosted via Meta's [Llama API](https://llama.developer.meta.com) or [Together.ai](https://together.ai), as well as the fine-tuned Llama 3.1 model.
+We have updated and simplified the original eval scripts from the BIRD [repo](https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/bird) to 3 simple steps for Llama 3 & 4 models hosted via Meta's [Llama API](https://llama.developer.meta.com), as well as Llama 3.1 8B on Hugging Face and its fine-tuned models.
 
 
 ## Evaluation Results
 ## Evaluation Results
 
 
@@ -13,9 +13,9 @@ Below are the results of the Llama models we have evaluated on the BIRD DEV data
 | Llama 4 Scout          | 44.39%             |
 | Llama 4 Scout          | 44.39%             |
 | Llama 4 Maverick       | 44.00%             |
 | Llama 4 Maverick       | 44.00%             |
 
 
-- Since Llama API does not have Llama 3.1 8b model, we use Hugging Face weights to run locally.
+- Since Llama API does not have Llama 3.1 8b model, we use Hugging Face weights and vllm to run locally.
 
 
-## Quick Start
+## Quick Start with Llama Models via Llama API
 
 
 First, run the commands below to create a new Conda environment and install all the required packages for Text2SQL evaluation and fine-tuning:
 First, run the commands below to create a new Conda environment and install all the required packages for Text2SQL evaluation and fine-tuning:
 
 
@@ -46,12 +46,39 @@ After the script completes, you'll see the accuracy of the Llama model on the BI
 
 
 To compare your evaluated accuracy of your selected Llama model with other results in the BIRD Dev leaderboard, click [here](https://bird-bench.github.io/).
 To compare your evaluated accuracy of your selected Llama model with other results in the BIRD Dev leaderboard, click [here](https://bird-bench.github.io/).
 
 
+## Evaluation with Llama Models on Hugging Face or Fine-tuned 
+
+We use vllm OpenAI compatible server to run Llama 3.1 8B on Hugging Face (steps below) or its fine-tuned models (steps [here](../fine-tuning/#evaluating-the-fine-tuned-model) for eval:
+
+1. Uncomment the last two lines in requirements.txt then run `pip install -r requirements.txt`:
+```
+# vllm==0.9.2
+# openai==1.90.0
+```
+
+2. Uncomment in `llama_eval.sh`:
+```
+YOUR_API_KEY='huggingface'
+model='meta-llama/Llama-3.1-8B-Instruct'
+```
+
+3. Start the vllm server:
+```   
+vllm serve meta-llama/Llama-3.1-8B-Instruct --tensor-parallel-size 1 --max-num-batched-tokens 8192 --max-num-seqs 64
+```
+or if you have multiple GPUs, do something like:
+```
+CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve meta-llama/Llama-3.1-8B-Instruct --tensor-parallel-size 4 --max-num-batched-tokens 8192 --max-num-seqs 64
+```
+
+then run `sh llama_eval.sh`.
+   
 ## Evaluation Process
 ## Evaluation Process
 
 
 1. **SQL Generation**: `llama_text2sql.py` sends natural language questions to the specified Llama model and collects the generated SQL queries.
 1. **SQL Generation**: `llama_text2sql.py` sends natural language questions to the specified Llama model and collects the generated SQL queries.
 
 
 2. **SQL Execution**: `text2sql_eval.py` executes both the generated SQL and ground truth SQL against the corresponding databases, then continues with steps 3 and 4 below.
 2. **SQL Execution**: `text2sql_eval.py` executes both the generated SQL and ground truth SQL against the corresponding databases, then continues with steps 3 and 4 below.
 
 
-3. **Result Comparison**: The results from executing the generated SQL are compared ([source code](text2sql_eval.py#L30)) with the results from the ground truth SQL to determine correctness.
+3. **Result Comparison**: The results from executing the generated SQL are compared ([source code](text2sql_eval.py#L29)) with the results from the ground truth SQL to determine correctness.
 
 
 4. **Accuracy Calculation**: Accuracy scores are calculated overall and broken down by difficulty levels (simple, moderate, challenging).
 4. **Accuracy Calculation**: Accuracy scores are calculated overall and broken down by difficulty levels (simple, moderate, challenging).

+ 19 - 4
end-to-end-use-cases/coding/text2sql/fine-tuning/README.md

@@ -76,7 +76,9 @@ Let me think through this step by step:\n\n1. First, I need to consider...\n2. T
 Run one of the commands below:
 Run one of the commands below:
 
 
 ```
 ```
+python trl_sft.py --quantized false --peft false --cot true
 python trl_sft.py --quantized false --peft true --cot true
 python trl_sft.py --quantized false --peft true --cot true
+python trl_sft.py --quantized true --peft true --cot true
 ```
 ```
 
 
 After the fine-tuning completes, you'll see the fine-tuned model saved in one of the following folders, as specified in `output_dir` of `SFTConfig` in `trl_sft.py`:
 After the fine-tuning completes, you'll see the fine-tuned model saved in one of the following folders, as specified in `output_dir` of `SFTConfig` in `trl_sft.py`:
@@ -84,20 +86,33 @@ After the fine-tuning completes, you'll see the fine-tuned model saved in one of
 ```
 ```
 llama31-8b-text2sql-fft-nonquantized-cot
 llama31-8b-text2sql-fft-nonquantized-cot
 llama31-8b-text2sql-peft-nonquantized-cot
 llama31-8b-text2sql-peft-nonquantized-cot
+llama31-8b-text2sql-peft-quantized-cot
 ```
 ```
 
 
 The train loss chart should look like this:
 The train loss chart should look like this:
 ![](train_loss_cot.png)
 ![](train_loss_cot.png)
 
 
-### Evaluating the fine-tuned model (With Reasoning)
+### Evaluating the fine-tuned model
 
 
-First, set the `model` value in `llama_eval.sh` to be one of the fine-tuned model folders above, e.g.
+1. Set the `model` value in `llama_eval.sh` to be one of the fine-tuned model folders above, e.g.
 
 
 ```
 ```
 YOUR_API_KEY='finetuned'
 YOUR_API_KEY='finetuned'
 model='fine_tuning/llama31-8b-text2sql-fft-nonquantized-cot'
 model='fine_tuning/llama31-8b-text2sql-fft-nonquantized-cot'
 ```
 ```
 
 
-Then uncomment the line `SYSTEM_PROMPT` [here](https://github.com/meta-llama/llama-cookbook/blob/text2sql/end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py#L31) in `llama_text2sql.py` to use it with the reasoning dataset fine-tuned model.
+2. Uncomment the line `SYSTEM_PROMPT` [here](https://github.com/meta-llama/llama-cookbook/blob/text2sql/end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py#L17) in `llama_text2sql.py` to use it with the reasoning dataset fine-tuned model.
 
 
-Now run `sh llama_eval.sh`, which will take longer because the reasoning is needed to generate the SQL. The accuracy this time is 43.37%, compared with 37.16% without reasoning. This is another 16% improvement over the model with fine-tuning without reasoning.
+3. Start the vllm server by running
+```
+vllm serve fine_tuning/llama31-8b-text2sql-fft-nonquantized-cot --tensor-parallel-size 1 --max-num-batched-tokens 8192 --max-num-seqs 64
+```
+If you have multiple GPUs you can run something like 
+```
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 vllm serve fine_tuning/llama31-8b-text2sql-fft-nonquantized-cot --tensor-parallel-size 8 --max-num-batched-tokens 8192 --max-num-seqs 64
+```
+ to speed up the eval.
+
+4. Run `sh llama_eval.sh`.
+
+**Note:** If your fine-tuned model is PEFT based, you may need to run `python merge_peft.py` after modifying its `peft_model_path` and `output_dir` and set the merged folder path after `vllm serve`.

+ 1 - 2
end-to-end-use-cases/coding/text2sql/fine-tuning/trl_sft.py

@@ -6,6 +6,7 @@ import sys
 import torch
 import torch
 from datasets import load_dataset
 from datasets import load_dataset
 from transformers import AutoModelForCausalLM, AutoTokenizer
 from transformers import AutoModelForCausalLM, AutoTokenizer
+from trl import SFTConfig, SFTTrainer
 
 
 # Parse command line arguments
 # Parse command line arguments
 parser = argparse.ArgumentParser(
 parser = argparse.ArgumentParser(
@@ -66,8 +67,6 @@ if use_quantized:
 if use_peft:
 if use_peft:
     from peft import LoraConfig
     from peft import LoraConfig
 
 
-from trl import setup_chat_format, SFTConfig, SFTTrainer
-
 # Dataset configuration based on CoT parameter
 # Dataset configuration based on CoT parameter
 if use_cot:
 if use_cot:
     FT_DATASET = "train_text2sql_cot_dataset.json"
     FT_DATASET = "train_text2sql_cot_dataset.json"