doc_processor.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import tiktoken
  4. # Assuming result_average_token is a constant, use UPPER_CASE for its name to follow Python conventions
  5. AVERAGE_TOKENS_PER_RESULT = 100
  6. def get_token_limit_for_model(model: str) -> int:
  7. """Returns the token limit for a given model."""
  8. if model == "gpt-3.5-turbo-16k":
  9. return 16384
  10. # Consider adding an else statement or default return value if more models are expected
  11. def fetch_encoding_for_model(model="gpt-3.5-turbo-16k"):
  12. """Fetches the encoding for the specified model."""
  13. try:
  14. return tiktoken.encoding_for_model(model)
  15. except KeyError:
  16. print("Warning: Model not found. Using 'cl100k_base' encoding as default.")
  17. return tiktoken.get_encoding("cl100k_base")
  18. def calculate_num_tokens_for_message(message: str, model="gpt-3.5-turbo-16k") -> int:
  19. """Calculates the number of tokens used by a message."""
  20. encoding = fetch_encoding_for_model(model)
  21. # Added 3 to account for priming with assistant's reply, as per original comment
  22. return len(encoding.encode(message)) + 3
  23. def split_text_into_chunks(context: dict, text: str) -> list[str]:
  24. """Splits a long text into substrings based on token length constraints, adjusted for question generation."""
  25. # Adjusted approach to calculate max tokens available for text chunks
  26. model_token_limit = get_token_limit_for_model(context["model"])
  27. tokens_for_questions = calculate_num_tokens_for_message(context["question_prompt_template"])
  28. estimated_tokens_per_question = AVERAGE_TOKENS_PER_RESULT
  29. estimated_total_question_tokens = estimated_tokens_per_question * context["total_questions"]
  30. # Ensure there's a reasonable minimum chunk size
  31. max_tokens_for_text = max(model_token_limit - tokens_for_questions - estimated_total_question_tokens, model_token_limit // 10)
  32. encoded_text = fetch_encoding_for_model(context["model"]).encode(text)
  33. chunks, current_chunk = [], []
  34. print(f"Splitting text into chunks of {max_tokens_for_text} tokens, encoded_text {len(encoded_text)}", flush=True)
  35. for token in encoded_text:
  36. if len(current_chunk) + 1 > max_tokens_for_text:
  37. chunks.append(fetch_encoding_for_model(context["model"]).decode(current_chunk).strip())
  38. current_chunk = [token]
  39. else:
  40. current_chunk.append(token)
  41. if current_chunk:
  42. chunks.append(fetch_encoding_for_model(context["model"]).decode(current_chunk).strip())
  43. return chunks