|
@@ -24,6 +24,7 @@ from transformers import (
|
|
|
pipeline,
|
|
|
)
|
|
|
|
|
|
+MAX_NEW_TOKENS=10240 # If API has max tokens (vs max new tokens), we calculate it
|
|
|
|
|
|
def local_llama(prompt, pipe):
|
|
|
SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
|
|
@@ -45,7 +46,7 @@ def local_llama(prompt, pipe):
|
|
|
|
|
|
outputs = pipe(
|
|
|
raw_prompt,
|
|
|
- max_new_tokens=10240,
|
|
|
+ max_new_tokens=MAX_NEW_TOKENS,
|
|
|
do_sample=False,
|
|
|
temperature=0.0,
|
|
|
top_k=50,
|
|
@@ -187,22 +188,26 @@ def cloud_llama(api_key, model, prompt, max_tokens, temperature, stop):
|
|
|
SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement."
|
|
|
try:
|
|
|
if model.startswith("meta-llama/"):
|
|
|
+ final_prompt = SYSTEM_PROMPT + "\n\n" + prompt
|
|
|
+ final_max_tokens = len(final_prompt) + MAX_NEW_TOKENS
|
|
|
llm = ChatTogether(
|
|
|
model=model,
|
|
|
temperature=0,
|
|
|
+ max_tokens=final_max_tokens,
|
|
|
)
|
|
|
- answer = llm.invoke(SYSTEM_PROMPT + "\n\n" + prompt).content
|
|
|
+ answer = llm.invoke(final_prompt).content
|
|
|
else:
|
|
|
client = LlamaAPIClient()
|
|
|
messages = [
|
|
|
{"content": SYSTEM_PROMPT, "role": "system"},
|
|
|
{"role": "user", "content": prompt},
|
|
|
]
|
|
|
-
|
|
|
+ final_max_tokens = len(messages) + MAX_NEW_TOKENS
|
|
|
response = client.chat.completions.create(
|
|
|
model=model,
|
|
|
messages=messages,
|
|
|
temperature=0,
|
|
|
+ max_completion_tokens=final_max_tokens,
|
|
|
)
|
|
|
answer = response.completion_message.content.text
|
|
|
|