generate_question_answers.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import argparse
  4. import asyncio
  5. import json
  6. from config import load_config
  7. from generator_utils import generate_question_batches, parse_qa_to_json
  8. from itertools import chain
  9. import logging
  10. import aiofiles # Ensure aiofiles is installed for async file operations
  11. from abc import ABC, abstractmethod
  12. from octoai.client import OctoAI
  13. from functools import partial
  14. from openai import OpenAI
  15. # Configure logging to include the timestamp, log level, and message
  16. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  17. # Manage rate limits with throttling
  18. rate_limit_threshold = 2000
  19. allowed_concurrent_requests = int(rate_limit_threshold * 0.75)
  20. request_limiter = asyncio.Semaphore(allowed_concurrent_requests)
  21. # Since OctoAI has different naming for llama models, create this mapping to get huggingface offical model name given OctoAI names.
  22. MODEL_NAME_MAPPING={"meta-llama-3-70b-instruct":"meta-llama/Meta-Llama-3-70B-Instruct",
  23. "meta-llama-3-8b-instruct":"meta-llama/Meta-Llama-3-8B-Instruct","llama-2-7b-chat":"meta-llama/Llama-2-7b-chat-hf"
  24. ,"llama-2-70b-chat":"meta-llama/Llama-2-70b-chat-hf"}
  25. class ChatService(ABC):
  26. @abstractmethod
  27. async def execute_chat_request_async(self, api_context: dict, chat_request):
  28. pass
  29. # Please implement your own chat service class here.
  30. # The class should inherit from the ChatService class and implement the execute_chat_request_async method.
  31. # The following are two example chat service classes that you can use as a reference.
  32. class OctoAIChatService(ChatService):
  33. async def execute_chat_request_async(self, api_context: dict, chat_request):
  34. async with request_limiter:
  35. try:
  36. event_loop = asyncio.get_running_loop()
  37. client = OctoAI(api_context['api_key'])
  38. api_chat_call = partial(
  39. client.chat.completions.create,
  40. model=api_context['model'],
  41. messages=chat_request,
  42. temperature=0.0
  43. )
  44. response = await event_loop.run_in_executor(None, api_chat_call)
  45. assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
  46. assistant_response_json = parse_qa_to_json(assistant_response)
  47. return assistant_response_json
  48. except Exception as error:
  49. logging.error(f"Error during chat request execution: {error}",exc_info=True)
  50. return ""
  51. # Use the local vllm openai compatible server for generating question/answer pairs to make API call syntax consistent
  52. # please read for more detail:https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html.
  53. class VllmChatService(ChatService):
  54. async def execute_chat_request_async(self, api_context: dict, chat_request):
  55. async with request_limiter:
  56. try:
  57. event_loop = asyncio.get_running_loop()
  58. model_name = MODEL_NAME_MAPPING[api_context['model']]
  59. client = OpenAI(api_key=api_context['api_key'], base_url="http://localhost:"+ str(api_context['endpoint'])+"/v1")
  60. api_chat_call = partial(
  61. client.chat.completions.create,
  62. model=model_name,
  63. messages=chat_request,
  64. temperature=0.0
  65. )
  66. response = await event_loop.run_in_executor(None, api_chat_call)
  67. assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
  68. assistant_response_json = parse_qa_to_json(assistant_response)
  69. if len(assistant_response_json)==0:
  70. logging.error("No question/answer pairs generated. Please check the input context or model configuration.")
  71. return assistant_response_json
  72. except Exception as error:
  73. logging.error(f"Error during chat request execution: {error}",exc_info=True)
  74. return ""
  75. async def main(context):
  76. if context["endpoint"]:
  77. chat_service = VllmChatService()
  78. else:
  79. chat_service = OctoAIChatService()
  80. try:
  81. logging.info("Starting to generate question/answer pairs.")
  82. data = await generate_question_batches(chat_service, context)
  83. if not data:
  84. logging.warning("No data generated. Please check the input context or model configuration.")
  85. return
  86. flattened_list = list(chain.from_iterable(data))
  87. logging.info(f"Successfully generated {len(flattened_list)} question/answer pairs.")
  88. # Use asynchronous file operation for writing to the file
  89. async with aiofiles.open("data.json", "w") as output_file:
  90. await output_file.write(json.dumps(flattened_list, indent=4))
  91. logging.info("Data successfully written to 'data.json'. Process completed.")
  92. except Exception as e:
  93. logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
  94. def parse_arguments():
  95. # Define command line arguments for the script
  96. parser = argparse.ArgumentParser(
  97. description="Generate question/answer pairs from documentation."
  98. )
  99. parser.add_argument(
  100. "-t", "--total_questions",
  101. type=int,
  102. default=100,
  103. help="Specify the total number of question/answer pairs to generate."
  104. )
  105. parser.add_argument(
  106. "-m", "--model",
  107. choices=["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct","llama-2-13b-chat", "llama-2-70b-chat"],
  108. default="meta-llama-3-70b-instruct",
  109. help="Select the model to use for generation."
  110. )
  111. parser.add_argument(
  112. "-c", "--config_path",
  113. default="config.yaml",
  114. help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
  115. )
  116. parser.add_argument(
  117. "-v", "--vllm_endpoint",
  118. default=None,
  119. type=int,
  120. help="If a port is specified, then use local vllm endpoint for generating question/answer pairs."
  121. )
  122. return parser.parse_args()
  123. if __name__ == "__main__":
  124. logging.info("Initializing the process and loading configuration...")
  125. args = parse_arguments()
  126. context = load_config(args.config_path)
  127. context["total_questions"] = args.total_questions
  128. context["model"] = args.model
  129. context["endpoint"] = args.vllm_endpoint
  130. logging.info(f"Configuration loaded. Generating {args.total_questions} question/answer pairs using model '{args.model}'.")
  131. if context["endpoint"]:
  132. logging.info(f"Use local vllm service at port: '{args.vllm_endpoint}'.")
  133. asyncio.run(main(context))