inference.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # ## Serving Llama 3.2 3B Instruct Model With vLLM
  2. # This app runs a vLLM server on an A100 GPU.
  3. #
  4. # Run it with:
  5. # modal deploy inference
  6. import modal
  7. # This defines the image to use for the vLLM server container
  8. vllm_image = modal.Image.debian_slim(python_version="3.10").pip_install(
  9. "vllm==0.5.3post1"
  10. )
  11. MODELS_DIR = "/llamas"
  12. MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
  13. # Ensure the model is downloaded and the volume exists
  14. try:
  15. volume = modal.Volume.lookup("llamas", create_if_missing=False)
  16. except modal.exception.NotFoundError:
  17. raise Exception("Download models first with modal run download")
  18. app = modal.App("many-llamas-human-eval")
  19. N_GPU = 1 # tip: for best results, first upgrade to more powerful GPUs, and only then increase GPU count
  20. TOKEN = (
  21. "super-secret-token" # auth token. for production use, replace with a modal.Secret
  22. )
  23. MINUTES = 60 # seconds
  24. HOURS = 60 * MINUTES
  25. @app.function(
  26. image=vllm_image,
  27. gpu=modal.gpu.A100(count=N_GPU, size="40GB"),
  28. container_idle_timeout=5 * MINUTES,
  29. timeout=24 * HOURS,
  30. allow_concurrent_inputs=20, # VLLM will batch requests so many can be received at once
  31. volumes={MODELS_DIR: volume},
  32. concurrency_limit=10, # max 10 GPUs
  33. )
  34. @modal.asgi_app()
  35. def serve():
  36. import fastapi
  37. import vllm.entrypoints.openai.api_server as api_server
  38. from vllm.engine.arg_utils import AsyncEngineArgs
  39. from vllm.engine.async_llm_engine import AsyncLLMEngine
  40. from vllm.entrypoints.logger import RequestLogger
  41. from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
  42. from vllm.entrypoints.openai.serving_completion import (
  43. OpenAIServingCompletion,
  44. )
  45. from vllm.usage.usage_lib import UsageContext
  46. volume.reload() # ensure we have the latest version of the weights
  47. # create a fastAPI app that uses vLLM's OpenAI-compatible router
  48. web_app = fastapi.FastAPI(
  49. title=f"OpenAI-compatible {MODEL_NAME} server",
  50. description="Run an OpenAI-compatible LLM server with vLLM on modal.com",
  51. version="0.0.1",
  52. docs_url="/docs",
  53. )
  54. # security: CORS middleware for external requests
  55. http_bearer = fastapi.security.HTTPBearer(
  56. scheme_name="Bearer Token",
  57. description="See code for authentication details.",
  58. )
  59. web_app.add_middleware(
  60. fastapi.middleware.cors.CORSMiddleware,
  61. allow_origins=["*"],
  62. allow_credentials=True,
  63. allow_methods=["*"],
  64. allow_headers=["*"],
  65. )
  66. # security: inject dependency on authed routes
  67. async def is_authenticated(api_key: str = fastapi.Security(http_bearer)):
  68. if api_key.credentials != TOKEN:
  69. raise fastapi.HTTPException(
  70. status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
  71. detail="Invalid authentication credentials",
  72. )
  73. return {"username": "authenticated_user"}
  74. router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)])
  75. # wrap vllm's router in auth router
  76. router.include_router(api_server.router)
  77. # add authed vllm to our fastAPI app
  78. web_app.include_router(router)
  79. engine_args = AsyncEngineArgs(
  80. model=MODELS_DIR + "/" + MODEL_NAME,
  81. tensor_parallel_size=N_GPU,
  82. gpu_memory_utilization=0.90,
  83. max_model_len=2048,
  84. enforce_eager=False, # capture the graph for faster inference, but slower cold starts (30s > 20s)
  85. )
  86. engine = AsyncLLMEngine.from_engine_args(
  87. engine_args, usage_context=UsageContext.OPENAI_API_SERVER
  88. )
  89. model_config = get_model_config(engine)
  90. request_logger = RequestLogger(max_log_len=2048)
  91. api_server.openai_serving_chat = OpenAIServingChat(
  92. engine,
  93. model_config=model_config,
  94. served_model_names=[MODEL_NAME],
  95. chat_template=None,
  96. response_role="assistant",
  97. lora_modules=[],
  98. prompt_adapters=[],
  99. request_logger=request_logger,
  100. )
  101. api_server.openai_serving_completion = OpenAIServingCompletion(
  102. engine,
  103. model_config=model_config,
  104. served_model_names=[MODEL_NAME],
  105. lora_modules=[],
  106. prompt_adapters=[],
  107. request_logger=request_logger,
  108. )
  109. return web_app
  110. def get_model_config(engine):
  111. import asyncio
  112. try: # adapted from vLLM source -- https://github.com/vllm-project/vllm/blob/507ef787d85dec24490069ffceacbd6b161f4f72/vllm/entrypoints/openai/api_server.py#L235C1-L247C1
  113. event_loop = asyncio.get_running_loop()
  114. except RuntimeError:
  115. event_loop = None
  116. if event_loop is not None and event_loop.is_running():
  117. # If the current is instanced by Ray Serve,
  118. # there is already a running event loop
  119. model_config = event_loop.run_until_complete(engine.get_model_config())
  120. else:
  121. # When using single vLLM without engine_use_ray
  122. model_config = asyncio.run(engine.get_model_config())
  123. return model_config