| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 | # Copyright (c) Meta Platforms, Inc. and affiliates.# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.import uuidimport asyncioimport fireimport torchfrom vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParamsfrom vllm.lora.request import LoRARequestfrom accelerate.utils import is_xpu_availableif is_xpu_available():    torch.xpu.manual_seed(42)else:    torch.cuda.manual_seed(42)torch.manual_seed(42)def load_model(model_name, peft_model=None, pp_size=1, tp_size=1):    additional_configs = {}    if peft_model:        additional_configs["enable_lora"] = True            engine_config = AsyncEngineArgs(        model=model_name,        pipeline_parallel_size=pp_size,        tensor_parallel_size=tp_size,        max_loras=1,        **additional_configs)    llm = AsyncLLMEngine.from_engine_args(engine_config)    return llmasync def main(    model,    peft_model_name=None,    max_new_tokens=100,    user_prompt=None,    top_p=0.9,    temperature=0.8):    while True:        if user_prompt is None:            user_prompt = input("Enter your prompt: ")                    print(f"User prompt:\n{user_prompt}")        print(f"sampling params: top_p {top_p} and temperature {temperature} for this inference request")        sampling_param = SamplingParams(top_p=top_p, temperature=temperature, max_tokens=max_new_tokens)        lora_request = None        if peft_model_name:            lora_request = LoRARequest("lora",0,peft_model_name)        req_id = str(uuid.uuid4())        generator = model.generate(user_prompt, sampling_param, req_id, lora_request=lora_request)        output = None        async for request_output in generator:            output = request_output           print(f"model output:\n {user_prompt} {output.outputs[0].text}")        user_prompt = input("Enter next prompt (press Enter to exit): ")        if not user_prompt:            breakdef run_script(    model_name: str,    peft_model_name=None,    pp_size : int = 1,    tp_size : int = 1,    max_new_tokens=100,    user_prompt=None,    top_p=0.9,    temperature=0.8):    model = load_model(model_name, peft_model_name, pp_size, tp_size)    asyncio.get_event_loop().run_until_complete(main(model, peft_model_name, max_new_tokens, user_prompt, top_p, temperature))if __name__ == "__main__":    fire.Fire(run_script)
 |