benchmark_prefix_caching.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. """
  2. Benchmark the efficiency of prefix caching.
  3. This script allows you to benchmark the performance of
  4. a model with and without prefix caching using either fixed prompts
  5. or prompts sampled from the ShareGPT dataset.
  6. Fixed example usage:
  7. python benchmark_prefix_caching.py \
  8. --model meta-llama/Llama-2-7b-chat-hf \
  9. --enable-prefix-caching \
  10. --num-prompts 1 \
  11. --repeat-count 100
  12. ShareGPT example usage:
  13. # This command samples 20 prompts with input lengths
  14. # between 128 and 256 tokens from the ShareGPT dataset,
  15. # then replicates each prompt 5 times.
  16. python benchmark_prefix_caching.py \
  17. --model meta-llama/Llama-2-7b-chat-hf \
  18. --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \
  19. --enable-prefix-caching \
  20. --num-prompts 20 \
  21. --repeat-count 5 \
  22. --input-length-range 128:256
  23. """
  24. import json
  25. import random
  26. import time
  27. from typing import List, Optional, Tuple
  28. from transformers import PreTrainedTokenizerBase
  29. from vllm import LLM, SamplingParams
  30. from vllm.utils import FlexibleArgumentParser
  31. try:
  32. from vllm.transformers_utils.tokenizer import get_tokenizer
  33. except ImportError:
  34. from backend_request_func import get_tokenizer
  35. PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
  36. def test_prefix(llm=None, sampling_params=None, prompts=None):
  37. start_time = time.time()
  38. llm.generate(prompts, sampling_params=sampling_params)
  39. end_time = time.time()
  40. print(f"cost time {end_time - start_time}")
  41. def sample_requests(
  42. dataset_path: str,
  43. num_requests: int,
  44. tokenizer: PreTrainedTokenizerBase,
  45. input_length_range: Tuple[int, int],
  46. fixed_output_len: Optional[int],
  47. ) -> List[Tuple[str, int, int]]:
  48. if fixed_output_len is not None and fixed_output_len < 4:
  49. raise ValueError("output_len too small")
  50. # Load the dataset.
  51. with open(dataset_path) as f:
  52. dataset = json.load(f)
  53. # Filter out the conversations with less than 2 turns.
  54. dataset = [data for data in dataset if len(data["conversations"]) >= 2]
  55. # Only keep the first two turns of each conversation.
  56. dataset = [(data["conversations"][0]["value"],
  57. data["conversations"][1]["value"]) for data in dataset]
  58. # Shuffle the dataset.
  59. random.shuffle(dataset)
  60. min_len, max_len = input_length_range
  61. # Filter out sequences that are too long or too short
  62. filtered_dataset: List[Tuple[str, int, int]] = []
  63. for i in range(len(dataset)):
  64. if len(filtered_dataset) == num_requests:
  65. break
  66. # Tokenize the prompts and completions.
  67. prompt = dataset[i][0]
  68. prompt_token_ids = tokenizer(prompt).input_ids
  69. completion = dataset[i][1]
  70. completion_token_ids = tokenizer(completion).input_ids
  71. prompt_len = len(prompt_token_ids)
  72. output_len = len(completion_token_ids
  73. ) if fixed_output_len is None else fixed_output_len
  74. if prompt_len < 4 or output_len < 4:
  75. # Prune too short sequences.
  76. continue
  77. if min_len <= prompt_len <= max_len:
  78. filtered_dataset.append((prompt, prompt_len, output_len))
  79. return filtered_dataset
  80. def repeat_and_sort_requests(requests: List[Tuple[str, int, int]],
  81. repeat_count: int,
  82. sort: bool = False) -> List[str]:
  83. repeated_requests = requests * repeat_count
  84. if sort:
  85. repeated_requests.sort(key=lambda x: x[1])
  86. else:
  87. random.shuffle(repeated_requests)
  88. return [req[0] for req in repeated_requests]
  89. def main(args):
  90. tokenizer = get_tokenizer(args.model, trust_remote_code=True)
  91. input_length_range = tuple(map(int, args.input_length_range.split(':')))
  92. if args.dataset_path is not None:
  93. print(f"Start to sample {args.num_prompts} prompts"
  94. "from {args.dataset_path}")
  95. filtered_datasets = sample_requests(
  96. dataset_path=args.dataset_path,
  97. num_requests=args.num_prompts,
  98. tokenizer=tokenizer,
  99. input_length_range=input_length_range,
  100. fixed_output_len=args.output_len,
  101. )
  102. else:
  103. prompt_len = len(tokenizer(PROMPT).input_ids)
  104. filtered_datasets = [(PROMPT, prompt_len, args.output_len)
  105. ] * args.num_prompts
  106. llm = LLM(model=args.model,
  107. tokenizer_mode='auto',
  108. trust_remote_code=True,
  109. enforce_eager=True,
  110. use_v2_block_manager=args.use_v2_block_manager,
  111. tensor_parallel_size=args.tensor_parallel_size,
  112. enable_prefix_caching=args.enable_prefix_caching)
  113. sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
  114. print("Testing filtered datasets")
  115. prompts = repeat_and_sort_requests(filtered_datasets,
  116. repeat_count=args.repeat_count,
  117. sort=args.sort)
  118. print("------warm up------")
  119. test_prefix(
  120. llm=llm,
  121. prompts=prompts,
  122. sampling_params=sampling_params,
  123. )
  124. print("------start generating------")
  125. test_prefix(
  126. llm=llm,
  127. prompts=prompts,
  128. sampling_params=sampling_params,
  129. )
  130. if __name__ == "__main__":
  131. parser = FlexibleArgumentParser(
  132. description=
  133. 'Benchmark the performance with or without automatic prefix caching.')
  134. parser.add_argument('--model',
  135. type=str,
  136. default='baichuan-inc/Baichuan2-13B-Chat')
  137. parser.add_argument("--dataset-path",
  138. type=str,
  139. default=None,
  140. help="Path to the dataset.")
  141. parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
  142. parser.add_argument('--output-len', type=int, default=10)
  143. parser.add_argument('--enable-prefix-caching',
  144. action='store_true',
  145. help='enable prefix caching')
  146. parser.add_argument('--use-v2-block-manager',
  147. action='store_true',
  148. help='Use BlockSpaceMangerV2')
  149. parser.add_argument('--num-prompts',
  150. type=int,
  151. default=1,
  152. help="Number of the prompts sampled from dataset")
  153. parser.add_argument('--repeat-count',
  154. type=int,
  155. default=100,
  156. help='Number of times to repeat each prompt')
  157. parser.add_argument('--sort',
  158. action='store_true',
  159. help='Sort prompts by input length')
  160. parser.add_argument('--input-length-range',
  161. type=str,
  162. default='128:256',
  163. help='Range of input lengths for sampling prompts,'
  164. 'specified as "min:max" (e.g., "128:256").')
  165. args = parser.parse_args()
  166. main(args)