benchmark_serving.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771
  1. """Benchmark online serving throughput.
  2. On the server side, run one of the following commands:
  3. vLLM OpenAI API server
  4. vllm serve <your_model> \
  5. --swap-space 16 \
  6. --disable-log-requests
  7. (TGI backend)
  8. ./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
  9. On the client side, run:
  10. python benchmarks/benchmark_serving.py \
  11. --backend <backend> \
  12. --model <your_model> \
  13. --dataset-name sharegpt \
  14. --dataset-path <path to dataset> \
  15. --request-rate <request_rate> \ # By default <request_rate> is inf
  16. --num-prompts <num_prompts> # By default <num_prompts> is 1000
  17. when using tgi backend, add
  18. --endpoint /generate_stream
  19. to the end of the command above.
  20. """
  21. import argparse
  22. import asyncio
  23. import json
  24. import os
  25. import random
  26. import time
  27. import warnings
  28. from dataclasses import dataclass
  29. from datetime import datetime
  30. from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
  31. import numpy as np
  32. from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
  33. RequestFuncOutput)
  34. from tqdm.asyncio import tqdm
  35. from transformers import PreTrainedTokenizerBase
  36. try:
  37. from vllm.transformers_utils.tokenizer import get_tokenizer
  38. except ImportError:
  39. from backend_request_func import get_tokenizer
  40. try:
  41. from vllm.utils import FlexibleArgumentParser
  42. except ImportError:
  43. from argparse import ArgumentParser as FlexibleArgumentParser
  44. @dataclass
  45. class BenchmarkMetrics:
  46. completed: int
  47. total_input: int
  48. total_output: int
  49. request_throughput: float
  50. input_throughput: float
  51. output_throughput: float
  52. mean_ttft_ms: float
  53. median_ttft_ms: float
  54. std_ttft_ms: float
  55. p99_ttft_ms: float
  56. mean_tpot_ms: float
  57. median_tpot_ms: float
  58. std_tpot_ms: float
  59. p99_tpot_ms: float
  60. mean_itl_ms: float
  61. median_itl_ms: float
  62. std_itl_ms: float
  63. p99_itl_ms: float
  64. def sample_sharegpt_requests(
  65. dataset_path: str,
  66. num_requests: int,
  67. tokenizer: PreTrainedTokenizerBase,
  68. fixed_output_len: Optional[int] = None,
  69. ) -> List[Tuple[str, int, int]]:
  70. if fixed_output_len is not None and fixed_output_len < 4:
  71. raise ValueError("output_len too small")
  72. # Load the dataset.
  73. with open(dataset_path) as f:
  74. dataset = json.load(f)
  75. # Filter out the conversations with less than 2 turns.
  76. dataset = [data for data in dataset if len(data["conversations"]) >= 2]
  77. # Only keep the first two turns of each conversation.
  78. dataset = [(data["conversations"][0]["value"],
  79. data["conversations"][1]["value"]) for data in dataset]
  80. # Shuffle the dataset.
  81. random.shuffle(dataset)
  82. # Filter out sequences that are too long or too short
  83. filtered_dataset: List[Tuple[str, int, int]] = []
  84. for i in range(len(dataset)):
  85. if len(filtered_dataset) == num_requests:
  86. break
  87. # Tokenize the prompts and completions.
  88. prompt = dataset[i][0]
  89. prompt_token_ids = tokenizer(prompt).input_ids
  90. completion = dataset[i][1]
  91. completion_token_ids = tokenizer(completion).input_ids
  92. prompt_len = len(prompt_token_ids)
  93. output_len = len(completion_token_ids
  94. ) if fixed_output_len is None else fixed_output_len
  95. if prompt_len < 4 or output_len < 4:
  96. # Prune too short sequences.
  97. continue
  98. if prompt_len > 1024 or prompt_len + output_len > 2048:
  99. # Prune too long sequences.
  100. continue
  101. filtered_dataset.append((prompt, prompt_len, output_len))
  102. return filtered_dataset
  103. def sample_sonnet_requests(
  104. dataset_path: str,
  105. num_requests: int,
  106. input_len: int,
  107. output_len: int,
  108. prefix_len: int,
  109. tokenizer: PreTrainedTokenizerBase,
  110. ) -> List[Tuple[str, str, int, int]]:
  111. assert (
  112. input_len > prefix_len
  113. ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
  114. # Load the dataset.
  115. with open(dataset_path) as f:
  116. poem_lines = f.readlines()
  117. # Tokenize the poem lines.
  118. poem_token_ids = tokenizer(poem_lines).input_ids
  119. average_poem_len = sum(
  120. len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids)
  121. # Base prefix for all requests.
  122. base_prompt = "Pick as many lines as you can from these poem lines:\n"
  123. base_message = [{
  124. "role": "user",
  125. "content": base_prompt,
  126. }]
  127. base_prompt_formatted = tokenizer.apply_chat_template(
  128. base_message, add_generation_prompt=True, tokenize=False)
  129. base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
  130. assert (
  131. input_len > base_prompt_offset
  132. ), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}."
  133. num_input_lines = round(
  134. (input_len - base_prompt_offset) / average_poem_len)
  135. # First approximately `prefix_len` number of tokens in the
  136. # prompt are fixed poem lines.
  137. assert (
  138. prefix_len > base_prompt_offset
  139. ), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."
  140. num_prefix_lines = round(
  141. (prefix_len - base_prompt_offset) / average_poem_len)
  142. prefix_lines = poem_lines[:num_prefix_lines]
  143. # Sample the rest of lines per request.
  144. sampled_requests: List[Tuple[str, int, int]] = []
  145. for _ in range(num_requests):
  146. sampled_lines = "".join(
  147. prefix_lines +
  148. random.sample(poem_lines, num_input_lines - num_prefix_lines))
  149. prompt = f"{base_prompt}{sampled_lines}"
  150. message = [
  151. {
  152. "role": "user",
  153. "content": prompt,
  154. },
  155. ]
  156. prompt_formatted = tokenizer.apply_chat_template(
  157. message, add_generation_prompt=True, tokenize=False)
  158. prompt_len = len(tokenizer(prompt_formatted).input_ids)
  159. sampled_requests.append(
  160. (prompt, prompt_formatted, prompt_len, output_len))
  161. return sampled_requests
  162. def sample_random_requests(
  163. input_len: int, output_len: int, num_prompts: int, range_ratio: float,
  164. tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]:
  165. input_lens = np.random.randint(
  166. int(input_len * range_ratio),
  167. input_len + 1,
  168. size=num_prompts,
  169. )
  170. output_lens = np.random.randint(
  171. int(output_len * range_ratio),
  172. output_len + 1,
  173. size=num_prompts,
  174. )
  175. offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
  176. input_requests = []
  177. for i in range(num_prompts):
  178. prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size
  179. for j in range(input_lens[i])])
  180. input_requests.append(
  181. (prompt, int(input_lens[i]), int(output_lens[i])))
  182. return input_requests
  183. async def get_request(
  184. input_requests: List[Tuple[str, int, int]],
  185. request_rate: float,
  186. ) -> AsyncGenerator[Tuple[str, int, int], None]:
  187. input_requests = iter(input_requests)
  188. for request in input_requests:
  189. yield request
  190. if request_rate == float("inf"):
  191. # If the request rate is infinity, then we don't need to wait.
  192. continue
  193. # Sample the request interval from the exponential distribution.
  194. interval = np.random.exponential(1.0 / request_rate)
  195. # The next request will be sent after the interval.
  196. await asyncio.sleep(interval)
  197. def calculate_metrics(
  198. input_requests: List[Tuple[str, int, int]],
  199. outputs: List[RequestFuncOutput],
  200. dur_s: float,
  201. tokenizer: PreTrainedTokenizerBase,
  202. ) -> Tuple[BenchmarkMetrics, List[int]]:
  203. actual_output_lens: List[int] = []
  204. total_input = 0
  205. completed = 0
  206. itls: List[float] = []
  207. tpots: List[float] = []
  208. ttfts: List[float] = []
  209. for i in range(len(outputs)):
  210. if outputs[i].success:
  211. # We use the tokenizer to count the number of output tokens for all
  212. # serving backends instead of looking at len(outputs[i].itl) since
  213. # multiple output tokens may be bundled together
  214. # Note : this may inflate the output token count slightly
  215. output_len = len(
  216. tokenizer(outputs[i].generated_text,
  217. add_special_tokens=False).input_ids)
  218. actual_output_lens.append(output_len)
  219. total_input += input_requests[i][1]
  220. if output_len > 1:
  221. tpots.append(
  222. (outputs[i].latency - outputs[i].ttft) / (output_len - 1))
  223. itls += outputs[i].itl
  224. ttfts.append(outputs[i].ttft)
  225. completed += 1
  226. else:
  227. actual_output_lens.append(0)
  228. if completed == 0:
  229. warnings.warn(
  230. "All requests failed. This is likely due to a misconfiguration "
  231. "on the benchmark arguments.",
  232. stacklevel=2)
  233. metrics = BenchmarkMetrics(
  234. completed=completed,
  235. total_input=total_input,
  236. total_output=sum(actual_output_lens),
  237. request_throughput=completed / dur_s,
  238. input_throughput=total_input / dur_s,
  239. output_throughput=sum(actual_output_lens) / dur_s,
  240. mean_ttft_ms=np.mean(ttfts or 0) *
  241. 1000, # ttfts is empty if streaming is not supported by backend
  242. median_ttft_ms=np.median(ttfts or 0) * 1000,
  243. std_ttft_ms=np.std(ttfts or 0) * 1000,
  244. p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
  245. mean_tpot_ms=np.mean(tpots or 0) * 1000,
  246. median_tpot_ms=np.median(tpots or 0) * 1000,
  247. std_tpot_ms=np.std(tpots or 0) * 1000,
  248. p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
  249. mean_itl_ms=np.mean(itls or 0) * 1000,
  250. median_itl_ms=np.median(itls or 0) * 1000,
  251. std_itl_ms=np.std(itls or 0) * 1000,
  252. p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
  253. )
  254. return metrics, actual_output_lens
  255. async def benchmark(
  256. backend: str,
  257. api_url: str,
  258. base_url: str,
  259. model_id: str,
  260. tokenizer: PreTrainedTokenizerBase,
  261. input_requests: List[Tuple[str, int, int]],
  262. best_of: int,
  263. use_beam_search: bool,
  264. request_rate: float,
  265. disable_tqdm: bool,
  266. profile: bool,
  267. ):
  268. if backend in ASYNC_REQUEST_FUNCS:
  269. request_func = ASYNC_REQUEST_FUNCS[backend]
  270. else:
  271. raise ValueError(f"Unknown backend: {backend}")
  272. print("Starting initial single prompt test run...")
  273. test_prompt, test_prompt_len, test_output_len = input_requests[0]
  274. test_input = RequestFuncInput(
  275. model=model_id,
  276. prompt=test_prompt,
  277. api_url=api_url,
  278. prompt_len=test_prompt_len,
  279. output_len=test_output_len,
  280. best_of=best_of,
  281. use_beam_search=use_beam_search,
  282. )
  283. test_output = await request_func(request_func_input=test_input)
  284. if not test_output.success:
  285. raise ValueError(
  286. "Initial test run failed - Please make sure benchmark arguments "
  287. f"are correctly specified. Error: {test_output.error}")
  288. else:
  289. print("Initial test run completed. Starting main benchmark run...")
  290. if profile:
  291. print("Starting profiler...")
  292. profile_input = RequestFuncInput(
  293. model=model_id,
  294. prompt=test_prompt,
  295. api_url=base_url + "/start_profile",
  296. prompt_len=test_prompt_len,
  297. output_len=test_output_len,
  298. best_of=best_of,
  299. use_beam_search=use_beam_search,
  300. )
  301. profile_output = await request_func(request_func_input=profile_input)
  302. if profile_output.success:
  303. print("Profiler started")
  304. print(f"Traffic request rate: {request_rate}")
  305. pbar = None if disable_tqdm else tqdm(total=len(input_requests))
  306. benchmark_start_time = time.perf_counter()
  307. tasks: List[asyncio.Task] = []
  308. async for request in get_request(input_requests, request_rate):
  309. prompt, prompt_len, output_len = request
  310. request_func_input = RequestFuncInput(
  311. model=model_id,
  312. prompt=prompt,
  313. api_url=api_url,
  314. prompt_len=prompt_len,
  315. output_len=output_len,
  316. best_of=best_of,
  317. use_beam_search=use_beam_search,
  318. )
  319. tasks.append(
  320. asyncio.create_task(
  321. request_func(request_func_input=request_func_input,
  322. pbar=pbar)))
  323. outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
  324. if profile:
  325. print("Stopping profiler...")
  326. profile_input = RequestFuncInput(
  327. model=model_id,
  328. prompt=test_prompt,
  329. api_url=base_url + "/stop_profile",
  330. prompt_len=test_prompt_len,
  331. output_len=test_output_len,
  332. best_of=best_of,
  333. use_beam_search=use_beam_search,
  334. )
  335. profile_output = await request_func(request_func_input=profile_input)
  336. if profile_output.success:
  337. print("Profiler stopped")
  338. if pbar is not None:
  339. pbar.close()
  340. benchmark_duration = time.perf_counter() - benchmark_start_time
  341. metrics, actual_output_lens = calculate_metrics(
  342. input_requests=input_requests,
  343. outputs=outputs,
  344. dur_s=benchmark_duration,
  345. tokenizer=tokenizer,
  346. )
  347. print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
  348. print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
  349. print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
  350. benchmark_duration))
  351. print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
  352. print("{:<40} {:<10}".format("Total generated tokens:",
  353. metrics.total_output))
  354. print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
  355. metrics.request_throughput))
  356. print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):",
  357. metrics.input_throughput))
  358. print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
  359. metrics.output_throughput))
  360. print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-'))
  361. print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
  362. print("{:<40} {:<10.2f}".format("Median TTFT (ms):",
  363. metrics.median_ttft_ms))
  364. print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
  365. print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)',
  366. n=50,
  367. c='-'))
  368. print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
  369. print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
  370. metrics.median_tpot_ms))
  371. print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
  372. print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-'))
  373. print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
  374. print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
  375. print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
  376. print("=" * 50)
  377. result = {
  378. "duration": benchmark_duration,
  379. "completed": metrics.completed,
  380. "total_input_tokens": metrics.total_input,
  381. "total_output_tokens": metrics.total_output,
  382. "request_throughput": metrics.request_throughput,
  383. "input_throughput": metrics.input_throughput,
  384. "output_throughput": metrics.output_throughput,
  385. "mean_ttft_ms": metrics.mean_ttft_ms,
  386. "median_ttft_ms": metrics.median_ttft_ms,
  387. "std_ttft_ms": metrics.std_ttft_ms,
  388. "p99_ttft_ms": metrics.p99_ttft_ms,
  389. "mean_tpot_ms": metrics.mean_tpot_ms,
  390. "median_tpot_ms": metrics.median_tpot_ms,
  391. "std_tpot_ms": metrics.std_tpot_ms,
  392. "p99_tpot_ms": metrics.p99_tpot_ms,
  393. "mean_itl_ms": metrics.mean_itl_ms,
  394. "median_itl_ms": metrics.median_itl_ms,
  395. "std_itl_ms": metrics.std_itl_ms,
  396. "p99_itl_ms": metrics.p99_itl_ms,
  397. "input_lens": [output.prompt_len for output in outputs],
  398. "output_lens": actual_output_lens,
  399. "ttfts": [output.ttft for output in outputs],
  400. "itls": [output.itl for output in outputs],
  401. "generated_texts": [output.generated_text for output in outputs],
  402. "errors": [output.error for output in outputs],
  403. }
  404. return result
  405. def main(args: argparse.Namespace):
  406. print(args)
  407. random.seed(args.seed)
  408. np.random.seed(args.seed)
  409. backend = args.backend
  410. model_id = args.model
  411. tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
  412. if args.base_url is not None:
  413. api_url = f"{args.base_url}{args.endpoint}"
  414. base_url = f"{args.base_url}"
  415. else:
  416. api_url = f"http://{args.host}:{args.port}{args.endpoint}"
  417. base_url = f"http://{args.host}:{args.port}"
  418. tokenizer = get_tokenizer(tokenizer_id,
  419. trust_remote_code=args.trust_remote_code)
  420. if args.dataset is not None:
  421. warnings.warn(
  422. "The '--dataset' argument will be deprecated in the next "
  423. "release. Please use '--dataset-name' and "
  424. "'--dataset-path' in the future runs.",
  425. stacklevel=2)
  426. input_requests = sample_sharegpt_requests(
  427. dataset_path=args.dataset,
  428. num_requests=args.num_prompts,
  429. tokenizer=tokenizer,
  430. fixed_output_len=args.sharegpt_output_len,
  431. )
  432. elif args.dataset_name == "sharegpt":
  433. input_requests = sample_sharegpt_requests(
  434. dataset_path=args.dataset_path,
  435. num_requests=args.num_prompts,
  436. tokenizer=tokenizer,
  437. fixed_output_len=args.sharegpt_output_len,
  438. )
  439. elif args.dataset_name == "sonnet":
  440. # Do not format the prompt, pass to message directly
  441. if args.backend == "openai-chat":
  442. input_requests = sample_sonnet_requests(
  443. dataset_path=args.dataset_path,
  444. num_requests=args.num_prompts,
  445. input_len=args.sonnet_input_len,
  446. output_len=args.sonnet_output_len,
  447. prefix_len=args.sonnet_prefix_len,
  448. tokenizer=tokenizer,
  449. )
  450. input_requests = [(prompt, prompt_len, output_len)
  451. for prompt, prompt_formatted, prompt_len,
  452. output_len in input_requests]
  453. else:
  454. assert (
  455. tokenizer.chat_template or tokenizer.default_chat_template
  456. ), "Tokenizer/model must have chat template for sonnet dataset."
  457. input_requests = sample_sonnet_requests(
  458. dataset_path=args.dataset_path,
  459. num_requests=args.num_prompts,
  460. input_len=args.sonnet_input_len,
  461. output_len=args.sonnet_output_len,
  462. prefix_len=args.sonnet_prefix_len,
  463. tokenizer=tokenizer,
  464. )
  465. input_requests = [(prompt_formatted, prompt_len, output_len)
  466. for prompt, prompt_formatted, prompt_len,
  467. output_len in input_requests]
  468. elif args.dataset_name == "random":
  469. input_requests = sample_random_requests(
  470. input_len=args.random_input_len,
  471. output_len=args.random_output_len,
  472. num_prompts=args.num_prompts,
  473. range_ratio=args.random_range_ratio,
  474. tokenizer=tokenizer,
  475. )
  476. else:
  477. raise ValueError(f"Unknown dataset: {args.dataset_name}")
  478. benchmark_result = asyncio.run(
  479. benchmark(
  480. backend=backend,
  481. api_url=api_url,
  482. base_url=base_url,
  483. model_id=model_id,
  484. tokenizer=tokenizer,
  485. input_requests=input_requests,
  486. best_of=args.best_of,
  487. use_beam_search=args.use_beam_search,
  488. request_rate=args.request_rate,
  489. disable_tqdm=args.disable_tqdm,
  490. profile=args.profile,
  491. ))
  492. # Save config and results to json
  493. if args.save_result:
  494. result_json: Dict[str, Any] = {}
  495. # Setup
  496. current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
  497. result_json["date"] = current_dt
  498. result_json["backend"] = backend
  499. result_json["model_id"] = model_id
  500. result_json["tokenizer_id"] = tokenizer_id
  501. result_json["best_of"] = args.best_of
  502. result_json["use_beam_search"] = args.use_beam_search
  503. result_json["num_prompts"] = args.num_prompts
  504. # Metadata
  505. if args.metadata:
  506. for item in args.metadata:
  507. if "=" in item:
  508. kvstring = item.split("=")
  509. result_json[kvstring[0].strip()] = kvstring[1].strip()
  510. else:
  511. raise ValueError(
  512. "Invalid metadata format. Please use KEY=VALUE format."
  513. )
  514. # Traffic
  515. result_json["request_rate"] = (
  516. args.request_rate if args.request_rate < float("inf") else "inf")
  517. # Merge with benchmark result
  518. result_json = {**result_json, **benchmark_result}
  519. # Save to file
  520. base_model_id = model_id.split("/")[-1]
  521. file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa
  522. if args.result_filename:
  523. file_name = args.result_filename
  524. if args.result_dir:
  525. file_name = os.path.join(args.result_dir, file_name)
  526. with open(file_name, "w") as outfile:
  527. json.dump(result_json, outfile)
  528. if __name__ == "__main__":
  529. parser = FlexibleArgumentParser(
  530. description="Benchmark the online serving throughput.")
  531. parser.add_argument(
  532. "--backend",
  533. type=str,
  534. default="vllm",
  535. choices=list(ASYNC_REQUEST_FUNCS.keys()),
  536. )
  537. parser.add_argument(
  538. "--base-url",
  539. type=str,
  540. default=None,
  541. help="Server or API base url if not using http host and port.",
  542. )
  543. parser.add_argument("--host", type=str, default="localhost")
  544. parser.add_argument("--port", type=int, default=8000)
  545. parser.add_argument(
  546. "--endpoint",
  547. type=str,
  548. default="/v1/completions",
  549. help="API endpoint.",
  550. )
  551. parser.add_argument(
  552. "--dataset",
  553. type=str,
  554. default=None,
  555. help="Path to the ShareGPT dataset, will be deprecated in the "
  556. "next release.",
  557. )
  558. parser.add_argument(
  559. "--dataset-name",
  560. type=str,
  561. default="sharegpt",
  562. choices=["sharegpt", "sonnet", "random"],
  563. help="Name of the dataset to benchmark on.",
  564. )
  565. parser.add_argument("--dataset-path",
  566. type=str,
  567. default=None,
  568. help="Path to the dataset.")
  569. parser.add_argument(
  570. "--model",
  571. type=str,
  572. required=True,
  573. help="Name of the model.",
  574. )
  575. parser.add_argument(
  576. "--tokenizer",
  577. type=str,
  578. help=
  579. "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
  580. )
  581. parser.add_argument(
  582. "--best-of",
  583. type=int,
  584. default=1,
  585. help="Generates `best_of` sequences per prompt and "
  586. "returns the best one.",
  587. )
  588. parser.add_argument("--use-beam-search", action="store_true")
  589. parser.add_argument(
  590. "--num-prompts",
  591. type=int,
  592. default=1000,
  593. help="Number of prompts to process.",
  594. )
  595. parser.add_argument(
  596. "--sharegpt-output-len",
  597. type=int,
  598. default=None,
  599. help="Output length for each request. Overrides the output length "
  600. "from the ShareGPT dataset.")
  601. parser.add_argument(
  602. "--sonnet-input-len",
  603. type=int,
  604. default=550,
  605. help=
  606. "Number of input tokens per request, used only for sonnet dataset.",
  607. )
  608. parser.add_argument(
  609. "--sonnet-output-len",
  610. type=int,
  611. default=150,
  612. help=
  613. "Number of output tokens per request, used only for sonnet dataset.",
  614. )
  615. parser.add_argument(
  616. "--sonnet-prefix-len",
  617. type=int,
  618. default=200,
  619. help=
  620. "Number of prefix tokens per request, used only for sonnet dataset.",
  621. )
  622. parser.add_argument(
  623. "--random-input-len",
  624. type=int,
  625. default=1024,
  626. help=
  627. "Number of input tokens per request, used only for random sampling.",
  628. )
  629. parser.add_argument(
  630. "--random-output-len",
  631. type=int,
  632. default=128,
  633. help=
  634. "Number of output tokens per request, used only for random sampling.",
  635. )
  636. parser.add_argument(
  637. "--random-range-ratio",
  638. type=float,
  639. default=1.0,
  640. help="Range of sampled ratio of input/output length, "
  641. "used only for random sampling.",
  642. )
  643. parser.add_argument(
  644. "--request-rate",
  645. type=float,
  646. default=float("inf"),
  647. help="Number of requests per second. If this is inf, "
  648. "then all the requests are sent at time 0. "
  649. "Otherwise, we use Poisson process to synthesize "
  650. "the request arrival times.",
  651. )
  652. parser.add_argument("--seed", type=int, default=0)
  653. parser.add_argument(
  654. "--trust-remote-code",
  655. action="store_true",
  656. help="Trust remote code from huggingface",
  657. )
  658. parser.add_argument(
  659. "--disable-tqdm",
  660. action="store_true",
  661. help="Specify to disable tqdm progress bar.",
  662. )
  663. parser.add_argument(
  664. "--profile",
  665. action="store_true",
  666. help="Use Torch Profiler. The endpoint must be launched with "
  667. "VLLM_TORCH_PROFILER_DIR to enable profiler.",
  668. )
  669. parser.add_argument(
  670. "--save-result",
  671. action="store_true",
  672. help="Specify to save benchmark results to a json file",
  673. )
  674. parser.add_argument(
  675. "--metadata",
  676. metavar="KEY=VALUE",
  677. nargs="*",
  678. help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
  679. "for metadata of this run to be saved in the result JSON file "
  680. "for record keeping purposes.",
  681. )
  682. parser.add_argument(
  683. "--result-dir",
  684. type=str,
  685. default=None,
  686. help="Specify directory to save benchmark json results."
  687. "If not specified, results are saved in the current directory.",
  688. )
  689. parser.add_argument(
  690. "--result-filename",
  691. type=str,
  692. default=None,
  693. help="Specify the filename to save benchmark json results."
  694. "If not specified, results will be saved in "
  695. "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
  696. " format.",
  697. )
  698. args = parser.parse_args()
  699. main(args)