inference.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import uuid
  4. import asyncio
  5. import fire
  6. import torch
  7. from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
  8. from vllm.lora.request import LoRARequest
  9. from accelerate.utils import is_xpu_available
  10. if is_xpu_available():
  11. torch.xpu.manual_seed(42)
  12. else:
  13. torch.cuda.manual_seed(42)
  14. torch.manual_seed(42)
  15. def load_model(model_name, peft_model=None, pp_size=1, tp_size=1):
  16. additional_configs = {}
  17. if peft_model:
  18. additional_configs["enable_lora"] = True
  19. engine_config = AsyncEngineArgs(
  20. model=model_name,
  21. pipeline_parallel_size=pp_size,
  22. tensor_parallel_size=tp_size,
  23. max_loras=1,
  24. **additional_configs)
  25. llm = AsyncLLMEngine.from_engine_args(engine_config)
  26. return llm
  27. async def main(
  28. model,
  29. peft_model_name=None,
  30. max_new_tokens=100,
  31. user_prompt=None,
  32. top_p=0.9,
  33. temperature=0.8
  34. ):
  35. while True:
  36. if user_prompt is None:
  37. user_prompt = input("Enter your prompt: ")
  38. print(f"User prompt:\n{user_prompt}")
  39. print(f"sampling params: top_p {top_p} and temperature {temperature} for this inference request")
  40. sampling_param = SamplingParams(top_p=top_p, temperature=temperature, max_tokens=max_new_tokens)
  41. lora_request = None
  42. if peft_model_name:
  43. lora_request = LoRARequest("lora",0,peft_model_name)
  44. req_id = str(uuid.uuid4())
  45. generator = model.generate(user_prompt, sampling_param, req_id, lora_request=lora_request)
  46. output = None
  47. async for request_output in generator:
  48. output = request_output
  49. print(f"model output:\n {user_prompt} {output.outputs[0].text}")
  50. user_prompt = input("Enter next prompt (press Enter to exit): ")
  51. if not user_prompt:
  52. break
  53. def run_script(
  54. model_name: str,
  55. peft_model_name=None,
  56. pp_size : int = 1,
  57. tp_size : int = 1,
  58. max_new_tokens=100,
  59. user_prompt=None,
  60. top_p=0.9,
  61. temperature=0.8
  62. ):
  63. model = load_model(model_name, peft_model_name, pp_size, tp_size)
  64. asyncio.get_event_loop().run_until_complete(main(model, peft_model_name, max_new_tokens, user_prompt, top_p, temperature))
  65. if __name__ == "__main__":
  66. fire.Fire(run_script)