소스 검색

some cleanup and typo fix

Jeff Tang 3 달 전
부모
커밋
799dee6813

+ 7 - 6
end-to-end-use-cases/coding/text2sql/eval/README.md

@@ -20,10 +20,11 @@ Below are the results of the Llama models we have evaluated on the BIRD DEV data
 First, run the commands below to create a new Conda environment and install all the required packages for Text2SQL evaluation and fine-tuning:
 
 ```
-git clone https://github.com/meta-llama/llama-cookbook
-cd llama-cookbook/end-to-end-use-cases/coding/text2sql
 conda create -n llama-text2sql python=3.10
 conda activate llama-text2sql
+git clone https://github.com/meta-llama/llama-cookbook
+git checkout text2sql # to be removed after the PR merge
+cd llama-cookbook/end-to-end-use-cases/coding/text2sql/eval
 pip install -r requirements.txt
 ```
 
@@ -31,7 +32,7 @@ Then, follow the steps below to evaluate Llama 3 & 4 models on Text2SQL using th
 
 1. Get the DEV dataset:
 ```
-cd data
+cd ../data
 sh download_dev_unzip.sh
 cd ../eval
 ```
@@ -46,7 +47,7 @@ After the script completes, you'll see the accuracy of the Llama model on the BI
 
 To compare your evaluated accuracy of your selected Llama model with other results in the BIRD Dev leaderboard, click [here](https://bird-bench.github.io/).
 
-## Evaluation with Llama Models on Hugging Face or Fine-tuned 
+## Evaluation with Llama Models on Hugging Face or Fine-tuned
 
 We use vllm OpenAI compatible server to run Llama 3.1 8B on Hugging Face (steps below) or its fine-tuned models (steps [here](../fine-tuning/#evaluating-the-fine-tuned-model) for eval:
 
@@ -63,7 +64,7 @@ model='meta-llama/Llama-3.1-8B-Instruct'
 ```
 
 3. Start the vllm server:
-```   
+```
 vllm serve meta-llama/Llama-3.1-8B-Instruct --tensor-parallel-size 1 --max-num-batched-tokens 8192 --max-num-seqs 64
 ```
 or if you have multiple GPUs, do something like:
@@ -72,7 +73,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve meta-llama/Llama-3.1-8B-Instruct --tenso
 ```
 
 then run `sh llama_eval.sh`.
-   
+
 ## Evaluation Process
 
 1. **SQL Generation**: `llama_text2sql.py` sends natural language questions to the specified Llama model and collects the generated SQL queries.

+ 5 - 4
end-to-end-use-cases/coding/text2sql/eval/llama_eval.sh

@@ -6,8 +6,8 @@ db_root_path='../data/dev_20240627/dev_databases/'
 ground_truth_path='../data/'
 
 # Llama models on Llama API
-# YOUR_API_KEY='YOUR_LLAMA_API_KEY'
-# model='Llama-3.3-8B-Instruct'
+YOUR_API_KEY='YOUR_LLAMA_API_KEY'
+model='Llama-3.3-8B-Instruct'
 #model='Llama-3.3-70B-Instruct'
 #model='Llama-4-Maverick-17B-128E-Instruct-FP8'
 #model='Llama-4-Scout-17B-16E-Instruct-FP8'
@@ -17,12 +17,13 @@ ground_truth_path='../data/'
 # model='meta-llama/Llama-3.1-8B-Instruct'
 
 # Fine-tuned Llama models locally
-YOUR_API_KEY='finetuned'
-model='../fine-tuning/llama31-8b-text2sql-fft-nonquantized-cot-epochs-3'
+# YOUR_API_KEY='finetuned'
+# model='../fine-tuning/llama31-8b-text2sql-fft-nonquantized-cot-epochs-3'
 
 data_output_path="./output/$model/"
 
 echo "Text2SQL using $model"
+
 python3 -u llama_text2sql.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} \
 --model ${model} --eval_path ${eval_path} --data_output_path ${data_output_path}
 

+ 5 - 7
end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py

@@ -311,13 +311,13 @@ def batch_collect_response_from_llama(
             )
         prompts.append(cur_prompt)
 
-    print(f"Generated {len(prompts)} prompts for batch processing")
+    print(f"Generated {len(prompts)} prompts for Llama processing")
 
-    # Process prompts in parallel
     if api_key in [
         "huggingface",
         "finetuned",
-    ]:  # running vllm on multiple GPUs to see best performance
+    ]:
+        # Process prompts in parallel; running vllm on multiple GPUs for best eval performance
         results = local_llama(
             client=client,
             api_key=api_key,
@@ -414,14 +414,13 @@ if __name__ == "__main__":
         os.environ["LLAMA_API_KEY"] = args.api_key
 
         try:
+            # test if the Llama API key is valid
             client = LlamaAPIClient()
-
-            response = client.chat.completions.create(
+            client.chat.completions.create(
                 model=args.model,
                 messages=[{"role": "user", "content": "125*125 is?"}],
                 temperature=0,
             )
-            answer = response.completion_message.content.text
         except Exception as exception:
             print(f"{exception=}")
             exit(1)
@@ -433,7 +432,6 @@ if __name__ == "__main__":
     )
     assert len(question_list) == len(db_path_list) == len(knowledge_list)
 
-    print(f"Using batch processing with batch_size={args.batch_size}")
     if args.use_knowledge == "True":
         responses = batch_collect_response_from_llama(
             db_path_list=db_path_list,