generate_question_answers.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import argparse
  4. import asyncio
  5. import json
  6. from config import load_config
  7. from generator_utils import generate_question_batches, parse_qa_to_json
  8. from itertools import chain
  9. import logging
  10. import aiofiles # Ensure aiofiles is installed for async file operations
  11. from abc import ABC, abstractmethod
  12. from octoai.client import Client
  13. from functools import partial
  14. # Configure logging to include the timestamp, log level, and message
  15. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  16. # Manage rate limits with throttling
  17. rate_limit_threshold = 2000
  18. allowed_concurrent_requests = int(rate_limit_threshold * 0.75)
  19. request_limiter = asyncio.Semaphore(allowed_concurrent_requests)
  20. class ChatService(ABC):
  21. @abstractmethod
  22. async def execute_chat_request_async(self, api_context: dict, chat_request):
  23. pass
  24. # Please implement your own chat service class here.
  25. # The class should inherit from the ChatService class and implement the execute_chat_request_async method.
  26. class OctoAIChatService(ChatService):
  27. async def execute_chat_request_async(self, api_context: dict, chat_request):
  28. async with request_limiter:
  29. try:
  30. event_loop = asyncio.get_running_loop()
  31. client = Client(api_context['api_key'])
  32. api_chat_call = partial(
  33. client.chat.completions.create,
  34. model=api_context['model'],
  35. messages=chat_request,
  36. temperature=0.0
  37. )
  38. response = await event_loop.run_in_executor(None, api_chat_call)
  39. assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
  40. assistant_response_json = parse_qa_to_json(assistant_response)
  41. return assistant_response_json
  42. except Exception as error:
  43. print(f"Error during chat request execution: {error}")
  44. return ""
  45. async def main(context):
  46. chat_service = OctoAIChatService()
  47. try:
  48. logging.info("Starting to generate question/answer pairs.")
  49. data = await generate_question_batches(chat_service, context)
  50. if not data:
  51. logging.warning("No data generated. Please check the input context or model configuration.")
  52. return
  53. flattened_list = list(chain.from_iterable(data))
  54. logging.info(f"Successfully generated {len(flattened_list)} question/answer pairs.")
  55. # Use asynchronous file operation for writing to the file
  56. async with aiofiles.open("data.json", "w") as output_file:
  57. await output_file.write(json.dumps(flattened_list, indent=4))
  58. logging.info("Data successfully written to 'data.json'. Process completed.")
  59. except Exception as e:
  60. logging.error(f"An unexpected error occurred during the process: {e}")
  61. def parse_arguments():
  62. # Define command line arguments for the script
  63. parser = argparse.ArgumentParser(
  64. description="Generate question/answer pairs from documentation."
  65. )
  66. parser.add_argument(
  67. "-t", "--total_questions",
  68. type=int,
  69. default=10,
  70. help="Specify the number of question/answer pairs to generate."
  71. )
  72. parser.add_argument(
  73. "-m", "--model",
  74. choices=["llama-2-70b-chat-fp16", "llama-2-13b-chat-fp16"],
  75. default="llama-2-70b-chat-fp16",
  76. help="Select the model to use for generation."
  77. )
  78. parser.add_argument(
  79. "-c", "--config_path",
  80. default="config.yaml",
  81. help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
  82. )
  83. return parser.parse_args()
  84. if __name__ == "__main__":
  85. logging.info("Initializing the process and loading configuration...")
  86. args = parse_arguments()
  87. context = load_config(args.config_path)
  88. context["total_questions"] = args.total_questions
  89. context["model"] = args.model
  90. logging.info(f"Configuration loaded. Generating {args.total_questions} question/answer pairs using model '{args.model}'.")
  91. asyncio.run(main(context))