Skip to content

Commit

Permalink
Add argument to disable FastAPI docs (#2540)
Browse files Browse the repository at this point in the history
  • Loading branch information
mouweng authored Oct 8, 2024
1 parent 2e49fc3 commit 6f34738
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 11 deletions.
4 changes: 3 additions & 1 deletion lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def add_parser_api_server():
ArgumentHelper.ssl(parser)
ArgumentHelper.model_name(parser)
ArgumentHelper.max_log_len(parser)
ArgumentHelper.disable_fastapi_docs(parser)

# chat template args
ArgumentHelper.chat_template(parser)
Expand Down Expand Up @@ -342,7 +343,8 @@ def api_server(args):
api_keys=args.api_keys,
ssl=args.ssl,
proxy_url=args.proxy_url,
max_log_len=args.max_log_len)
max_log_len=args.max_log_len,
disable_fastapi_docs=args.disable_fastapi_docs)

@staticmethod
def api_client(args):
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,3 +487,11 @@ def max_log_len(parser):
default=None,
help='Max number of prompt characters or prompt tokens being'
'printed in log. Default: Unlimited')

@staticmethod
def disable_fastapi_docs(parser):
return parser.add_argument('--disable-fastapi-docs',
action='store_true',
default=False,
help="Disable FastAPI's OpenAPI schema,"
' Swagger UI, and ReDoc endpoint')
32 changes: 22 additions & 10 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import AsyncGenerator, Dict, List, Literal, Optional, Union

import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
Expand Down Expand Up @@ -45,7 +45,7 @@ class VariableInterface:
api_server_url: Optional[str] = None


app = FastAPI(docs_url='/')
router = APIRouter()
get_bearer_token = HTTPBearer(auto_error=False)


Expand Down Expand Up @@ -88,7 +88,7 @@ def get_model_list():
return model_names


@app.get('/v1/models', dependencies=[Depends(check_api_key)])
@router.get('/v1/models', dependencies=[Depends(check_api_key)])
def available_models():
"""Show available models."""
model_cards = []
Expand Down Expand Up @@ -237,7 +237,7 @@ def _create_chat_completion_logprobs(tokenizer: Tokenizer,
return ChoiceLogprobs(content=content)


@app.get('/health')
@router.get('/health')
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
Expand Down Expand Up @@ -277,7 +277,7 @@ def _logit_bias_processor(
return partial(_logit_bias_processor, clamped_logit_bias)


@app.post('/v1/chat/completions', dependencies=[Depends(check_api_key)])
@router.post('/v1/chat/completions', dependencies=[Depends(check_api_key)])
async def chat_completions_v1(request: ChatCompletionRequest,
raw_request: Request = None):
"""Completion API similar to OpenAI's API.
Expand Down Expand Up @@ -543,7 +543,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
return response


@app.post('/v1/completions', dependencies=[Depends(check_api_key)])
@router.post('/v1/completions', dependencies=[Depends(check_api_key)])
async def completions_v1(request: CompletionRequest,
raw_request: Request = None):
"""Completion API similar to OpenAI's API.
Expand Down Expand Up @@ -751,15 +751,15 @@ async def _inner_call(i, generator):
return response


@app.post('/v1/embeddings', tags=['unsupported'])
@router.post('/v1/embeddings', tags=['unsupported'])
async def create_embeddings(request: EmbeddingsRequest,
raw_request: Request = None):
"""Creates embeddings for the text."""
return create_error_response(HTTPStatus.BAD_REQUEST,
'Unsupported by turbomind.')


@app.post('/v1/encode', dependencies=[Depends(check_api_key)])
@router.post('/v1/encode', dependencies=[Depends(check_api_key)])
async def encode(request: EncodeRequest, raw_request: Request = None):
"""Encode prompts.
Expand Down Expand Up @@ -790,7 +790,7 @@ def encode(prompt: str, do_preprocess: bool, add_bos: bool):
return EncodeResponse(input_ids=encoded, length=length)


@app.post('/v1/chat/interactive', dependencies=[Depends(check_api_key)])
@router.post('/v1/chat/interactive', dependencies=[Depends(check_api_key)])
async def chat_interactive_v1(request: GenerateRequest,
raw_request: Request = None):
"""Generate completion for the request.
Expand Down Expand Up @@ -929,7 +929,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
return JSONResponse(ret)


@app.on_event('startup')
@router.on_event('startup')
async def startup_event():
if VariableInterface.proxy_url is None:
return
Expand Down Expand Up @@ -973,6 +973,7 @@ def serve(model_path: str,
ssl: bool = False,
proxy_url: Optional[str] = None,
max_log_len: int = None,
disable_fastapi_docs: bool = False,
**kwargs):
"""An example to perform model inference through the command line
interface.
Expand Down Expand Up @@ -1022,6 +1023,17 @@ def serve(model_path: str,
os.environ['TM_LOG_LEVEL'] = log_level
logger.setLevel(log_level)

if disable_fastapi_docs:
app = FastAPI(
docs_url=None,
redoc_url=None,
openapi_url=None,
)
else:
app = FastAPI(docs_url='/')

app.include_router(router)

if allow_origins:
app.add_middleware(
CORSMiddleware,
Expand Down

0 comments on commit 6f34738

Please sign in to comment.