瀏覽代碼

added vllm cllient

Ubuntu 1 月之前
父節點
當前提交
99943dd7b1
共有 1 個文件被更改,包括 49 次插入2 次删除
  1. 49 2
      src/finetune_pipeline/inference/run_inference.py

+ 49 - 2
src/finetune_pipeline/inference/run_inference.py

@@ -14,7 +14,7 @@ class VLLMInferenceRequest(TypedDict):
     """Type definition for VLLM inference request format."""
 
     messages: List[List[Dict[str, Any]]]
-    sampling_params: Union[SamplingParams, List[SamplingParams]]
+    sampling_params: Union[SamplingParams, List[SamplingParams]]    
 
 
 class VLLMClient:
@@ -68,4 +68,51 @@ class VLLMClient:
             return response.json()
         except requests.exceptions.RequestException as e:
             self.logger.error(f"Error sending request to vLLM server: {e}")
-            raise
+            raise
+
+def run_inference_on_eval_data(
+    eval_data_path: str,
+    server_url: str = "http://localhost:8000/v1",
+    is_local: bool = False,
+    temperature: float = 0.0,
+    top_p: float = 1.0,
+    max_tokens: int = 100,
+    seed: Optional[int] = None,
+    dataset_kwargs: Optional[Dict[str, Any]] = None,
+    column_mapping: Optional[Dict[str, str]] = None,
+) -> List[Dict[str, Any]]:
+    """
+    Run inference on evaluation data using a vLLM server.
+
+    Args:
+        eval_data_path: Path to the evaluation data
+        server_url: URL of the vLLM server
+        is_local: Whether the data is stored locally
+        temperature: Temperature for sampling
+        top_p: Top-p for sampling
+        max_tokens: Maximum number of tokens to generate
+        seed: Random seed for reproducibility
+        dataset_kwargs: Additional arguments to pass to the load_dataset function
+        column_mapping: Mapping of column names
+
+    Returns:
+        List of responses from the vLLM server
+    """
+    # Initialize the vLLM client
+    client = VLLMClient(server_url)
+
+    # Load the evaluation data
+    if dataset_kwargs is None:
+        dataset_kwargs = {}
+
+    eval_data = load_data(eval_data_path, is_local, **dataset_kwargs)
+
+    # Convert the data to conversations
+    conversations = convert_to_conversations(eval_data, column_mapping)
+
+    # Convert the conversations to vLLM format
+    vllm_formatter = get_formatter("vllm")
+    formatted_data = vllm_formatter.format_data(conversations)
+
+    # Run inference on the formatted data
+    #pass