|
@@ -1,27 +1,71 @@
|
|
|
-"""
|
|
|
|
|
-Script to run the model inference.
|
|
|
|
|
-
|
|
|
|
|
-"""
|
|
|
|
|
-
|
|
|
|
|
|
|
+import json
|
|
|
|
|
+import logging
|
|
|
from typing import Any, Dict, List, Optional, TypedDict, Union
|
|
from typing import Any, Dict, List, Optional, TypedDict, Union
|
|
|
|
|
|
|
|
-from vllm import SamplingParams
|
|
|
|
|
-
|
|
|
|
|
|
|
+import requests
|
|
|
|
|
+from tqdm import tqdm
|
|
|
|
|
|
|
|
-class InferenceRequest(TypedDict, total=False):
|
|
|
|
|
- """Type definition for LLM inference request."""
|
|
|
|
|
|
|
+from vllm import SamplingParams
|
|
|
|
|
|
|
|
- model: str
|
|
|
|
|
- messages: List[Message]
|
|
|
|
|
- temperature: float
|
|
|
|
|
- top_p: float
|
|
|
|
|
- max_completion_tokens: int
|
|
|
|
|
- seed: int
|
|
|
|
|
- response_format: Optional[Dict[str, Any]]
|
|
|
|
|
|
|
+from ..data.data_loader import convert_to_conversations, get_formatter, load_data
|
|
|
|
|
|
|
|
|
|
|
|
|
class VLLMInferenceRequest(TypedDict):
|
|
class VLLMInferenceRequest(TypedDict):
|
|
|
"""Type definition for VLLM inference request format."""
|
|
"""Type definition for VLLM inference request format."""
|
|
|
|
|
|
|
|
- messages: List[List[Message]]
|
|
|
|
|
|
|
+ messages: List[List[Dict[str, Any]]]
|
|
|
sampling_params: Union[SamplingParams, List[SamplingParams]]
|
|
sampling_params: Union[SamplingParams, List[SamplingParams]]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class VLLMClient:
|
|
|
|
|
+ """Client for interacting with a vLLM server."""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, server_url: str = "http://localhost:8000/v1"):
|
|
|
|
|
+ """
|
|
|
|
|
+ Initialize the vLLM client.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ server_url: URL of the vLLM server
|
|
|
|
|
+ """
|
|
|
|
|
+ self.server_url = server_url
|
|
|
|
|
+ self.logger = logging.getLogger(__name__)
|
|
|
|
|
+
|
|
|
|
|
+ def generate(self, request: VLLMInferenceRequest) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ Send a request to the vLLM server and get the response.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ request: The inference request
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ The response from the vLLM server
|
|
|
|
|
+
|
|
|
|
|
+ Raises:
|
|
|
|
|
+ requests.exceptions.RequestException: If the request fails
|
|
|
|
|
+ """
|
|
|
|
|
+ # Format the request for the OpenAI-compatible API
|
|
|
|
|
+ vllm_request = {
|
|
|
|
|
+ "messages": request.get("messages", []),
|
|
|
|
|
+ "temperature": request.get("temperature", 0.7),
|
|
|
|
|
+ "top_p": request.get("top_p", 1.0),
|
|
|
|
|
+ "max_tokens": request.get("max_completion_tokens", 100),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if "seed" in request:
|
|
|
|
|
+ vllm_request["seed"] = request["seed"]
|
|
|
|
|
+
|
|
|
|
|
+ if "response_format" in request:
|
|
|
|
|
+ vllm_request["response_format"] = request["response_format"]
|
|
|
|
|
+
|
|
|
|
|
+ # Send the request to the vLLM server
|
|
|
|
|
+ try:
|
|
|
|
|
+ response = requests.post(
|
|
|
|
|
+ f"{self.server_url}/chat/completions",
|
|
|
|
|
+ json=vllm_request,
|
|
|
|
|
+ headers={"Content-Type": "application/json"},
|
|
|
|
|
+ )
|
|
|
|
|
+ response.raise_for_status()
|
|
|
|
|
+ return response.json()
|
|
|
|
|
+ except requests.exceptions.RequestException as e:
|
|
|
|
|
+ self.logger.error(f"Error sending request to vLLM server: {e}")
|
|
|
|
|
+ raise
|