Skip to content

Commit

Permalink
[shortfin llm] Simplify interface between llm specific code and fasta…
Browse files Browse the repository at this point in the history
…pi webapp (#985)

Currently it's a little hard to trace through the server initialization
code because a lot of global variables are involved. This PR eliminates
them and uses a singe location to put a context manager for init-deinit.
  • Loading branch information
renxida authored Feb 19, 2025
1 parent 8b41015 commit aa31e1e
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 93 deletions.
8 changes: 5 additions & 3 deletions shortfin/python/shortfin_apps/llm/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
Uses shortfin_apps.llm.components.lifecycle to configure a FastAPI application.
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from .lifecycle_hooks import lifespan
from .routes import application_router, generation_router
from fastapi import FastAPI


def add_routes(app: FastAPI):
Expand All @@ -27,7 +29,7 @@ def add_middleware(app: FastAPI):
return app


def get_app() -> FastAPI:
def get_app(lifespan) -> FastAPI:
app = FastAPI(lifespan=lifespan)
app = add_routes(app)
app = add_middleware(app)
Expand Down
113 changes: 113 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
Implements a context manager that configures a shortfin llm server from a namespace mirroring server.py's commandline args, and exposes a context manager interface such that we can do:
```python
def lifecycle(app: FastApi):
with lifecycle_manager(args) as man:
yield
```
"""


from .config_struct import ModelParams, ServerParams
from .manager import SystemManager
from .service import GenerateService
from .tokenizer import Tokenizer
from typing import TYPE_CHECKING
from fastapi import FastAPI


from contextlib import asynccontextmanager
import logging


def get_eos_from_tokenizer_config(json_path):
import json

with open(json_path, "rt") as f:
json_text = f.read()
config = json.loads(json_text)
return config["eos_token"]


class ShortfinLlmLifecycleManager:
"""
Manages the lifecycle of a shortfin llm server, including config loading and parameter setup.
There are generally two ways to use this.
To start a full shortfin server, use the context manager or the fastapi_lifespan method.
To initialize a shortfin server but not start it, use the constructor, then manipulate the services and sysman attributes directly.
"""

def __init__(self, args):
# Load server configuration with priority: command line > config file > defaults
server_params = ServerParams.load(
args.server_config if hasattr(args, "server_config") else None
)
server_params.update_from_args(args)

# Setup system (configure devices, etc).
sysman = SystemManager(
device=args.device,
device_ids=server_params.device_ids,
async_allocs=server_params.amdgpu_async_allocations,
amdgpu_allocators=server_params.amdgpu_allocators,
)

# Setup each service we are hosting.
eos_token = get_eos_from_tokenizer_config(args.tokenizer_config_json)
tokenizer = Tokenizer.from_tokenizer_json_file(
args.tokenizer_json, eos_token=eos_token
)
model_params = ModelParams.load_json(args.model_config)
service = GenerateService(
name="default",
sysman=sysman,
tokenizer=tokenizer,
model_params=model_params,
server_params=server_params,
program_isolation=server_params.program_isolation,
)
service.load_inference_module(args.vmfb)
service.load_inference_parameters(*args.parameters, parameter_scope="model")
self.sysman = sysman
self.services = {"default": service}

def __enter__(self):
self.sysman.start()
for service_name, service in self.services.items():
logging.info("Initializing service '%s': %r", service_name, service)
service.start()
return self

def __exit__(self, exc_type, exc_value, traceback):
for service_name, service in self.services.items():
logging.info("Shutting down service '%s'", service_name)
service.shutdown()
self.sysman.shutdown()
return False

@asynccontextmanager
async def fastapi_lifespan(self, app: FastAPI):
"""
Context manager for FastAPI lifespan events.
Initializes the system manager and services when the app starts, and shuts them down when the app stops.
Also provides the services via app.state, which can be accessed from route handlers via
request.app.state.services.
Implements API described in https://fastapi.tiangolo.com/advanced/events/#lifespan
See `server.py` for a usage example.
"""
with self:
app.state.services = self.services
yield
37 changes: 0 additions & 37 deletions shortfin/python/shortfin_apps/llm/lifecycle_hooks.py

This file was deleted.

6 changes: 4 additions & 2 deletions shortfin/python/shortfin_apps/llm/routes/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@

from ..components.generate import ClientGenerateBatchProcess
from ..components.io_struct import GenerateReqInput
from ..lifecycle_hooks import services
from ..components.service import GenerateService

generation_router = APIRouter()


@generation_router.post("/generate")
@generation_router.put("/generate")
async def generate_request(gen_req: GenerateReqInput, request: Request):
service = services["default"]
# app.state.services is populated by the ShortfinLlmLifecycleManager
# see shortfin/python/shortfin_apps/llm/components/lifecycle.py
service: GenerateService = request.app.state.services["default"]
gen_req.post_init()
responder = FastAPIResponder(request)
ClientGenerateBatchProcess(service, gen_req, responder).launch()
Expand Down
55 changes: 4 additions & 51 deletions shortfin/python/shortfin_apps/llm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@
from shortfin import ProgramIsolation
import uvicorn

from . import lifecycle_hooks
from .application import get_app
from .components.config_struct import ModelParams, ServerParams
from .components.manager import SystemManager
from .components.service import GenerateService
from .components.tokenizer import Tokenizer
from .components.lifecycle import ShortfinLlmLifecycleManager


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,50 +49,6 @@
}


def get_eos_from_tokenizer_config(json_path):
import json

with open(json_path, "rt") as f:
json_text = f.read()
config = json.loads(json_text)
return config["eos_token"]


def configure(args) -> SystemManager:
# Load server configuration with priority: command line > config file > defaults
server_params = ServerParams.load(
args.server_config if hasattr(args, "server_config") else None
)
server_params.update_from_args(args)

# Setup system (configure devices, etc).
sysman = SystemManager(
device=args.device,
device_ids=server_params.device_ids,
async_allocs=server_params.amdgpu_async_allocations,
amdgpu_allocators=server_params.amdgpu_allocators,
)

# Setup each service we are hosting.
eos_token = get_eos_from_tokenizer_config(args.tokenizer_config_json)
tokenizer = Tokenizer.from_tokenizer_json_file(
args.tokenizer_json, eos_token=eos_token
)
model_params = ModelParams.load_json(args.model_config)
sm = GenerateService(
name="default",
sysman=sysman,
tokenizer=tokenizer,
model_params=model_params,
server_params=server_params,
program_isolation=server_params.program_isolation,
)
sm.load_inference_module(args.vmfb)
sm.load_inference_parameters(*args.parameters, parameter_scope="model")
lifecycle_hooks.services[sm.name] = sm
return sysman


def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
Expand Down Expand Up @@ -194,10 +146,11 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
args.tokenizer_json.stem + "_config.json"
)
args.tokenizer_config_json = inferred_tokenizer_config_path
lifecycle_hooks.sysman = configure(args)

lifecycle_manager = ShortfinLlmLifecycleManager(args)

uvicorn.run(
get_app(),
get_app(lifecycle_manager.fastapi_lifespan),
host=args.host,
port=args.port,
log_config=log_config,
Expand Down

0 comments on commit aa31e1e

Please sign in to comment.