|
@@ -1,11 +1,13 @@
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
+import uuid
|
|
|
+import asyncio
|
|
|
import fire
|
|
|
|
|
|
import torch
|
|
|
-from vllm import LLM
|
|
|
-from vllm import LLM, SamplingParams
|
|
|
+from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
|
|
|
+from vllm.lora.request import LoRARequest
|
|
|
from accelerate.utils import is_xpu_available
|
|
|
|
|
|
if is_xpu_available():
|
|
@@ -15,13 +17,24 @@ else:
|
|
|
|
|
|
torch.manual_seed(42)
|
|
|
|
|
|
-def load_model(model_name, tp_size=1):
|
|
|
+def load_model(model_name, peft_model=None, pp_size=1, tp_size=1):
|
|
|
+ additional_configs = {}
|
|
|
+ if peft_model:
|
|
|
+ additional_configs["enable_lora"] = True
|
|
|
+
|
|
|
+ engine_config = AsyncEngineArgs(
|
|
|
+ model=model_name,
|
|
|
+ pipeline_parallel_size=pp_size,
|
|
|
+ tensor_parallel_size=tp_size,
|
|
|
+ max_loras=1,
|
|
|
+ **additional_configs)
|
|
|
|
|
|
- llm = LLM(model_name, tensor_parallel_size=tp_size)
|
|
|
+ llm = AsyncLLMEngine.from_engine_args(engine_config)
|
|
|
return llm
|
|
|
|
|
|
-def main(
|
|
|
+async def main(
|
|
|
model,
|
|
|
+ peft_model_name=None,
|
|
|
max_new_tokens=100,
|
|
|
user_prompt=None,
|
|
|
top_p=0.9,
|
|
@@ -35,26 +48,36 @@ def main(
|
|
|
|
|
|
print(f"sampling params: top_p {top_p} and temperature {temperature} for this inference request")
|
|
|
sampling_param = SamplingParams(top_p=top_p, temperature=temperature, max_tokens=max_new_tokens)
|
|
|
-
|
|
|
|
|
|
- outputs = model.generate(user_prompt, sampling_params=sampling_param)
|
|
|
+ lora_request = None
|
|
|
+ if peft_model_name:
|
|
|
+ lora_request = LoRARequest("lora",0,peft_model_name)
|
|
|
+
|
|
|
+ req_id = str(uuid.uuid4())
|
|
|
+
|
|
|
+ generator = model.generate(user_prompt, sampling_param, req_id, lora_request=lora_request)
|
|
|
+ output = None
|
|
|
+ async for request_output in generator:
|
|
|
+ output = request_output
|
|
|
|
|
|
- print(f"model output:\n {user_prompt} {outputs[0].outputs[0].text}")
|
|
|
+ print(f"model output:\n {user_prompt} {output.outputs[0].text}")
|
|
|
user_prompt = input("Enter next prompt (press Enter to exit): ")
|
|
|
if not user_prompt:
|
|
|
break
|
|
|
|
|
|
def run_script(
|
|
|
model_name: str,
|
|
|
- peft_model=None,
|
|
|
- tp_size=1,
|
|
|
+ peft_model_name=None,
|
|
|
+ pp_size : int = 1,
|
|
|
+ tp_size : int = 1,
|
|
|
max_new_tokens=100,
|
|
|
user_prompt=None,
|
|
|
top_p=0.9,
|
|
|
temperature=0.8
|
|
|
):
|
|
|
- model = load_model(model_name, tp_size)
|
|
|
- main(model, max_new_tokens, user_prompt, top_p, temperature)
|
|
|
+ model = load_model(model_name, peft_model_name, pp_size, tp_size)
|
|
|
+
|
|
|
+ asyncio.get_event_loop().run_until_complete(main(model, peft_model_name, max_new_tokens, user_prompt, top_p, temperature))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
fire.Fire(run_script)
|