|
@@ -12,6 +12,7 @@ import aiofiles # Ensure aiofiles is installed for async file operations
|
|
from abc import ABC, abstractmethod
|
|
from abc import ABC, abstractmethod
|
|
from octoai.client import Client
|
|
from octoai.client import Client
|
|
from functools import partial
|
|
from functools import partial
|
|
|
|
+from openai import OpenAI
|
|
|
|
|
|
# Configure logging to include the timestamp, log level, and message
|
|
# Configure logging to include the timestamp, log level, and message
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
@@ -28,6 +29,7 @@ class ChatService(ABC):
|
|
|
|
|
|
# Please implement your own chat service class here.
|
|
# Please implement your own chat service class here.
|
|
# The class should inherit from the ChatService class and implement the execute_chat_request_async method.
|
|
# The class should inherit from the ChatService class and implement the execute_chat_request_async method.
|
|
|
|
+# The following are two example chat service classes that you can use as a reference.
|
|
class OctoAIChatService(ChatService):
|
|
class OctoAIChatService(ChatService):
|
|
async def execute_chat_request_async(self, api_context: dict, chat_request):
|
|
async def execute_chat_request_async(self, api_context: dict, chat_request):
|
|
async with request_limiter:
|
|
async with request_limiter:
|
|
@@ -43,14 +45,40 @@ class OctoAIChatService(ChatService):
|
|
response = await event_loop.run_in_executor(None, api_chat_call)
|
|
response = await event_loop.run_in_executor(None, api_chat_call)
|
|
assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
|
|
assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
|
|
assistant_response_json = parse_qa_to_json(assistant_response)
|
|
assistant_response_json = parse_qa_to_json(assistant_response)
|
|
-
|
|
|
|
|
|
+
|
|
return assistant_response_json
|
|
return assistant_response_json
|
|
except Exception as error:
|
|
except Exception as error:
|
|
print(f"Error during chat request execution: {error}")
|
|
print(f"Error during chat request execution: {error}")
|
|
return ""
|
|
return ""
|
|
-
|
|
|
|
|
|
+# Use the local vllm openai compatible server for generating question/answer pairs to make API call syntax consistent
|
|
|
|
+# please read for more detail:https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html.
|
|
|
|
+class VllmChatService(ChatService):
|
|
|
|
+ async def execute_chat_request_async(self, api_context: dict, chat_request):
|
|
|
|
+ async with request_limiter:
|
|
|
|
+ try:
|
|
|
|
+ event_loop = asyncio.get_running_loop()
|
|
|
|
+ client = OpenAI(api_key="EMPTY", base_url="http://localhost:"+ api_context['end_point']+"/v1")
|
|
|
|
+ api_chat_call = partial(
|
|
|
|
+ client.chat.completions.create,
|
|
|
|
+ model=api_context['model'],
|
|
|
|
+ messages=chat_request,
|
|
|
|
+ temperature=0.0
|
|
|
|
+ )
|
|
|
|
+ response = await event_loop.run_in_executor(None, api_chat_call)
|
|
|
|
+ assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
|
|
|
|
+ assistant_response_json = parse_qa_to_json(assistant_response)
|
|
|
|
+
|
|
|
|
+ return assistant_response_json
|
|
|
|
+ except Exception as error:
|
|
|
|
+ print(f"Error during chat request execution: {error}")
|
|
|
|
+ return ""
|
|
|
|
+
|
|
async def main(context):
|
|
async def main(context):
|
|
- chat_service = OctoAIChatService()
|
|
|
|
|
|
+ if context["endpoint"]:
|
|
|
|
+ logging.info(f" Use local vllm service at port '{context["endpoint"]}'.")
|
|
|
|
+ chat_service = VllmChatService()
|
|
|
|
+ else:
|
|
|
|
+ chat_service = OctoAIChatService()
|
|
try:
|
|
try:
|
|
logging.info("Starting to generate question/answer pairs.")
|
|
logging.info("Starting to generate question/answer pairs.")
|
|
data = await generate_question_batches(chat_service, context)
|
|
data = await generate_question_batches(chat_service, context)
|
|
@@ -80,8 +108,8 @@ def parse_arguments():
|
|
)
|
|
)
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
"-m", "--model",
|
|
"-m", "--model",
|
|
- choices=["llama-2-70b-chat-fp16", "llama-2-13b-chat-fp16"],
|
|
|
|
- default="llama-2-70b-chat-fp16",
|
|
|
|
|
|
+ choices=["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct","llama-2-70b-chat-fp16", "llama-2-13b-chat-fp16"],
|
|
|
|
+ default="meta-llama-3-70b-instruct",
|
|
help="Select the model to use for generation."
|
|
help="Select the model to use for generation."
|
|
)
|
|
)
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
@@ -89,6 +117,11 @@ def parse_arguments():
|
|
default="config.yaml",
|
|
default="config.yaml",
|
|
help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
|
|
help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
|
|
)
|
|
)
|
|
|
|
+ parser.add_argument(
|
|
|
|
+ "-v", "--vllm_endpoint",
|
|
|
|
+ default=None,
|
|
|
|
+ help="If a port is specified, then use local vllm endpoint for generating question/answer pairs."
|
|
|
|
+
|
|
return parser.parse_args()
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
@@ -98,6 +131,6 @@ if __name__ == "__main__":
|
|
context = load_config(args.config_path)
|
|
context = load_config(args.config_path)
|
|
context["total_questions"] = args.total_questions
|
|
context["total_questions"] = args.total_questions
|
|
context["model"] = args.model
|
|
context["model"] = args.model
|
|
-
|
|
|
|
|
|
+ context["endpoint"] = args.vllm_endpoint
|
|
logging.info(f"Configuration loaded. Generating {args.total_questions} question/answer pairs using model '{args.model}'.")
|
|
logging.info(f"Configuration loaded. Generating {args.total_questions} question/answer pairs using model '{args.model}'.")
|
|
- asyncio.run(main(context))
|
|
|
|
|
|
+ asyncio.run(main(context))
|