Bläddra i källkod

Enable pipeline parallelism through use of AsyncLLMEngine in vllm inferecen + enable use of lora adapter

Matthias Reso 9 månader sedan
förälder
incheckning
c9ae014459
1 ändrade filer med 35 tillägg och 12 borttagningar
  1. 35 12
      recipes/3p_integrations/vllm/inference.py

+ 35 - 12
recipes/3p_integrations/vllm/inference.py

@@ -1,11 +1,13 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import uuid
+import asyncio
 import fire
 
 import torch
-from vllm import LLM
-from vllm import LLM, SamplingParams
+from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
+from vllm.lora.request import LoRARequest
 from accelerate.utils import is_xpu_available
 
 if is_xpu_available():
@@ -15,13 +17,24 @@ else:
 
 torch.manual_seed(42)
 
-def load_model(model_name, tp_size=1):
+def load_model(model_name, peft_model=None, pp_size=1, tp_size=1):
+    additional_configs = {}
+    if peft_model:
+        additional_configs["enable_lora"] = True
+        
+    engine_config = AsyncEngineArgs(
+        model=model_name,
+        pipeline_parallel_size=pp_size,
+        tensor_parallel_size=tp_size,
+        max_loras=1,
+        **additional_configs)
 
-    llm = LLM(model_name, tensor_parallel_size=tp_size)
+    llm = AsyncLLMEngine.from_engine_args(engine_config)
     return llm
 
-def main(
+async def main(
     model,
+    peft_model_name=None,
     max_new_tokens=100,
     user_prompt=None,
     top_p=0.9,
@@ -35,26 +48,36 @@ def main(
 
         print(f"sampling params: top_p {top_p} and temperature {temperature} for this inference request")
         sampling_param = SamplingParams(top_p=top_p, temperature=temperature, max_tokens=max_new_tokens)
-        
 
-        outputs = model.generate(user_prompt, sampling_params=sampling_param)
+        lora_request = None
+        if peft_model_name:
+            lora_request = LoRARequest("lora",0,peft_model_name)
+
+        req_id = str(uuid.uuid4())
+
+        generator = model.generate(user_prompt, sampling_param, req_id, lora_request=lora_request)
+        output = None
+        async for request_output in generator:
+            output = request_output
    
-        print(f"model output:\n {user_prompt} {outputs[0].outputs[0].text}")
+        print(f"model output:\n {user_prompt} {output.outputs[0].text}")
         user_prompt = input("Enter next prompt (press Enter to exit): ")
         if not user_prompt:
             break
 
 def run_script(
     model_name: str,
-    peft_model=None,
-    tp_size=1,
+    peft_model_name=None,
+    pp_size : int = 1,
+    tp_size : int = 1,
     max_new_tokens=100,
     user_prompt=None,
     top_p=0.9,
     temperature=0.8
 ):
-    model = load_model(model_name, tp_size)
-    main(model, max_new_tokens, user_prompt, top_p, temperature)
+    model = load_model(model_name, peft_model_name, pp_size, tp_size)
+
+    asyncio.get_event_loop().run_until_complete(main(model, peft_model_name, max_new_tokens, user_prompt, top_p, temperature))
 
 if __name__ == "__main__":
     fire.Fire(run_script)