123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- import argparse
- import asyncio
- import json
- from config import load_config
- from generator_utils import generate_question_batches, parse_qa_to_json
- from itertools import chain
- import logging
- import aiofiles # Ensure aiofiles is installed for async file operations
- from abc import ABC, abstractmethod
- from octoai.client import Client
- from functools import partial
- from openai import OpenAI
- # Configure logging to include the timestamp, log level, and message
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
- # Manage rate limits with throttling
- rate_limit_threshold = 2000
- allowed_concurrent_requests = int(rate_limit_threshold * 0.75)
- request_limiter = asyncio.Semaphore(allowed_concurrent_requests)
- class ChatService(ABC):
- @abstractmethod
- async def execute_chat_request_async(self, api_context: dict, chat_request):
- pass
- # Please implement your own chat service class here.
- # The class should inherit from the ChatService class and implement the execute_chat_request_async method.
- # The following are two example chat service classes that you can use as a reference.
- class OctoAIChatService(ChatService):
- async def execute_chat_request_async(self, api_context: dict, chat_request):
- async with request_limiter:
- try:
- event_loop = asyncio.get_running_loop()
- client = Client(api_context['api_key'])
- api_chat_call = partial(
- client.chat.completions.create,
- model=api_context['model'],
- messages=chat_request,
- temperature=0.0
- )
- response = await event_loop.run_in_executor(None, api_chat_call)
- assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
- assistant_response_json = parse_qa_to_json(assistant_response)
- return assistant_response_json
- except Exception as error:
- print(f"Error during chat request execution: {error}")
- return ""
- # Use the local vllm openai compatible server for generating question/answer pairs to make API call syntax consistent
- # please read for more detail:https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html.
- class VllmChatService(ChatService):
- async def execute_chat_request_async(self, api_context: dict, chat_request):
- async with request_limiter:
- try:
- event_loop = asyncio.get_running_loop()
- client = OpenAI(api_key="EMPTY", base_url="http://localhost:"+ api_context['end_point']+"/v1")
- api_chat_call = partial(
- client.chat.completions.create,
- model=api_context['model'],
- messages=chat_request,
- temperature=0.0
- )
- response = await event_loop.run_in_executor(None, api_chat_call)
- assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
- assistant_response_json = parse_qa_to_json(assistant_response)
- return assistant_response_json
- except Exception as error:
- print(f"Error during chat request execution: {error}")
- return ""
- async def main(context):
- if context["endpoint"]:
- logging.info(f" Use local vllm service at port '{context["endpoint"]}'.")
- chat_service = VllmChatService()
- else:
- chat_service = OctoAIChatService()
- try:
- logging.info("Starting to generate question/answer pairs.")
- data = await generate_question_batches(chat_service, context)
- if not data:
- logging.warning("No data generated. Please check the input context or model configuration.")
- return
- flattened_list = list(chain.from_iterable(data))
- logging.info(f"Successfully generated {len(flattened_list)} question/answer pairs.")
- # Use asynchronous file operation for writing to the file
- async with aiofiles.open("data.json", "w") as output_file:
- await output_file.write(json.dumps(flattened_list, indent=4))
- logging.info("Data successfully written to 'data.json'. Process completed.")
- except Exception as e:
- logging.error(f"An unexpected error occurred during the process: {e}")
- def parse_arguments():
- # Define command line arguments for the script
- parser = argparse.ArgumentParser(
- description="Generate question/answer pairs from documentation."
- )
- parser.add_argument(
- "-t", "--total_questions",
- type=int,
- default=10,
- help="Specify the number of question/answer pairs to generate."
- )
- parser.add_argument(
- "-m", "--model",
- choices=["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct","llama-2-70b-chat-fp16", "llama-2-13b-chat-fp16"],
- default="meta-llama-3-70b-instruct",
- help="Select the model to use for generation."
- )
- parser.add_argument(
- "-c", "--config_path",
- default="config.yaml",
- help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
- )
- parser.add_argument(
- "-v", "--vllm_endpoint",
- default=None,
- help="If a port is specified, then use local vllm endpoint for generating question/answer pairs."
- return parser.parse_args()
- if __name__ == "__main__":
- logging.info("Initializing the process and loading configuration...")
- args = parse_arguments()
- context = load_config(args.config_path)
- context["total_questions"] = args.total_questions
- context["model"] = args.model
- context["endpoint"] = args.vllm_endpoint
- logging.info(f"Configuration loaded. Generating {args.total_questions} question/answer pairs using model '{args.model}'.")
- asyncio.run(main(context))
|