| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates.
 
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
- import uuid
 
- import asyncio
 
- import fire
 
- import torch
 
- from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
 
- from vllm.lora.request import LoRARequest
 
- from accelerate.utils import is_xpu_available
 
- if is_xpu_available():
 
-     torch.xpu.manual_seed(42)
 
- else:
 
-     torch.cuda.manual_seed(42)
 
- torch.manual_seed(42)
 
- def load_model(model_name, peft_model=None, pp_size=1, tp_size=1):
 
-     additional_configs = {}
 
-     if peft_model:
 
-         additional_configs["enable_lora"] = True
 
-         
 
-     engine_config = AsyncEngineArgs(
 
-         model=model_name,
 
-         pipeline_parallel_size=pp_size,
 
-         tensor_parallel_size=tp_size,
 
-         max_loras=1,
 
-         **additional_configs)
 
-     llm = AsyncLLMEngine.from_engine_args(engine_config)
 
-     return llm
 
- async def main(
 
-     model,
 
-     peft_model_name=None,
 
-     max_new_tokens=100,
 
-     user_prompt=None,
 
-     top_p=0.9,
 
-     temperature=0.8
 
- ):
 
-     while True:
 
-         if user_prompt is None:
 
-             user_prompt = input("Enter your prompt: ")
 
-             
 
-         print(f"User prompt:\n{user_prompt}")
 
-         print(f"sampling params: top_p {top_p} and temperature {temperature} for this inference request")
 
-         sampling_param = SamplingParams(top_p=top_p, temperature=temperature, max_tokens=max_new_tokens)
 
-         lora_request = None
 
-         if peft_model_name:
 
-             lora_request = LoRARequest("lora",0,peft_model_name)
 
-         req_id = str(uuid.uuid4())
 
-         generator = model.generate(user_prompt, sampling_param, req_id, lora_request=lora_request)
 
-         output = None
 
-         async for request_output in generator:
 
-             output = request_output
 
-    
 
-         print(f"model output:\n {user_prompt} {output.outputs[0].text}")
 
-         user_prompt = input("Enter next prompt (press Enter to exit): ")
 
-         if not user_prompt:
 
-             break
 
- def run_script(
 
-     model_name: str,
 
-     peft_model_name=None,
 
-     pp_size : int = 1,
 
-     tp_size : int = 1,
 
-     max_new_tokens=100,
 
-     user_prompt=None,
 
-     top_p=0.9,
 
-     temperature=0.8
 
- ):
 
-     model = load_model(model_name, peft_model_name, pp_size, tp_size)
 
-     asyncio.get_event_loop().run_until_complete(main(model, peft_model_name, max_new_tokens, user_prompt, top_p, temperature))
 
- if __name__ == "__main__":
 
-     fire.Fire(run_script)
 
 
  |