run_inference.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import json
  2. import logging
  3. from typing import Any, Dict, List, Optional, TypedDict, Union
  4. import requests
  5. from tqdm import tqdm
  6. from vllm import SamplingParams
  7. from ..data.data_loader import convert_to_conversations, get_formatter, load_data
  8. class VLLMInferenceRequest(TypedDict):
  9. """Type definition for VLLM inference request format."""
  10. messages: List[List[Dict[str, Any]]]
  11. sampling_params: Union[SamplingParams, List[SamplingParams]]
  12. class VLLMClient:
  13. """Client for interacting with a vLLM server."""
  14. def __init__(self, server_url: str = "http://localhost:8000/v1"):
  15. """
  16. Initialize the vLLM client.
  17. Args:
  18. server_url: URL of the vLLM server
  19. """
  20. self.server_url = server_url
  21. self.logger = logging.getLogger(__name__)
  22. def generate(self, request: VLLMInferenceRequest) -> Dict[str, Any]:
  23. """
  24. Send a request to the vLLM server and get the response.
  25. Args:
  26. request: The inference request
  27. Returns:
  28. The response from the vLLM server
  29. Raises:
  30. requests.exceptions.RequestException: If the request fails
  31. """
  32. # Format the request for the OpenAI-compatible API
  33. vllm_request = {
  34. "messages": request.get("messages", []),
  35. "temperature": request.get("temperature", 0.7),
  36. "top_p": request.get("top_p", 1.0),
  37. "max_tokens": request.get("max_completion_tokens", 100),
  38. }
  39. if "seed" in request:
  40. vllm_request["seed"] = request["seed"]
  41. if "response_format" in request:
  42. vllm_request["response_format"] = request["response_format"]
  43. # Send the request to the vLLM server
  44. try:
  45. response = requests.post(
  46. f"{self.server_url}/chat/completions",
  47. json=vllm_request,
  48. headers={"Content-Type": "application/json"},
  49. )
  50. response.raise_for_status()
  51. return response.json()
  52. except requests.exceptions.RequestException as e:
  53. self.logger.error(f"Error sending request to vLLM server: {e}")
  54. raise