| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428 | import jsonimport osimport sysimport timeimport tracebackfrom dataclasses import dataclass, fieldfrom typing import List, Optional, Unionimport aiohttpimport huggingface_hub.constantsfrom tqdm.asyncio import tqdmfrom transformers import (AutoTokenizer, PreTrainedTokenizer,                          PreTrainedTokenizerFast)AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)@dataclassclass RequestFuncInput:    prompt: str    api_url: str    prompt_len: int    output_len: int    model: str    best_of: int = 1    use_beam_search: bool = False@dataclassclass RequestFuncOutput:    generated_text: str = ""    success: bool = False    latency: float = 0.0    ttft: float = 0.0  # Time to first token    itl: List[float] = field(        default_factory=list)  # List of inter-token latencies    prompt_len: int = 0    error: str = ""async def async_request_tgi(    request_func_input: RequestFuncInput,    pbar: Optional[tqdm] = None,) -> RequestFuncOutput:    api_url = request_func_input.api_url    assert api_url.endswith("generate_stream")    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:        assert not request_func_input.use_beam_search        params = {            "best_of": request_func_input.best_of,            "max_new_tokens": request_func_input.output_len,            "do_sample": True,            "temperature": 0.01,  # TGI does not accept 0.0 temperature.            "top_p": 0.99,  # TGI does not accept 1.0 top_p.        }        payload = {            "inputs": request_func_input.prompt,            "parameters": params,        }        output = RequestFuncOutput()        output.prompt_len = request_func_input.prompt_len        ttft = 0.0        st = time.perf_counter()        most_recent_timestamp = st        try:            async with session.post(url=api_url, json=payload) as response:                if response.status == 200:                    async for chunk_bytes in response.content:                        chunk_bytes = chunk_bytes.strip()                        if not chunk_bytes:                            continue                        chunk_bytes = chunk_bytes.decode("utf-8")                        #NOTE: Sometimes TGI returns a ping response without                        # any data, we should skip it.                        if chunk_bytes.startswith(":"):                            continue                        chunk = remove_prefix(chunk_bytes, "data:")                        data = json.loads(chunk)                        timestamp = time.perf_counter()                        # First token                        if ttft == 0.0:                            ttft = time.perf_counter() - st                            output.ttft = ttft                        # Decoding phase                        else:                            output.itl.append(timestamp -                                              most_recent_timestamp)                        most_recent_timestamp = timestamp                    output.latency = most_recent_timestamp - st                    output.success = True                    output.generated_text = data["generated_text"]                else:                    output.error = response.reason or ""                    output.success = False        except Exception:            output.success = False            exc_info = sys.exc_info()            output.error = "".join(traceback.format_exception(*exc_info))        if pbar:            pbar.update(1)        return outputasync def async_request_trt_llm(    request_func_input: RequestFuncInput,    pbar: Optional[tqdm] = None,) -> RequestFuncOutput:    api_url = request_func_input.api_url    assert api_url.endswith("generate_stream")    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:        assert not request_func_input.use_beam_search        assert request_func_input.best_of == 1        payload = {            "accumulate_tokens": True,            "text_input": request_func_input.prompt,            "temperature": 0.0,            "top_p": 1.0,            "max_tokens": request_func_input.output_len,            "stream": True,        }        output = RequestFuncOutput()        output.prompt_len = request_func_input.prompt_len        ttft = 0.0        st = time.perf_counter()        most_recent_timestamp = st        try:            async with session.post(url=api_url, json=payload) as response:                if response.status == 200:                    async for chunk_bytes in response.content:                        chunk_bytes = chunk_bytes.strip()                        if not chunk_bytes:                            continue                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),                                              "data:")                        data = json.loads(chunk)                        output.generated_text += data["text_output"]                        timestamp = time.perf_counter()                        # First token                        if ttft == 0.0:                            ttft = time.perf_counter() - st                            output.ttft = ttft                        # Decoding phase                        else:                            output.itl.append(timestamp -                                              most_recent_timestamp)                        most_recent_timestamp = timestamp                    output.latency = most_recent_timestamp - st                    output.success = True                else:                    output.error = response.reason or ""                    output.success = False        except Exception:            output.success = False            exc_info = sys.exc_info()            output.error = "".join(traceback.format_exception(*exc_info))        if pbar:            pbar.update(1)        return outputasync def async_request_deepspeed_mii(    request_func_input: RequestFuncInput,    pbar: Optional[tqdm] = None,) -> RequestFuncOutput:    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:        assert request_func_input.best_of == 1        assert not request_func_input.use_beam_search        payload = {            "prompt": request_func_input.prompt,            "max_tokens": request_func_input.output_len,            "temperature": 0.01,  # deepspeed-mii does not accept 0.0 temp.            "top_p": 1.0,        }        output = RequestFuncOutput()        output.prompt_len = request_func_input.prompt_len        # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,        # will use 0 as placeholder.        # See https://github.com/microsoft/DeepSpeed-MII/pull/311        output.ttft = 0        st = time.perf_counter()        try:            async with session.post(url=request_func_input.api_url,                                    json=payload) as response:                if response.status == 200:                    parsed_resp = await response.json()                    output.latency = time.perf_counter() - st                    output.generated_text = parsed_resp["text"][0]                    output.success = True                else:                    output.error = response.reason or ""                    output.success = False        except Exception:            output.success = False            exc_info = sys.exc_info()            output.error = "".join(traceback.format_exception(*exc_info))        if pbar:            pbar.update(1)        return outputasync def async_request_openai_completions(    request_func_input: RequestFuncInput,    pbar: Optional[tqdm] = None,) -> RequestFuncOutput:    api_url = request_func_input.api_url    assert api_url.endswith(        ("completions", "profile")    ), "OpenAI Completions API URL must end with 'completions' or 'profile'."    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:        assert not request_func_input.use_beam_search        payload = {            "model": request_func_input.model,            "prompt": request_func_input.prompt,            "temperature": 0.0,            "best_of": request_func_input.best_of,            "max_tokens": request_func_input.output_len,            "stream": True,        }        headers = {            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"        }        output = RequestFuncOutput()        output.prompt_len = request_func_input.prompt_len        generated_text = ""        ttft = 0.0        st = time.perf_counter()        most_recent_timestamp = st        try:            async with session.post(url=api_url, json=payload,                                    headers=headers) as response:                if response.status == 200:                    async for chunk_bytes in response.content:                        chunk_bytes = chunk_bytes.strip()                        if not chunk_bytes:                            continue                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),                                              "data: ")                        if chunk == "[DONE]":                            latency = time.perf_counter() - st                        else:                            data = json.loads(chunk)                            # NOTE: Some completion API might have a last                            # usage summary response without a token so we                            # want to check a token was generated                            if data["choices"][0]["text"]:                                timestamp = time.perf_counter()                                # First token                                if ttft == 0.0:                                    ttft = time.perf_counter() - st                                    output.ttft = ttft                                # Decoding phase                                else:                                    output.itl.append(timestamp -                                                      most_recent_timestamp)                                most_recent_timestamp = timestamp                                generated_text += data["choices"][0]["text"]                    output.generated_text = generated_text                    output.success = True                    output.latency = latency                else:                    output.error = response.reason or ""                    output.success = False        except Exception:            output.success = False            exc_info = sys.exc_info()            output.error = "".join(traceback.format_exception(*exc_info))    if pbar:        pbar.update(1)    return outputasync def async_request_openai_chat_completions(    request_func_input: RequestFuncInput,    pbar: Optional[tqdm] = None,) -> RequestFuncOutput:    api_url = request_func_input.api_url    assert api_url.endswith(        "chat/completions"    ), "OpenAI Chat Completions API URL must end with 'chat/completions'."    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:        assert not request_func_input.use_beam_search        payload = {            "model": request_func_input.model,            "messages": [                {                    "role": "user",                    "content": request_func_input.prompt,                },            ],            "temperature": 0.0,            "max_tokens": request_func_input.output_len,            "stream": True,        }        headers = {            "Content-Type": "application/json",            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",        }        output = RequestFuncOutput()        output.prompt_len = request_func_input.prompt_len        generated_text = ""        ttft = 0.0        st = time.perf_counter()        most_recent_timestamp = st        try:            async with session.post(url=api_url, json=payload,                                    headers=headers) as response:                if response.status == 200:                    async for chunk_bytes in response.content:                        chunk_bytes = chunk_bytes.strip()                        if not chunk_bytes:                            continue                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),                                              "data: ")                        if chunk == "[DONE]":                            latency = time.perf_counter() - st                        else:                            timestamp = time.perf_counter()                            data = json.loads(chunk)                            delta = data["choices"][0]["delta"]                            if delta.get("content", None):                                # First token                                if ttft == 0.0:                                    ttft = time.perf_counter() - st                                    output.ttft = ttft                                # Decoding phase                                else:                                    output.itl.append(timestamp -                                                      most_recent_timestamp)                                generated_text += delta["content"]                            most_recent_timestamp = timestamp                    output.generated_text = generated_text                    output.success = True                    output.latency = latency                else:                    output.error = response.reason or ""                    output.success = False        except Exception:            output.success = False            exc_info = sys.exc_info()            output.error = "".join(traceback.format_exception(*exc_info))    if pbar:        pbar.update(1)    return output# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix)# introduced in Python 3.9def remove_prefix(text: str, prefix: str) -> str:    if text.startswith(prefix):        return text[len(prefix):]    return textdef get_model(pretrained_model_name_or_path: str) -> str:    if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':        from modelscope import snapshot_download        model_path = snapshot_download(            model_id=pretrained_model_name_or_path,            local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,            ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])        return model_path    return pretrained_model_name_or_pathdef get_tokenizer(    pretrained_model_name_or_path: str, trust_remote_code: bool) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:    if pretrained_model_name_or_path is not None and not os.path.exists(            pretrained_model_name_or_path):        pretrained_model_name_or_path = get_model(            pretrained_model_name_or_path)    return AutoTokenizer.from_pretrained(pretrained_model_name_or_path,                                         trust_remote_code=trust_remote_code)ASYNC_REQUEST_FUNCS = {    "tgi": async_request_tgi,    "vllm": async_request_openai_completions,    "lmdeploy": async_request_openai_completions,    "deepspeed-mii": async_request_deepspeed_mii,    "openai": async_request_openai_completions,    "openai-chat": async_request_openai_chat_completions,    "tensorrt-llm": async_request_trt_llm,    "scalellm": async_request_openai_completions,}
 |