فهرست منبع

dynamically calculating max tokens param; it was unused before

Amir Youssefi 2 هفته پیش
والد
کامیت
6b92409507
1فایلهای تغییر یافته به همراه8 افزوده شده و 3 حذف شده
  1. 8 3
      end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py

+ 8 - 3
end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py

@@ -24,6 +24,7 @@ from transformers import (
     pipeline,
     pipeline,
 )
 )
 
 
+MAX_NEW_TOKENS=10240  # If API has max tokens (vs max new tokens), we calculate it
 
 
 def local_llama(prompt, pipe):
 def local_llama(prompt, pipe):
     SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
     SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
@@ -45,7 +46,7 @@ def local_llama(prompt, pipe):
 
 
     outputs = pipe(
     outputs = pipe(
         raw_prompt,
         raw_prompt,
-        max_new_tokens=10240,
+        max_new_tokens=MAX_NEW_TOKENS,
         do_sample=False,
         do_sample=False,
         temperature=0.0,
         temperature=0.0,
         top_k=50,
         top_k=50,
@@ -187,22 +188,26 @@ def cloud_llama(api_key, model, prompt, max_tokens, temperature, stop):
     SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
     SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
     try:
     try:
         if model.startswith("meta-llama/"):
         if model.startswith("meta-llama/"):
+            final_prompt = SYSTEM_PROMPT + "\n\n" + prompt
+            final_max_tokens = len(final_prompt) + MAX_NEW_TOKENS
             llm = ChatTogether(
             llm = ChatTogether(
                 model=model,
                 model=model,
                 temperature=0,
                 temperature=0,
+                max_tokens=final_max_tokens,
             )
             )
-            answer = llm.invoke(SYSTEM_PROMPT + "\n\n" + prompt).content
+            answer = llm.invoke(final_prompt).content
         else:
         else:
             client = LlamaAPIClient()
             client = LlamaAPIClient()
             messages = [
             messages = [
                 {"content": SYSTEM_PROMPT, "role": "system"},
                 {"content": SYSTEM_PROMPT, "role": "system"},
                 {"role": "user", "content": prompt},
                 {"role": "user", "content": prompt},
             ]
             ]
-
+            final_max_tokens = len(messages) + MAX_NEW_TOKENS
             response = client.chat.completions.create(
             response = client.chat.completions.create(
                 model=model,
                 model=model,
                 messages=messages,
                 messages=messages,
                 temperature=0,
                 temperature=0,
+                max_completion_tokens=final_max_tokens,
             )
             )
             answer = response.completion_message.content.text
             answer = response.completion_message.content.text