浏览代码

adding support for vllm local endpoint and llama3 model

Kai Wu 1 年之前
父节点
当前提交
230c557730

+ 40 - 7
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/generate_question_answers.py

@@ -12,6 +12,7 @@ import aiofiles  # Ensure aiofiles is installed for async file operations
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from octoai.client import Client
 from octoai.client import Client
 from functools import partial
 from functools import partial
+from openai import OpenAI
 
 
 # Configure logging to include the timestamp, log level, and message
 # Configure logging to include the timestamp, log level, and message
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -28,6 +29,7 @@ class ChatService(ABC):
 
 
 # Please implement your own chat service class here.
 # Please implement your own chat service class here.
 # The class should inherit from the ChatService class and implement the execute_chat_request_async method.
 # The class should inherit from the ChatService class and implement the execute_chat_request_async method.
+# The following are two example chat service classes that you can use as a reference.
 class OctoAIChatService(ChatService):
 class OctoAIChatService(ChatService):
     async def execute_chat_request_async(self, api_context: dict, chat_request):
     async def execute_chat_request_async(self, api_context: dict, chat_request):
         async with request_limiter:
         async with request_limiter:
@@ -43,14 +45,40 @@ class OctoAIChatService(ChatService):
                 response = await event_loop.run_in_executor(None, api_chat_call)
                 response = await event_loop.run_in_executor(None, api_chat_call)
                 assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
                 assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
                 assistant_response_json = parse_qa_to_json(assistant_response)
                 assistant_response_json = parse_qa_to_json(assistant_response)
-                      
+
                 return assistant_response_json
                 return assistant_response_json
             except Exception as error:
             except Exception as error:
                 print(f"Error during chat request execution: {error}")
                 print(f"Error during chat request execution: {error}")
                 return ""
                 return ""
-            
+# Use the local vllm openai compatible server for generating question/answer pairs to make API call syntax consistent
+# please read for more detail:https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html.
+class VllmChatService(ChatService):
+    async def execute_chat_request_async(self, api_context: dict, chat_request):
+        async with request_limiter:
+            try:
+                event_loop = asyncio.get_running_loop()
+                client = OpenAI(api_key="EMPTY", base_url="http://localhost:"+ api_context['end_point']+"/v1")
+                api_chat_call = partial(
+                    client.chat.completions.create,
+                    model=api_context['model'],
+                    messages=chat_request,
+                    temperature=0.0
+                )
+                response = await event_loop.run_in_executor(None, api_chat_call)
+                assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
+                assistant_response_json = parse_qa_to_json(assistant_response)
+
+                return assistant_response_json
+            except Exception as error:
+                print(f"Error during chat request execution: {error}")
+                return ""
+
 async def main(context):
 async def main(context):
-    chat_service = OctoAIChatService()
+    if context["endpoint"]:
+        logging.info(f" Use local vllm service at port '{context["endpoint"]}'.")
+        chat_service = VllmChatService()
+    else:
+        chat_service = OctoAIChatService()
     try:
     try:
         logging.info("Starting to generate question/answer pairs.")
         logging.info("Starting to generate question/answer pairs.")
         data = await generate_question_batches(chat_service, context)
         data = await generate_question_batches(chat_service, context)
@@ -80,8 +108,8 @@ def parse_arguments():
     )
     )
     parser.add_argument(
     parser.add_argument(
         "-m", "--model",
         "-m", "--model",
-        choices=["llama-2-70b-chat-fp16", "llama-2-13b-chat-fp16"],
-        default="llama-2-70b-chat-fp16",
+        choices=["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct","llama-2-70b-chat-fp16", "llama-2-13b-chat-fp16"],
+        default="meta-llama-3-70b-instruct",
         help="Select the model to use for generation."
         help="Select the model to use for generation."
     )
     )
     parser.add_argument(
     parser.add_argument(
@@ -89,6 +117,11 @@ def parse_arguments():
         default="config.yaml",
         default="config.yaml",
         help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
         help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
     )
     )
+    parser.add_argument(
+        "-v", "--vllm_endpoint",
+        default=None,
+        help="If a port is specified, then use local vllm endpoint for generating question/answer pairs."
+
     return parser.parse_args()
     return parser.parse_args()
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
@@ -98,6 +131,6 @@ if __name__ == "__main__":
     context = load_config(args.config_path)
     context = load_config(args.config_path)
     context["total_questions"] = args.total_questions
     context["total_questions"] = args.total_questions
     context["model"] = args.model
     context["model"] = args.model
-
+    context["endpoint"] = args.vllm_endpoint
     logging.info(f"Configuration loaded. Generating {args.total_questions} question/answer pairs using model '{args.model}'.")
     logging.info(f"Configuration loaded. Generating {args.total_questions} question/answer pairs using model '{args.model}'.")
-    asyncio.run(main(context))
+    asyncio.run(main(context))

+ 6 - 6
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/generator_utils.py

@@ -75,7 +75,7 @@ def parse_qa_to_json(response_string):
     # Adjusted regex to capture question-answer pairs more flexibly
     # Adjusted regex to capture question-answer pairs more flexibly
     # This pattern accounts for optional numbering and different question/answer lead-ins
     # This pattern accounts for optional numbering and different question/answer lead-ins
     pattern = re.compile(
     pattern = re.compile(
-        r"\d*\.\s*Question:\s*(.*?)\nAnswer:\s*(.*?)(?=\n\d*\.\s*Question:|\Z)", 
+        r"\d*\.\s*Question:\s*(.*?)\nAnswer:\s*(.*?)(?=\n\d*\.\s*Question:|\Z)",
         re.DOTALL
         re.DOTALL
     )
     )
 
 
@@ -96,9 +96,12 @@ async def prepare_and_send_request(chat_service, api_context: dict, document_con
 
 
 async def generate_question_batches(chat_service, api_context: dict):
 async def generate_question_batches(chat_service, api_context: dict):
     document_text = read_file_content(api_context)
     document_text = read_file_content(api_context)
-    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
+    if api_context["model"] in ["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct"]:
+        tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", pad_token="</s>", padding_side="right")
+    else:
+        tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
     document_batches = split_text_into_chunks(api_context, document_text, tokenizer)
     document_batches = split_text_into_chunks(api_context, document_text, tokenizer)
-    
+
     total_questions = api_context["total_questions"]
     total_questions = api_context["total_questions"]
     batches_count = len(document_batches)
     batches_count = len(document_batches)
     base_questions_per_batch = total_questions // batches_count
     base_questions_per_batch = total_questions // batches_count
@@ -116,6 +119,3 @@ async def generate_question_batches(chat_service, api_context: dict):
     question_generation_results = await asyncio.gather(*generation_tasks)
     question_generation_results = await asyncio.gather(*generation_tasks)
 
 
     return question_generation_results
     return question_generation_results
-
-
-