Ver código fonte

added vllm cllient

Ubuntu 3 meses atrás
pai
commit
e68e908d93
1 arquivos alterados com 61 adições e 17 exclusões
  1. 61 17
      src/finetune_pipeline/inference/run_inference.py

+ 61 - 17
src/finetune_pipeline/inference/run_inference.py

@@ -1,27 +1,71 @@
-"""
-Script to run the model inference.
-
-"""
-
+import json
+import logging
 from typing import Any, Dict, List, Optional, TypedDict, Union
 
-from vllm import SamplingParams
-
+import requests
+from tqdm import tqdm
 
-class InferenceRequest(TypedDict, total=False):
-    """Type definition for LLM inference request."""
+from vllm import SamplingParams
 
-    model: str
-    messages: List[Message]
-    temperature: float
-    top_p: float
-    max_completion_tokens: int
-    seed: int
-    response_format: Optional[Dict[str, Any]]
+from ..data.data_loader import convert_to_conversations, get_formatter, load_data
 
 
 class VLLMInferenceRequest(TypedDict):
     """Type definition for VLLM inference request format."""
 
-    messages: List[List[Message]]
+    messages: List[List[Dict[str, Any]]]
     sampling_params: Union[SamplingParams, List[SamplingParams]]
+
+
+class VLLMClient:
+    """Client for interacting with a vLLM server."""
+
+    def __init__(self, server_url: str = "http://localhost:8000/v1"):
+        """
+        Initialize the vLLM client.
+
+        Args:
+            server_url: URL of the vLLM server
+        """
+        self.server_url = server_url
+        self.logger = logging.getLogger(__name__)
+
+    def generate(self, request: VLLMInferenceRequest) -> Dict[str, Any]:
+        """
+        Send a request to the vLLM server and get the response.
+
+        Args:
+            request: The inference request
+
+        Returns:
+            The response from the vLLM server
+
+        Raises:
+            requests.exceptions.RequestException: If the request fails
+        """
+        # Format the request for the OpenAI-compatible API
+        vllm_request = {
+            "messages": request.get("messages", []),
+            "temperature": request.get("temperature", 0.7),
+            "top_p": request.get("top_p", 1.0),
+            "max_tokens": request.get("max_completion_tokens", 100),
+        }
+
+        if "seed" in request:
+            vllm_request["seed"] = request["seed"]
+
+        if "response_format" in request:
+            vllm_request["response_format"] = request["response_format"]
+
+        # Send the request to the vLLM server
+        try:
+            response = requests.post(
+                f"{self.server_url}/chat/completions",
+                json=vllm_request,
+                headers={"Content-Type": "application/json"},
+            )
+            response.raise_for_status()
+            return response.json()
+        except requests.exceptions.RequestException as e:
+            self.logger.error(f"Error sending request to vLLM server: {e}")
+            raise