model_handler.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import os
  2. import openai
  3. import asyncio
  4. from functools import partial
  5. import json
  6. from token_processor import split_string_by_token_length
  7. from file_handler import get_file_string
  8. # Throttling to manage rate limits
  9. model_rate_limits = 100
  10. max_concurrent_requests = int(model_rate_limits * 0.75)
  11. throttler = asyncio.Semaphore(max_concurrent_requests)
  12. async def send_chat_async(context: dict, request):
  13. async with throttler:
  14. try:
  15. loop = asyncio.get_running_loop()
  16. # Wrap the synchronous OpenAI API call with partial to pass arguments
  17. func = partial(
  18. openai.ChatCompletion.create,
  19. model=context['model'],
  20. messages=request,
  21. temperature=0.0
  22. )
  23. # Run the synchronous function in a separate thread
  24. resp = await loop.run_in_executor(None, func)
  25. # Process the response as before
  26. return next((msg['message']['content'] for msg in resp.choices if msg['message']['role'] == 'assistant'), "")
  27. except Exception as e:
  28. print(f"Error in send_chat_async: {e}")
  29. return ""
  30. async def request_question(context: dict, input_str: str, num_data: int) -> dict:
  31. system_prompt = context['question_generator'].format(num_data=num_data, language=context["language"])
  32. request = [{'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': input_str}]
  33. return json.loads(await send_chat_async(context, request))
  34. async def generate_questions(context: dict):
  35. doc_string = get_file_string(context)
  36. batches = split_string_by_token_length(context, doc_string)
  37. num_questions_per_batch = context["num_data"] // len(batches)
  38. tasks = []
  39. for idx, batch in enumerate(batches):
  40. num_questions = num_questions_per_batch + 1 if idx == len(batches) - 1 and len(batches) * num_questions_per_batch < context["num_data"] else num_questions_per_batch
  41. tasks.append(request_question(context, batch, num_questions))
  42. results = await asyncio.gather(*tasks)
  43. return results