Browse Source

adding support for vllm local endpoint and llama3 model

Kai Wu 1 year ago
parent
commit
230c557730

+ 40 - 7
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/generate_question_answers.py

@@ -12,6 +12,7 @@ import aiofiles  # Ensure aiofiles is installed for async file operations
 from abc import ABC, abstractmethod
 from octoai.client import Client
 from functools import partial
+from openai import OpenAI
 
 # Configure logging to include the timestamp, log level, and message
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -28,6 +29,7 @@ class ChatService(ABC):
 
 # Please implement your own chat service class here.
 # The class should inherit from the ChatService class and implement the execute_chat_request_async method.
+# The following are two example chat service classes that you can use as a reference.
 class OctoAIChatService(ChatService):
     async def execute_chat_request_async(self, api_context: dict, chat_request):
         async with request_limiter:
@@ -43,14 +45,40 @@ class OctoAIChatService(ChatService):
                 response = await event_loop.run_in_executor(None, api_chat_call)
                 assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
                 assistant_response_json = parse_qa_to_json(assistant_response)
-                      
+
                 return assistant_response_json
             except Exception as error:
                 print(f"Error during chat request execution: {error}")
                 return ""
-            
+# Use the local vllm openai compatible server for generating question/answer pairs to make API call syntax consistent
+# please read for more detail:https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html.
+class VllmChatService(ChatService):
+    async def execute_chat_request_async(self, api_context: dict, chat_request):
+        async with request_limiter:
+            try:
+                event_loop = asyncio.get_running_loop()
+                client = OpenAI(api_key="EMPTY", base_url="http://localhost:"+ api_context['end_point']+"/v1")
+                api_chat_call = partial(
+                    client.chat.completions.create,
+                    model=api_context['model'],
+                    messages=chat_request,
+                    temperature=0.0
+                )
+                response = await event_loop.run_in_executor(None, api_chat_call)
+                assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
+                assistant_response_json = parse_qa_to_json(assistant_response)
+
+                return assistant_response_json
+            except Exception as error:
+                print(f"Error during chat request execution: {error}")
+                return ""
+
 async def main(context):
-    chat_service = OctoAIChatService()
+    if context["endpoint"]:
+        logging.info(f" Use local vllm service at port '{context["endpoint"]}'.")
+        chat_service = VllmChatService()
+    else:
+        chat_service = OctoAIChatService()
     try:
         logging.info("Starting to generate question/answer pairs.")
         data = await generate_question_batches(chat_service, context)
@@ -80,8 +108,8 @@ def parse_arguments():
     )
     parser.add_argument(
         "-m", "--model",
-        choices=["llama-2-70b-chat-fp16", "llama-2-13b-chat-fp16"],
-        default="llama-2-70b-chat-fp16",
+        choices=["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct","llama-2-70b-chat-fp16", "llama-2-13b-chat-fp16"],
+        default="meta-llama-3-70b-instruct",
         help="Select the model to use for generation."
     )
     parser.add_argument(
@@ -89,6 +117,11 @@ def parse_arguments():
         default="config.yaml",
         help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
     )
+    parser.add_argument(
+        "-v", "--vllm_endpoint",
+        default=None,
+        help="If a port is specified, then use local vllm endpoint for generating question/answer pairs."
+
     return parser.parse_args()
 
 if __name__ == "__main__":
@@ -98,6 +131,6 @@ if __name__ == "__main__":
     context = load_config(args.config_path)
     context["total_questions"] = args.total_questions
     context["model"] = args.model
-
+    context["endpoint"] = args.vllm_endpoint
     logging.info(f"Configuration loaded. Generating {args.total_questions} question/answer pairs using model '{args.model}'.")
-    asyncio.run(main(context))
+    asyncio.run(main(context))

+ 6 - 6
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/generator_utils.py

@@ -75,7 +75,7 @@ def parse_qa_to_json(response_string):
     # Adjusted regex to capture question-answer pairs more flexibly
     # This pattern accounts for optional numbering and different question/answer lead-ins
     pattern = re.compile(
-        r"\d*\.\s*Question:\s*(.*?)\nAnswer:\s*(.*?)(?=\n\d*\.\s*Question:|\Z)", 
+        r"\d*\.\s*Question:\s*(.*?)\nAnswer:\s*(.*?)(?=\n\d*\.\s*Question:|\Z)",
         re.DOTALL
     )
 
@@ -96,9 +96,12 @@ async def prepare_and_send_request(chat_service, api_context: dict, document_con
 
 async def generate_question_batches(chat_service, api_context: dict):
     document_text = read_file_content(api_context)
-    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
+    if api_context["model"] in ["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct"]:
+        tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", pad_token="</s>", padding_side="right")
+    else:
+        tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
     document_batches = split_text_into_chunks(api_context, document_text, tokenizer)
-    
+
     total_questions = api_context["total_questions"]
     batches_count = len(document_batches)
     base_questions_per_batch = total_questions // batches_count
@@ -116,6 +119,3 @@ async def generate_question_batches(chat_service, api_context: dict):
     question_generation_results = await asyncio.gather(*generation_tasks)
 
     return question_generation_results
-
-
-