| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- import json
- import logging
- from typing import Any, Dict, List, Optional, TypedDict, Union
- import requests
- from tqdm import tqdm
- from vllm import SamplingParams
- from ..data.data_loader import convert_to_conversations, get_formatter, load_data
- class VLLMInferenceRequest(TypedDict):
- """Type definition for VLLM inference request format."""
- messages: List[List[Dict[str, Any]]]
- sampling_params: Union[SamplingParams, List[SamplingParams]]
- class VLLMClient:
- """Client for interacting with a vLLM server."""
- def __init__(self, server_url: str = "http://localhost:8000/v1"):
- """
- Initialize the vLLM client.
- Args:
- server_url: URL of the vLLM server
- """
- self.server_url = server_url
- self.logger = logging.getLogger(__name__)
- def generate(self, request: VLLMInferenceRequest) -> Dict[str, Any]:
- """
- Send a request to the vLLM server and get the response.
- Args:
- request: The inference request
- Returns:
- The response from the vLLM server
- Raises:
- requests.exceptions.RequestException: If the request fails
- """
- # Format the request for the OpenAI-compatible API
- vllm_request = {
- "messages": request.get("messages", []),
- "temperature": request.get("temperature", 0.7),
- "top_p": request.get("top_p", 1.0),
- "max_tokens": request.get("max_completion_tokens", 100),
- }
- if "seed" in request:
- vllm_request["seed"] = request["seed"]
- if "response_format" in request:
- vllm_request["response_format"] = request["response_format"]
- # Send the request to the vLLM server
- try:
- response = requests.post(
- f"{self.server_url}/chat/completions",
- json=vllm_request,
- headers={"Content-Type": "application/json"},
- )
- response.raise_for_status()
- return response.json()
- except requests.exceptions.RequestException as e:
- self.logger.error(f"Error sending request to vLLM server: {e}")
- raise
|