Skip to content

Commit

Permalink
OAI: Add return types for docs
Browse files Browse the repository at this point in the history
Adding return types allows for responses to get included in the
autogenerated docs.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Jul 8, 2024
1 parent 62e495f commit 521d21b
Showing 1 changed file with 38 additions and 20 deletions.
58 changes: 38 additions & 20 deletions endpoints/OAI/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
from common.templating import PromptTemplate, get_all_templates
from common.utils import coalesce, unwrap
from endpoints.OAI.types.auth import AuthPermissionResponse
from endpoints.OAI.types.completion import CompletionRequest
from endpoints.OAI.types.chat_completion import ChatCompletionRequest
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest,
ChatCompletionResponse,
)
from endpoints.OAI.types.download import DownloadRequest, DownloadResponse
from endpoints.OAI.types.lora import (
LoraCard,
Expand All @@ -23,8 +26,10 @@
)
from endpoints.OAI.types.model import (
ModelCard,
ModelList,
ModelLoadRequest,
ModelCardParameters,
ModelLoadResponse,
)
from endpoints.OAI.types.sampler_overrides import (
SamplerOverrideListResponse,
Expand Down Expand Up @@ -70,7 +75,7 @@ async def check_model_container():
# Model list endpoint
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
async def list_models() -> ModelList:
"""Lists all models in the model directory."""
model_config = config.model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
Expand All @@ -90,7 +95,7 @@ async def list_models():
"/v1/model",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_current_model():
async def get_current_model() -> ModelCard:
"""Returns the currently loaded model."""
model_params = model.container.get_model_parameters()
draft_model_params = model_params.pop("draft", {})
Expand Down Expand Up @@ -121,7 +126,7 @@ async def get_current_model():


@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models():
async def list_draft_models() -> ModelList:
"""Lists all draft models in the model directory."""
draft_model_dir = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
Expand All @@ -135,8 +140,8 @@ async def list_draft_models():

# Load model endpoint
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest):
"""Loads a model into the model container."""
async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
"""Loads a model into the model container. This returns an SSE stream."""

# Verify request parameters
if not data.name:
Expand Down Expand Up @@ -189,7 +194,7 @@ async def unload_model():

@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def get_templates():
async def get_templates() -> TemplateList:
templates = get_all_templates()
template_strings = [template.stem for template in templates]
return TemplateList(data=template_strings)
Expand Down Expand Up @@ -233,7 +238,7 @@ async def unload_template():
# Sampler override endpoints
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides():
async def list_sampler_overrides() -> SamplerOverrideListResponse:
"""API wrapper to list all currently applied sampler overrides"""

return SamplerOverrideListResponse(
Expand Down Expand Up @@ -281,7 +286,7 @@ async def unload_sampler_override():


@router.post("/v1/download", dependencies=[Depends(check_admin_key)])
async def download_model(request: Request, data: DownloadRequest):
async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse:
"""Downloads a model from HuggingFace."""

try:
Expand All @@ -304,7 +309,7 @@ async def download_model(request: Request, data: DownloadRequest):
# Lora list endpoint
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def get_all_loras():
async def get_all_loras() -> LoraList:
"""Lists all LoRAs in the lora directory."""
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
Expand All @@ -317,7 +322,7 @@ async def get_all_loras():
"/v1/lora",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_active_loras():
async def get_active_loras() -> LoraList:
"""Returns the currently loaded loras."""
active_loras = LoraList(
data=[
Expand All @@ -337,7 +342,7 @@ async def get_active_loras():
"/v1/lora/load",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def load_lora(data: LoraLoadRequest):
async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
"""Loads a LoRA into the model container."""

if not data.loras:
Expand Down Expand Up @@ -383,7 +388,7 @@ async def unload_loras():
"/v1/token/encode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def encode_tokens(data: TokenEncodeRequest):
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
"""Encodes a string or chat completion messages into tokens."""

if isinstance(data.text, str):
Expand Down Expand Up @@ -413,7 +418,7 @@ async def encode_tokens(data: TokenEncodeRequest):
"/v1/token/decode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def decode_tokens(data: TokenDecodeRequest):
async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse:
"""Decodes tokens into a string."""
message = model.container.decode_tokens(data.tokens, **data.get_params())
response = TokenDecodeResponse(text=unwrap(message, ""))
Expand All @@ -426,7 +431,7 @@ async def get_key_permission(
x_admin_key: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None),
authorization: Optional[str] = Header(None),
):
) -> AuthPermissionResponse:
"""
Gets the access level/permission of a provided key in headers.
Expand All @@ -452,8 +457,15 @@ async def get_key_permission(
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def completion_request(request: Request, data: CompletionRequest):
"""Generates a completion from a prompt."""
async def completion_request(
request: Request, data: CompletionRequest
) -> CompletionResponse:
"""
Generates a completion from a prompt.
If stream = true, this returns an SSE stream.
"""

model_path = model.container.get_model_path()

if isinstance(data.prompt, list):
Expand Down Expand Up @@ -488,8 +500,14 @@ async def completion_request(request: Request, data: CompletionRequest):
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def chat_completion_request(request: Request, data: ChatCompletionRequest):
"""Generates a chat completion from a prompt."""
async def chat_completion_request(
request: Request, data: ChatCompletionRequest
) -> ChatCompletionResponse:
"""
Generates a chat completion from a prompt.
If stream = true, this returns an SSE stream.
"""

if model.container.prompt_template is None:
error_message = handle_request_error(
Expand Down

0 comments on commit 521d21b

Please sign in to comment.