Skip to content

Commit

Permalink
OAI: Amend comments
Browse files Browse the repository at this point in the history
Clarify what the user can and can't see.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Jul 11, 2024
1 parent 1f46a11 commit 9fc3fc4
Showing 1 changed file with 28 additions and 6 deletions.
34 changes: 28 additions & 6 deletions endpoints/OAI/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,12 @@ async def chat_completion_request(
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models(request: Request) -> ModelList:
"""Lists all models in the model directory."""
"""
Lists all models in the model directory.
Requires an admin key to see all models.
"""

model_config = config.model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)
Expand Down Expand Up @@ -207,7 +212,11 @@ async def current_model() -> ModelCard:

@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models(request: Request) -> ModelList:
"""Lists all draft models in the model directory."""
"""
Lists all draft models in the model directory.
Requires an admin key to see all draft models.
"""

if get_key_permission(request) == "admin":
draft_model_dir = unwrap(
Expand Down Expand Up @@ -301,7 +310,11 @@ async def download_model(request: Request, data: DownloadRequest) -> DownloadRes
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def list_all_loras(request: Request) -> LoraList:
"""Lists all LoRAs in the lora directory."""
"""
Lists all LoRAs in the lora directory.
Requires an admin key to see all LoRAs.
"""

if get_key_permission(request) == "admin":
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
Expand Down Expand Up @@ -406,6 +419,7 @@ async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
)
async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse:
"""Decodes tokens into a string."""

message = model.container.decode_tokens(data.tokens, **data.get_params())
response = TokenDecodeResponse(text=unwrap(message, ""))

Expand Down Expand Up @@ -435,7 +449,11 @@ async def key_permission(request: Request) -> AuthPermissionResponse:
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def list_templates(request: Request) -> TemplateList:
"""Get a list of all templates."""
"""
Get a list of all templates.
Requires an admin key to see all templates.
"""

template_strings = []
if get_key_permission(request) == "admin":
Expand All @@ -453,7 +471,7 @@ async def list_templates(request: Request) -> TemplateList:
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template"""
"""Switch the currently loaded template."""

if not data.name:
error_message = handle_request_error(
Expand Down Expand Up @@ -488,7 +506,11 @@ async def unload_template():
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse:
"""API wrapper to list all currently applied sampler overrides"""
"""
List all currently applied sampler overrides.
Requires an admin key to see all override presets.
"""

if get_key_permission(request) == "admin":
presets = sampling.get_all_presets()
Expand Down

0 comments on commit 9fc3fc4

Please sign in to comment.