doc_processor.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. # Assuming result_average_token is a constant, use UPPER_CASE for its name to follow Python conventions
  4. AVERAGE_TOKENS_PER_RESULT = 100
  5. def get_token_limit_for_model(model: str) -> int:
  6. """Returns the token limit for a given model."""
  7. if model == "llama-2-13b-chat" or model == "llama-2-70b-chat":
  8. return 4096
  9. else:
  10. return 8192
  11. def calculate_num_tokens_for_message(encoded_text) -> int:
  12. """Calculates the number of tokens used by a message."""
  13. # Added 3 to account for priming with assistant's reply, as per original comment
  14. return len(encoded_text) + 3
  15. def split_text_into_chunks(context: dict, text: str, tokenizer) -> list[str]:
  16. """Splits a long text into substrings based on token length constraints, adjusted for question generation."""
  17. # Adjusted approach to calculate max tokens available for text chunks
  18. encoded_text = tokenizer(text, return_tensors="pt", padding=True)["input_ids"]
  19. encoded_text = encoded_text.squeeze()
  20. model_token_limit = get_token_limit_for_model(context["model"])
  21. tokens_for_questions = calculate_num_tokens_for_message(encoded_text)
  22. estimated_tokens_per_question = AVERAGE_TOKENS_PER_RESULT
  23. estimated_total_question_tokens = estimated_tokens_per_question * context["total_questions"]
  24. # Ensure there's a reasonable minimum chunk size
  25. max_tokens_for_text = max(model_token_limit - tokens_for_questions - estimated_total_question_tokens, model_token_limit // 10)
  26. chunks, current_chunk = [], []
  27. print(f"Splitting text into chunks of {max_tokens_for_text} tokens, encoded_text {len(encoded_text)}", flush=True)
  28. for token in encoded_text:
  29. if len(current_chunk) >= max_tokens_for_text:
  30. chunks.append(tokenizer.decode(current_chunk).strip())
  31. current_chunk = []
  32. else:
  33. current_chunk.append(token)
  34. if current_chunk:
  35. chunks.append(tokenizer.decode(current_chunk).strip())
  36. print(f"Number of chunks in the processed text: {len(chunks)}", flush=True)
  37. return chunks