start_vllm_server.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. #!/usr/bin/env python
  2. """
  3. Script to start a vLLM server for inference.
  4. This script provides a convenient way to start a vLLM server with various configuration options.
  5. It supports loading models from local paths or Hugging Face model IDs.
  6. Example usage:
  7. python start_vllm_server.py --model-path meta-llama/Llama-2-7b-chat-hf
  8. python start_vllm_server.py --model-path /path/to/local/model --port 8080
  9. python start_vllm_server.py --config /path/to/config.yaml
  10. python start_vllm_server.py # Uses the default config.yaml in the parent directory
  11. """
  12. import argparse
  13. import json
  14. import logging
  15. import os
  16. import subprocess
  17. import sys
  18. from pathlib import Path
  19. from typing import Dict, Optional, Union
  20. # Configure logging
  21. logging.basicConfig(
  22. format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  23. datefmt="%Y-%m-%d %H:%M:%S",
  24. level=logging.INFO,
  25. )
  26. logger = logging.getLogger(__name__)
  27. # Try to import yaml for config file support
  28. try:
  29. import yaml
  30. HAS_YAML = True
  31. except ImportError:
  32. HAS_YAML = False
  33. logger.warning("PyYAML not installed. Config file support limited to JSON format.")
  34. def read_config(config_path: str) -> Dict:
  35. """
  36. Read the configuration file (supports both JSON and YAML formats).
  37. Args:
  38. config_path: Path to the configuration file
  39. Returns:
  40. dict: Configuration parameters
  41. Raises:
  42. ValueError: If the file format is not supported
  43. ImportError: If the required package for the file format is not installed
  44. """
  45. file_extension = Path(config_path).suffix.lower()
  46. with open(config_path, "r") as f:
  47. if file_extension in [".json"]:
  48. config = json.load(f)
  49. elif file_extension in [".yaml", ".yml"]:
  50. if not HAS_YAML:
  51. raise ImportError(
  52. "The 'pyyaml' package is required to load YAML files. "
  53. "Please install it with 'pip install pyyaml'."
  54. )
  55. config = yaml.safe_load(f)
  56. else:
  57. raise ValueError(
  58. f"Unsupported config file format: {file_extension}. "
  59. f"Supported formats are: .json, .yaml, .yml"
  60. )
  61. return config
  62. def check_vllm_installed() -> bool:
  63. """
  64. Check if vLLM is installed.
  65. Returns:
  66. bool: True if vLLM is installed, False otherwise
  67. """
  68. try:
  69. subprocess.run(
  70. ["vllm", "--help"],
  71. stdout=subprocess.PIPE,
  72. stderr=subprocess.PIPE,
  73. check=False,
  74. )
  75. return True
  76. except FileNotFoundError:
  77. return False
  78. def start_vllm_server(
  79. model_path: str,
  80. port: int = 8000,
  81. host: str = "0.0.0.0",
  82. tensor_parallel_size: int = 1,
  83. max_model_len: int = 4096,
  84. max_num_seqs: int = 256,
  85. quantization: Optional[str] = None,
  86. gpu_memory_utilization: float = 0.9,
  87. enforce_eager: bool = False,
  88. additional_args: Optional[Dict] = None,
  89. ) -> None:
  90. """
  91. Start a vLLM server with the specified parameters.
  92. Args:
  93. model_path: Path to the model or Hugging Face model ID
  94. port: Port to run the server on
  95. host: Host to run the server on
  96. tensor_parallel_size: Number of GPUs to use for tensor parallelism
  97. max_model_len: Maximum sequence length
  98. max_num_seqs: Maximum number of sequences
  99. quantization: Quantization method (e.g., "awq", "gptq", "squeezellm")
  100. dtype: Data type for model weights (e.g., "half", "float", "bfloat16", "auto")
  101. gpu_memory_utilization: Fraction of GPU memory to use
  102. trust_remote_code: Whether to trust remote code when loading the model
  103. enforce_eager: Whether to enforce eager execution
  104. additional_args: Additional arguments to pass to vLLM
  105. Raises:
  106. subprocess.CalledProcessError: If the vLLM server fails to start
  107. FileNotFoundError: If vLLM is not installed
  108. """
  109. # Check if vLLM is installed
  110. if not check_vllm_installed():
  111. logger.error(
  112. "vLLM is not installed. Please install it with 'pip install vllm'."
  113. )
  114. sys.exit(1)
  115. # Build the command
  116. cmd = ["vllm", "serve", model_path]
  117. # Add basic parameters
  118. cmd.extend(["--port", str(port)])
  119. cmd.extend(["--host", host])
  120. cmd.extend(["--tensor-parallel-size", str(tensor_parallel_size)])
  121. cmd.extend(["--max-model-len", str(max_model_len)])
  122. cmd.extend(["--max-num-seqs", str(max_num_seqs)])
  123. cmd.extend(["--gpu-memory-utilization", str(gpu_memory_utilization)])
  124. # Add optional parameters
  125. if quantization:
  126. cmd.extend(["--quantization", quantization])
  127. if enforce_eager:
  128. cmd.append("--enforce-eager")
  129. # Log the command
  130. logger.info(f"Starting vLLM server with command: {' '.join(cmd)}")
  131. # Run the command
  132. try:
  133. subprocess.run(cmd, check=True)
  134. except subprocess.CalledProcessError as e:
  135. logger.error(f"Failed to start vLLM server: {e}")
  136. sys.exit(1)
  137. except KeyboardInterrupt:
  138. logger.info("vLLM server stopped by user.")
  139. sys.exit(0)
  140. def find_config_file():
  141. """
  142. Find the config.yaml file in the parent directory.
  143. Returns:
  144. str: Path to the config file
  145. """
  146. # Try to find the config file in the parent directory
  147. script_dir = Path(__file__).resolve().parent
  148. parent_dir = script_dir.parent
  149. config_path = parent_dir / "config.yaml"
  150. if config_path.exists():
  151. return str(config_path)
  152. else:
  153. return None
  154. def main():
  155. """Main function."""
  156. parser = argparse.ArgumentParser(description="Start a vLLM server for inference")
  157. # Configuration options
  158. config_group = parser.add_argument_group("Configuration")
  159. config_group.add_argument(
  160. "--config",
  161. type=str,
  162. help="Path to a configuration file (JSON or YAML)",
  163. )
  164. # Model options
  165. model_group = parser.add_argument_group("Model")
  166. model_group.add_argument(
  167. "--model-path",
  168. type=str,
  169. help="Path to the model or Hugging Face model ID",
  170. )
  171. model_group.add_argument(
  172. "--quantization",
  173. type=str,
  174. choices=["awq", "gptq", "squeezellm"],
  175. help="Quantization method to use",
  176. )
  177. # Server options
  178. server_group = parser.add_argument_group("Server")
  179. server_group.add_argument(
  180. "--port",
  181. type=int,
  182. default=8000,
  183. help="Port to run the server on",
  184. )
  185. server_group.add_argument(
  186. "--host",
  187. type=str,
  188. default="0.0.0.0",
  189. help="Host to run the server on",
  190. )
  191. # Performance options
  192. perf_group = parser.add_argument_group("Performance")
  193. perf_group.add_argument(
  194. "--tensor-parallel-size",
  195. type=int,
  196. default=1,
  197. help="Number of GPUs to use for tensor parallelism",
  198. )
  199. perf_group.add_argument(
  200. "--max-model-len",
  201. type=int,
  202. default=4096,
  203. help="Maximum sequence length",
  204. )
  205. perf_group.add_argument(
  206. "--max-num-seqs",
  207. type=int,
  208. default=256,
  209. help="Maximum number of sequences",
  210. )
  211. perf_group.add_argument(
  212. "--gpu-memory-utilization",
  213. type=float,
  214. default=0.9,
  215. help="Fraction of GPU memory to use",
  216. )
  217. perf_group.add_argument(
  218. "--enforce-eager",
  219. action="store_true",
  220. help="Enforce eager execution",
  221. )
  222. args = parser.parse_args()
  223. # Load config file
  224. config = {}
  225. config_path = args.config
  226. # If no config file is provided, try to find the default one
  227. if not config_path:
  228. config_path = find_config_file()
  229. if config_path:
  230. logger.info(f"Using default config file: {config_path}")
  231. if config_path:
  232. try:
  233. config = read_config(config_path)
  234. logger.info(f"Loaded configuration from {config_path}")
  235. except Exception as e:
  236. logger.error(f"Failed to load configuration from {config_path}: {e}")
  237. sys.exit(1)
  238. # Extract inference section from config if it exists
  239. inference_config = config.get("inference", {})
  240. # Merge command-line arguments with config file
  241. # Command-line arguments take precedence
  242. model_path = args.model_path or inference_config.get("model_path")
  243. if not model_path:
  244. logger.error(
  245. "Model path must be provided either via --model-path or in the config file under inference.model_path"
  246. )
  247. sys.exit(1)
  248. # Extract parameters
  249. params = {
  250. "model_path": model_path,
  251. "port": (
  252. args.port
  253. if args.port != parser.get_default("port")
  254. else inference_config.get("port", args.port)
  255. ),
  256. "host": (
  257. args.host
  258. if args.host != parser.get_default("host")
  259. else inference_config.get("host", args.host)
  260. ),
  261. "tensor_parallel_size": (
  262. args.tensor_parallel_size
  263. if args.tensor_parallel_size != parser.get_default("tensor_parallel_size")
  264. else inference_config.get("tensor_parallel_size", args.tensor_parallel_size)
  265. ),
  266. "max_model_len": (
  267. args.max_model_len
  268. if args.max_model_len != parser.get_default("max_model_len")
  269. else inference_config.get("max_model_len", args.max_model_len)
  270. ),
  271. "max_num_seqs": (
  272. args.max_num_seqs
  273. if args.max_num_seqs != parser.get_default("max_num_seqs")
  274. else inference_config.get("max_num_seqs", args.max_num_seqs)
  275. ),
  276. "quantization": args.quantization or inference_config.get("quantization"),
  277. "gpu_memory_utilization": (
  278. args.gpu_memory_utilization
  279. if args.gpu_memory_utilization
  280. != parser.get_default("gpu_memory_utilization")
  281. else inference_config.get(
  282. "gpu_memory_utilization", args.gpu_memory_utilization
  283. )
  284. ),
  285. "enforce_eager": args.enforce_eager
  286. or inference_config.get("enforce_eager", False),
  287. }
  288. # Get additional arguments from inference config
  289. additional_args = {k: v for k, v in inference_config.items() if k not in params}
  290. if additional_args:
  291. params["additional_args"] = additional_args
  292. # Start the vLLM server
  293. start_vllm_server(**params)
  294. if __name__ == "__main__":
  295. main()