Переглянути джерело

code cleanup and refactoring; cloud llama response generation in tqdm progress

Jeff Tang 3 місяців тому
батько
коміт
12a6dfa2ac

+ 92 - 184
end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py

@@ -13,36 +13,7 @@ MAX_NEW_TOKENS = 10240  # If API has max tokens (vs max new tokens), we calculat
 TIMEOUT = 60  # Timeout in seconds for each API call
 
 
-def local_llama(client, prompt, model):
-    SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
-    # UNCOMMENT TO USE THE FINE_TUNED MODEL WITH REASONING DATASET
-    # SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question."
-
-    messages = [
-        {"content": SYSTEM_PROMPT, "role": "system"},
-        {"role": "user", "content": prompt},
-    ]
-    print(f"local_llama: {model=}")
-    chat_response = client.chat.completions.create(
-        model=model,
-        messages=messages,
-        timeout=TIMEOUT,
-        temperature=0,
-    )
-    answer = chat_response.choices[0].message.content.strip()
-
-    pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
-    matches = pattern.findall(answer)
-    if not matches:
-        result = answer
-    else:
-        result = matches[0]
-
-    print(f"{result=}")
-    return result
-
-
-def batch_local_llama(client, prompts, model, max_workers=8):
+def local_llama(client, api_key, prompts, model, max_workers=8):
     """
     Process multiple prompts in parallel using the vllm server.
 
@@ -55,10 +26,19 @@ def batch_local_llama(client, prompts, model, max_workers=8):
     Returns:
         List of results in the same order as prompts
     """
+
     SYSTEM_PROMPT = (
-        "You are a text to SQL query translator. Using the SQLite DB Schema "
-        "and the External Knowledge, translate the following text question "
-        "into a SQLite SQL select statement."
+        (
+            "You are a text to SQL query translator. Using the SQLite DB Schema "
+            "and the External Knowledge, translate the following text question "
+            "into a SQLite SQL select statement."
+        )
+        if api_key == "huggingface"
+        else (
+            "You are a text to SQL query translator. Using the SQLite DB Schema "
+            "and the External Knowledge, generate the step-by-step reasoning and "
+            "then the final SQLite SQL select statement from the text question."
+        )
     )
 
     def process_single_prompt(prompt):
@@ -88,7 +68,7 @@ def batch_local_llama(client, prompts, model, max_workers=8):
             return f"error:{e}"
 
     print(
-        f"batch_local_llama: Processing {len(prompts)} prompts with {model=} "
+        f"local_llama: Processing {len(prompts)} prompts with {model=} "
         f"using {max_workers} workers"
     )
     results = []
@@ -231,103 +211,61 @@ def generate_combined_prompts_one(db_path, question, knowledge=None):
     return combined_prompts
 
 
-def cloud_llama(client, api_key, model, prompt, max_tokens, temperature, stop):
-
-    SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
-    try:
-        messages = [
-            {"content": SYSTEM_PROMPT, "role": "system"},
-            {"role": "user", "content": prompt},
-        ]
-        final_max_tokens = len(messages) + MAX_NEW_TOKENS
-        response = client.chat.completions.create(
-            model=model,
-            messages=messages,
-            temperature=0,
-            max_completion_tokens=final_max_tokens,
-        )
-        answer = response.completion_message.content.text
-
-        pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
-        matches = pattern.findall(answer)
-        if matches != []:
-            result = matches[0]
-        else:
-            result = answer
-
-        print(result)
-    except Exception as e:
-        result = "error:{}".format(e)
-        print(f"{result=}")
-    return result
-
-
-def collect_response_from_llama(
-    db_path_list, question_list, api_key, model, knowledge_list=None
-):
+def cloud_llama(client, api_key, model, prompts):
     """
-    :param db_path: str
-    :param question_list: []
-    :return: dict of responses
-    """
-    response_list = []
-
-    if api_key in ["huggingface", "finetuned"]:
-        from openai import OpenAI
+    Process multiple prompts sequentially using the cloud API, showing progress with tqdm.
 
-        openai_api_key = "EMPTY"
-        openai_api_base = "http://localhost:8000/v1"
+    Args:
+        client: LlamaAPIClient
+        api_key: API key
+        model: Model name
+        prompts: List of prompts to process (or a single prompt as string)
 
-        client = OpenAI(
-            api_key=openai_api_key,
-            base_url=openai_api_base,
-        )
-    else:
-        client = LlamaAPIClient()
+    Returns:
+        List of results if prompts is a list, or a single result if prompts is a string
+    """
+    SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
 
-    for i, question in enumerate(question_list):
-        print(
-            "--------------------- processing question #{}---------------------".format(
-                i + 1
-            )
-        )
-        print("the question is: {}".format(question))
+    # Handle the case where a single prompt is passed
+    single_prompt = False
+    if isinstance(prompts, str):
+        prompts = [prompts]
+        single_prompt = True
 
-        if knowledge_list:
-            cur_prompt = generate_combined_prompts_one(
-                db_path=db_path_list[i], question=question, knowledge=knowledge_list[i]
-            )
-        else:
-            cur_prompt = generate_combined_prompts_one(
-                db_path=db_path_list[i], question=question
-            )
-
-        if api_key in ["huggingface", "finetuned"]:
-            plain_result = local_llama(client=client, prompt=cur_prompt, model=model)
-        else:
+    results = []
 
-            plain_result = cloud_llama(
-                client=client,
-                api_key=api_key,
+    # Process each prompt sequentially with tqdm progress bar
+    for prompt in tqdm(prompts, desc="Processing prompts", unit="prompt"):
+        try:
+            messages = [
+                {"content": SYSTEM_PROMPT, "role": "system"},
+                {"role": "user", "content": prompt},
+            ]
+            final_max_tokens = len(messages) + MAX_NEW_TOKENS
+            response = client.chat.completions.create(
                 model=model,
-                prompt=cur_prompt,
-                max_tokens=10240,
+                messages=messages,
                 temperature=0,
-                stop=["--", "\n\n", ";", "#"],
+                max_completion_tokens=final_max_tokens,
             )
-        if isinstance(plain_result, str):
-            sql = plain_result
-        else:
-            sql = "SELECT" + plain_result["choices"][0]["text"]
+            answer = response.completion_message.content.text
 
-        # responses_dict[i] = sql
-        db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0]
-        sql = (
-            sql + "\t----- bird -----\t" + db_id
-        )  # to avoid unpredicted \t appearing in codex results
-        response_list.append(sql)
+            pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL)
+            matches = pattern.findall(answer)
+            if matches != []:
+                result = matches[0]
+            else:
+                result = answer
+        except Exception as e:
+            result = "error:{}".format(e)
+            print(f"{result=}")
 
-    return response_list
+        results.append(result)
+
+    # Return a single result if input was a single prompt
+    if single_prompt:
+        return results[0]
+    return results
 
 
 def batch_collect_response_from_llama(
@@ -376,25 +314,24 @@ def batch_collect_response_from_llama(
     print(f"Generated {len(prompts)} prompts for batch processing")
 
     # Process prompts in parallel
-    if api_key in ["huggingface", "finetuned"]:
-        results = batch_local_llama(
-            client=client, prompts=prompts, model=model, max_workers=batch_size
+    if api_key in [
+        "huggingface",
+        "finetuned",
+    ]:  # running vllm on multiple GPUs to see best performance
+        results = local_llama(
+            client=client,
+            api_key=api_key,
+            prompts=prompts,
+            model=model,
+            max_workers=batch_size,
         )
     else:
-        # For cloud API, we could implement a batch version of cloud_llama if needed
-        # For now, just process sequentially
-        results = []
-        for prompt in prompts:
-            plain_result = cloud_llama(
-                client=client,
-                api_key=api_key,
-                model=model,
-                prompt=prompt,
-                max_tokens=10240,
-                temperature=0,
-                stop=["--", "\n\n", ";", "#"],
-            )
-            results.append(plain_result)
+        results = cloud_llama(
+            client=client,
+            api_key=api_key,
+            model=model,
+            prompts=prompts,
+        )
 
     # Format results
     response_list = []
@@ -471,9 +408,6 @@ if __name__ == "__main__":
         default=8,
         help="Number of parallel requests for batch processing",
     )
-    args_parser.add_argument(
-        "--use_batch", type=str, default="True", help="Whether to use batch processing"
-    )
     args = args_parser.parse_args()
 
     if args.api_key not in ["huggingface", "finetuned"]:
@@ -488,62 +422,36 @@ if __name__ == "__main__":
                 temperature=0,
             )
             answer = response.completion_message.content.text
-
-            print(f"{answer=}")
         except Exception as exception:
             print(f"{exception=}")
             exit(1)
 
     eval_data = json.load(open(args.eval_path, "r"))
-    # '''for debug'''
-    # eval_data = eval_data[:3]
-    # '''for debug'''
 
     question_list, db_path_list, knowledge_list = decouple_question_schema(
         datasets=eval_data, db_root_path=args.db_root_path
     )
     assert len(question_list) == len(db_path_list) == len(knowledge_list)
 
-    use_batch = args.use_batch.lower() == "true"
-
-    if use_batch:
-        print(f"Using batch processing with batch_size={args.batch_size}")
-        if args.use_knowledge == "True":
-            responses = batch_collect_response_from_llama(
-                db_path_list=db_path_list,
-                question_list=question_list,
-                api_key=args.api_key,
-                model=args.model,
-                knowledge_list=knowledge_list,
-                batch_size=args.batch_size,
-            )
-        else:
-            responses = batch_collect_response_from_llama(
-                db_path_list=db_path_list,
-                question_list=question_list,
-                api_key=args.api_key,
-                model=args.model,
-                knowledge_list=None,
-                batch_size=args.batch_size,
-            )
+    print(f"Using batch processing with batch_size={args.batch_size}")
+    if args.use_knowledge == "True":
+        responses = batch_collect_response_from_llama(
+            db_path_list=db_path_list,
+            question_list=question_list,
+            api_key=args.api_key,
+            model=args.model,
+            knowledge_list=knowledge_list,
+            batch_size=args.batch_size,
+        )
     else:
-        print("Using sequential processing")
-        if args.use_knowledge == "True":
-            responses = collect_response_from_llama(
-                db_path_list=db_path_list,
-                question_list=question_list,
-                api_key=args.api_key,
-                model=args.model,
-                knowledge_list=knowledge_list,
-            )
-        else:
-            responses = collect_response_from_llama(
-                db_path_list=db_path_list,
-                question_list=question_list,
-                api_key=args.api_key,
-                model=args.model,
-                knowledge_list=None,
-            )
+        responses = batch_collect_response_from_llama(
+            db_path_list=db_path_list,
+            question_list=question_list,
+            api_key=args.api_key,
+            model=args.model,
+            knowledge_list=None,
+            batch_size=args.batch_size,
+        )
 
     output_name = args.data_output_path + "predict_" + args.mode + ".json"
 

+ 1 - 0
end-to-end-use-cases/coding/text2sql/eval/requirements.txt

@@ -1,5 +1,6 @@
 llama_api_client==0.1.2
 func_timeout==4.3.5
+tqdm==4.67.1
 
 # uncomment to run vllm for eval with Llama 3.1 8B on HF and its fine-tuned models
 # vllm==0.9.2