token_processor.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import tiktoken
  2. # Assuming result_average_token is a constant, use UPPER_CASE for its name to follow Python conventions
  3. AVERAGE_TOKENS_PER_RESULT = 100
  4. def get_token_limit_for_model(model: str) -> int:
  5. """Returns the token limit for a given model."""
  6. if model == "gpt-3.5-turbo-16k":
  7. return 16384
  8. # Consider adding an else statement or default return value if more models are expected
  9. def fetch_encoding_for_model(model="gpt-3.5-turbo-16k"):
  10. """Fetches the encoding for the specified model."""
  11. try:
  12. return tiktoken.encoding_for_model(model)
  13. except KeyError:
  14. print("Warning: Model not found. Using 'cl100k_base' encoding as default.")
  15. return tiktoken.get_encoding("cl100k_base")
  16. def calculate_num_tokens_for_message(message: str, model="gpt-3.5-turbo-16k") -> int:
  17. """Calculates the number of tokens used by a message."""
  18. encoding = fetch_encoding_for_model(model)
  19. # Added 3 to account for priming with assistant's reply, as per original comment
  20. return len(encoding.encode(message)) + 3
  21. def split_text_into_tokenized_chunks(context: dict, text_to_split: str) -> list[str]:
  22. """Splits a long string into substrings based on token length constraints."""
  23. max_tokens_per_chunk = (
  24. get_token_limit_for_model(context["model"]) -
  25. calculate_num_tokens_for_message(context["question_prompt_template"]) -
  26. AVERAGE_TOKENS_PER_RESULT * context["total_questions"]
  27. )
  28. substrings = []
  29. chunk_tokens = []
  30. encoding = fetch_encoding_for_model(context["model"])
  31. text_tokens = encoding.encode(text_to_split)
  32. for token in text_tokens:
  33. if len(chunk_tokens) + 1 > max_tokens_per_chunk:
  34. substrings.append(encoding.decode(chunk_tokens).strip())
  35. chunk_tokens = [token]
  36. else:
  37. chunk_tokens.append(token)
  38. if chunk_tokens:
  39. substrings.append(encoding.decode(chunk_tokens).strip())
  40. return substrings