| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428 | 
							- import json
 
- import os
 
- import sys
 
- import time
 
- import traceback
 
- from dataclasses import dataclass, field
 
- from typing import List, Optional, Union
 
- import aiohttp
 
- import huggingface_hub.constants
 
- from tqdm.asyncio import tqdm
 
- from transformers import (AutoTokenizer, PreTrainedTokenizer,
 
-                           PreTrainedTokenizerFast)
 
- AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
 
- @dataclass
 
- class RequestFuncInput:
 
-     prompt: str
 
-     api_url: str
 
-     prompt_len: int
 
-     output_len: int
 
-     model: str
 
-     best_of: int = 1
 
-     use_beam_search: bool = False
 
- @dataclass
 
- class RequestFuncOutput:
 
-     generated_text: str = ""
 
-     success: bool = False
 
-     latency: float = 0.0
 
-     ttft: float = 0.0  # Time to first token
 
-     itl: List[float] = field(
 
-         default_factory=list)  # List of inter-token latencies
 
-     prompt_len: int = 0
 
-     error: str = ""
 
- async def async_request_tgi(
 
-     request_func_input: RequestFuncInput,
 
-     pbar: Optional[tqdm] = None,
 
- ) -> RequestFuncOutput:
 
-     api_url = request_func_input.api_url
 
-     assert api_url.endswith("generate_stream")
 
-     async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
 
-         assert not request_func_input.use_beam_search
 
-         params = {
 
-             "best_of": request_func_input.best_of,
 
-             "max_new_tokens": request_func_input.output_len,
 
-             "do_sample": True,
 
-             "temperature": 0.01,  # TGI does not accept 0.0 temperature.
 
-             "top_p": 0.99,  # TGI does not accept 1.0 top_p.
 
-         }
 
-         payload = {
 
-             "inputs": request_func_input.prompt,
 
-             "parameters": params,
 
-         }
 
-         output = RequestFuncOutput()
 
-         output.prompt_len = request_func_input.prompt_len
 
-         ttft = 0.0
 
-         st = time.perf_counter()
 
-         most_recent_timestamp = st
 
-         try:
 
-             async with session.post(url=api_url, json=payload) as response:
 
-                 if response.status == 200:
 
-                     async for chunk_bytes in response.content:
 
-                         chunk_bytes = chunk_bytes.strip()
 
-                         if not chunk_bytes:
 
-                             continue
 
-                         chunk_bytes = chunk_bytes.decode("utf-8")
 
-                         #NOTE: Sometimes TGI returns a ping response without
 
-                         # any data, we should skip it.
 
-                         if chunk_bytes.startswith(":"):
 
-                             continue
 
-                         chunk = remove_prefix(chunk_bytes, "data:")
 
-                         data = json.loads(chunk)
 
-                         timestamp = time.perf_counter()
 
-                         # First token
 
-                         if ttft == 0.0:
 
-                             ttft = time.perf_counter() - st
 
-                             output.ttft = ttft
 
-                         # Decoding phase
 
-                         else:
 
-                             output.itl.append(timestamp -
 
-                                               most_recent_timestamp)
 
-                         most_recent_timestamp = timestamp
 
-                     output.latency = most_recent_timestamp - st
 
-                     output.success = True
 
-                     output.generated_text = data["generated_text"]
 
-                 else:
 
-                     output.error = response.reason or ""
 
-                     output.success = False
 
-         except Exception:
 
-             output.success = False
 
-             exc_info = sys.exc_info()
 
-             output.error = "".join(traceback.format_exception(*exc_info))
 
-         if pbar:
 
-             pbar.update(1)
 
-         return output
 
- async def async_request_trt_llm(
 
-     request_func_input: RequestFuncInput,
 
-     pbar: Optional[tqdm] = None,
 
- ) -> RequestFuncOutput:
 
-     api_url = request_func_input.api_url
 
-     assert api_url.endswith("generate_stream")
 
-     async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
 
-         assert not request_func_input.use_beam_search
 
-         assert request_func_input.best_of == 1
 
-         payload = {
 
-             "accumulate_tokens": True,
 
-             "text_input": request_func_input.prompt,
 
-             "temperature": 0.0,
 
-             "top_p": 1.0,
 
-             "max_tokens": request_func_input.output_len,
 
-             "stream": True,
 
-         }
 
-         output = RequestFuncOutput()
 
-         output.prompt_len = request_func_input.prompt_len
 
-         ttft = 0.0
 
-         st = time.perf_counter()
 
-         most_recent_timestamp = st
 
-         try:
 
-             async with session.post(url=api_url, json=payload) as response:
 
-                 if response.status == 200:
 
-                     async for chunk_bytes in response.content:
 
-                         chunk_bytes = chunk_bytes.strip()
 
-                         if not chunk_bytes:
 
-                             continue
 
-                         chunk = remove_prefix(chunk_bytes.decode("utf-8"),
 
-                                               "data:")
 
-                         data = json.loads(chunk)
 
-                         output.generated_text += data["text_output"]
 
-                         timestamp = time.perf_counter()
 
-                         # First token
 
-                         if ttft == 0.0:
 
-                             ttft = time.perf_counter() - st
 
-                             output.ttft = ttft
 
-                         # Decoding phase
 
-                         else:
 
-                             output.itl.append(timestamp -
 
-                                               most_recent_timestamp)
 
-                         most_recent_timestamp = timestamp
 
-                     output.latency = most_recent_timestamp - st
 
-                     output.success = True
 
-                 else:
 
-                     output.error = response.reason or ""
 
-                     output.success = False
 
-         except Exception:
 
-             output.success = False
 
-             exc_info = sys.exc_info()
 
-             output.error = "".join(traceback.format_exception(*exc_info))
 
-         if pbar:
 
-             pbar.update(1)
 
-         return output
 
- async def async_request_deepspeed_mii(
 
-     request_func_input: RequestFuncInput,
 
-     pbar: Optional[tqdm] = None,
 
- ) -> RequestFuncOutput:
 
-     async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
 
-         assert request_func_input.best_of == 1
 
-         assert not request_func_input.use_beam_search
 
-         payload = {
 
-             "prompt": request_func_input.prompt,
 
-             "max_tokens": request_func_input.output_len,
 
-             "temperature": 0.01,  # deepspeed-mii does not accept 0.0 temp.
 
-             "top_p": 1.0,
 
-         }
 
-         output = RequestFuncOutput()
 
-         output.prompt_len = request_func_input.prompt_len
 
-         # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
 
-         # will use 0 as placeholder.
 
-         # See https://github.com/microsoft/DeepSpeed-MII/pull/311
 
-         output.ttft = 0
 
-         st = time.perf_counter()
 
-         try:
 
-             async with session.post(url=request_func_input.api_url,
 
-                                     json=payload) as response:
 
-                 if response.status == 200:
 
-                     parsed_resp = await response.json()
 
-                     output.latency = time.perf_counter() - st
 
-                     output.generated_text = parsed_resp["text"][0]
 
-                     output.success = True
 
-                 else:
 
-                     output.error = response.reason or ""
 
-                     output.success = False
 
-         except Exception:
 
-             output.success = False
 
-             exc_info = sys.exc_info()
 
-             output.error = "".join(traceback.format_exception(*exc_info))
 
-         if pbar:
 
-             pbar.update(1)
 
-         return output
 
- async def async_request_openai_completions(
 
-     request_func_input: RequestFuncInput,
 
-     pbar: Optional[tqdm] = None,
 
- ) -> RequestFuncOutput:
 
-     api_url = request_func_input.api_url
 
-     assert api_url.endswith(
 
-         ("completions", "profile")
 
-     ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
 
-     async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
 
-         assert not request_func_input.use_beam_search
 
-         payload = {
 
-             "model": request_func_input.model,
 
-             "prompt": request_func_input.prompt,
 
-             "temperature": 0.0,
 
-             "best_of": request_func_input.best_of,
 
-             "max_tokens": request_func_input.output_len,
 
-             "stream": True,
 
-         }
 
-         headers = {
 
-             "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
 
-         }
 
-         output = RequestFuncOutput()
 
-         output.prompt_len = request_func_input.prompt_len
 
-         generated_text = ""
 
-         ttft = 0.0
 
-         st = time.perf_counter()
 
-         most_recent_timestamp = st
 
-         try:
 
-             async with session.post(url=api_url, json=payload,
 
-                                     headers=headers) as response:
 
-                 if response.status == 200:
 
-                     async for chunk_bytes in response.content:
 
-                         chunk_bytes = chunk_bytes.strip()
 
-                         if not chunk_bytes:
 
-                             continue
 
-                         chunk = remove_prefix(chunk_bytes.decode("utf-8"),
 
-                                               "data: ")
 
-                         if chunk == "[DONE]":
 
-                             latency = time.perf_counter() - st
 
-                         else:
 
-                             data = json.loads(chunk)
 
-                             # NOTE: Some completion API might have a last
 
-                             # usage summary response without a token so we
 
-                             # want to check a token was generated
 
-                             if data["choices"][0]["text"]:
 
-                                 timestamp = time.perf_counter()
 
-                                 # First token
 
-                                 if ttft == 0.0:
 
-                                     ttft = time.perf_counter() - st
 
-                                     output.ttft = ttft
 
-                                 # Decoding phase
 
-                                 else:
 
-                                     output.itl.append(timestamp -
 
-                                                       most_recent_timestamp)
 
-                                 most_recent_timestamp = timestamp
 
-                                 generated_text += data["choices"][0]["text"]
 
-                     output.generated_text = generated_text
 
-                     output.success = True
 
-                     output.latency = latency
 
-                 else:
 
-                     output.error = response.reason or ""
 
-                     output.success = False
 
-         except Exception:
 
-             output.success = False
 
-             exc_info = sys.exc_info()
 
-             output.error = "".join(traceback.format_exception(*exc_info))
 
-     if pbar:
 
-         pbar.update(1)
 
-     return output
 
- async def async_request_openai_chat_completions(
 
-     request_func_input: RequestFuncInput,
 
-     pbar: Optional[tqdm] = None,
 
- ) -> RequestFuncOutput:
 
-     api_url = request_func_input.api_url
 
-     assert api_url.endswith(
 
-         "chat/completions"
 
-     ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
 
-     async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
 
-         assert not request_func_input.use_beam_search
 
-         payload = {
 
-             "model": request_func_input.model,
 
-             "messages": [
 
-                 {
 
-                     "role": "user",
 
-                     "content": request_func_input.prompt,
 
-                 },
 
-             ],
 
-             "temperature": 0.0,
 
-             "max_tokens": request_func_input.output_len,
 
-             "stream": True,
 
-         }
 
-         headers = {
 
-             "Content-Type": "application/json",
 
-             "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
 
-         }
 
-         output = RequestFuncOutput()
 
-         output.prompt_len = request_func_input.prompt_len
 
-         generated_text = ""
 
-         ttft = 0.0
 
-         st = time.perf_counter()
 
-         most_recent_timestamp = st
 
-         try:
 
-             async with session.post(url=api_url, json=payload,
 
-                                     headers=headers) as response:
 
-                 if response.status == 200:
 
-                     async for chunk_bytes in response.content:
 
-                         chunk_bytes = chunk_bytes.strip()
 
-                         if not chunk_bytes:
 
-                             continue
 
-                         chunk = remove_prefix(chunk_bytes.decode("utf-8"),
 
-                                               "data: ")
 
-                         if chunk == "[DONE]":
 
-                             latency = time.perf_counter() - st
 
-                         else:
 
-                             timestamp = time.perf_counter()
 
-                             data = json.loads(chunk)
 
-                             delta = data["choices"][0]["delta"]
 
-                             if delta.get("content", None):
 
-                                 # First token
 
-                                 if ttft == 0.0:
 
-                                     ttft = time.perf_counter() - st
 
-                                     output.ttft = ttft
 
-                                 # Decoding phase
 
-                                 else:
 
-                                     output.itl.append(timestamp -
 
-                                                       most_recent_timestamp)
 
-                                 generated_text += delta["content"]
 
-                             most_recent_timestamp = timestamp
 
-                     output.generated_text = generated_text
 
-                     output.success = True
 
-                     output.latency = latency
 
-                 else:
 
-                     output.error = response.reason or ""
 
-                     output.success = False
 
-         except Exception:
 
-             output.success = False
 
-             exc_info = sys.exc_info()
 
-             output.error = "".join(traceback.format_exception(*exc_info))
 
-     if pbar:
 
-         pbar.update(1)
 
-     return output
 
- # Since vllm must support Python 3.8, we can't use str.removeprefix(prefix)
 
- # introduced in Python 3.9
 
- def remove_prefix(text: str, prefix: str) -> str:
 
-     if text.startswith(prefix):
 
-         return text[len(prefix):]
 
-     return text
 
- def get_model(pretrained_model_name_or_path: str) -> str:
 
-     if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
 
-         from modelscope import snapshot_download
 
-         model_path = snapshot_download(
 
-             model_id=pretrained_model_name_or_path,
 
-             local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
 
-             ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
 
-         return model_path
 
-     return pretrained_model_name_or_path
 
- def get_tokenizer(
 
-     pretrained_model_name_or_path: str, trust_remote_code: bool
 
- ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
 
-     if pretrained_model_name_or_path is not None and not os.path.exists(
 
-             pretrained_model_name_or_path):
 
-         pretrained_model_name_or_path = get_model(
 
-             pretrained_model_name_or_path)
 
-     return AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
 
-                                          trust_remote_code=trust_remote_code)
 
- ASYNC_REQUEST_FUNCS = {
 
-     "tgi": async_request_tgi,
 
-     "vllm": async_request_openai_completions,
 
-     "lmdeploy": async_request_openai_completions,
 
-     "deepspeed-mii": async_request_deepspeed_mii,
 
-     "openai": async_request_openai_completions,
 
-     "openai-chat": async_request_openai_chat_completions,
 
-     "tensorrt-llm": async_request_trt_llm,
 
-     "scalellm": async_request_openai_completions,
 
- }
 
 
  |