Procházet zdrojové kódy

dynamically calculating max tokens param; it was unused before

Amir Youssefi před 2 týdny
rodič
revize
6b92409507

+ 8 - 3
end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py

@@ -24,6 +24,7 @@ from transformers import (
     pipeline,
 )
 
+MAX_NEW_TOKENS=10240  # If API has max tokens (vs max new tokens), we calculate it
 
 def local_llama(prompt, pipe):
     SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
@@ -45,7 +46,7 @@ def local_llama(prompt, pipe):
 
     outputs = pipe(
         raw_prompt,
-        max_new_tokens=10240,
+        max_new_tokens=MAX_NEW_TOKENS,
         do_sample=False,
         temperature=0.0,
         top_k=50,
@@ -187,22 +188,26 @@ def cloud_llama(api_key, model, prompt, max_tokens, temperature, stop):
     SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
     try:
         if model.startswith("meta-llama/"):
+            final_prompt = SYSTEM_PROMPT + "\n\n" + prompt
+            final_max_tokens = len(final_prompt) + MAX_NEW_TOKENS
             llm = ChatTogether(
                 model=model,
                 temperature=0,
+                max_tokens=final_max_tokens,
             )
-            answer = llm.invoke(SYSTEM_PROMPT + "\n\n" + prompt).content
+            answer = llm.invoke(final_prompt).content
         else:
             client = LlamaAPIClient()
             messages = [
                 {"content": SYSTEM_PROMPT, "role": "system"},
                 {"role": "user", "content": prompt},
             ]
-
+            final_max_tokens = len(messages) + MAX_NEW_TOKENS
             response = client.chat.completions.create(
                 model=model,
                 messages=messages,
                 temperature=0,
+                max_completion_tokens=final_max_tokens,
             )
             answer = response.completion_message.content.text