From 93203ef46ab8fef35a2c0a7d9e469150f337bc41 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sun, 29 Sep 2024 20:37:38 +0100 Subject: [PATCH] Add tags to API endpoints --- endpoints/Kobold/router.py | 24 +++++++++--- endpoints/OAI/router.py | 8 ++-- endpoints/core/router.py | 74 ++++++++++++++++++++++++++++-------- endpoints/core/types/tags.py | 13 +++++++ 4 files changed, 94 insertions(+), 25 deletions(-) create mode 100644 endpoints/core/types/tags.py diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py index 310a3809..c8e45bb4 100644 --- a/endpoints/Kobold/router.py +++ b/endpoints/Kobold/router.py @@ -6,6 +6,7 @@ from common.auth import check_api_key from common.model import check_model_container from common.utils import unwrap +from endpoints.core.types.tags import Tags from endpoints.core.utils.model import get_current_model from endpoints.Kobold.types.generation import ( AbortRequest, @@ -46,6 +47,7 @@ def setup(): @kai_router.post( "/generate", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: response = await get_generation(data, request) @@ -56,6 +58,7 @@ async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: @extra_kai_router.post( "/generate/stream", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse: response = EventSourceResponse(stream_generation(data, request), ping=maxsize) @@ -66,6 +69,7 @@ async def generate_stream(request: Request, data: GenerateRequest) -> GenerateRe @extra_kai_router.post( "/abort", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) async def abort_generate(data: AbortRequest) -> AbortResponse: response = await abort_generation(data.genkey) @@ -76,10 +80,12 @@ async def abort_generate(data: AbortRequest) -> AbortResponse: @extra_kai_router.get( "/generate/check", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) @extra_kai_router.post( "/generate/check", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: response = await generation_status(data.genkey) @@ -88,7 +94,9 @@ async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: @kai_router.get( - "/model", dependencies=[Depends(check_api_key), Depends(check_model_container)] + "/model", + dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) async def current_model() -> CurrentModelResponse: """Fetches the current model and who owns it.""" @@ -100,6 +108,7 @@ async def current_model() -> CurrentModelResponse: @extra_kai_router.post( "/tokencount", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse: raw_tokens = model.container.encode_tokens(data.prompt) @@ -110,14 +119,17 @@ async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse: @kai_router.get( "/config/max_length", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) @kai_router.get( "/config/max_context_length", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) @extra_kai_router.get( "/true_max_context_length", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) async def get_max_length() -> MaxLengthResponse: """Fetches the max length of the model.""" @@ -126,35 +138,35 @@ async def get_max_length() -> MaxLengthResponse: return {"value": max_length} -@kai_router.get("/info/version") +@kai_router.get("/info/version", tags=[Tags.Kobold]) async def get_version(): """Impersonate KAI United.""" return {"result": "1.2.5"} -@extra_kai_router.get("/version") +@extra_kai_router.get("/version", tags=[Tags.Kobold]) async def get_extra_version(): """Impersonate Koboldcpp.""" return {"result": "KoboldCpp", "version": "1.71"} -@kai_router.get("/config/soft_prompts_list") +@kai_router.get("/config/soft_prompts_list", tags=[Tags.Kobold]) async def get_available_softprompts(): """Used for KAI compliance.""" return {"values": []} -@kai_router.get("/config/soft_prompt") +@kai_router.get("/config/soft_prompt", tags=[Tags.Kobold]) async def get_current_softprompt(): """Used for KAI compliance.""" return {"value": ""} -@kai_router.put("/config/soft_prompt") +@kai_router.put("/config/soft_prompt", tags=[Tags.Kobold]) async def set_current_softprompt(): """Used for KAI compliance.""" diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index b6a44c98..094aab62 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -25,6 +25,7 @@ stream_generate_completion, ) from endpoints.OAI.utils.embeddings import get_embeddings +from endpoints.core.types.tags import Tags api_name = "OAI" @@ -41,8 +42,7 @@ def setup(): # Completions endpoint @router.post( - "/v1/completions", - dependencies=[Depends(check_api_key)], + "/v1/completions", dependencies=[Depends(check_api_key)], tags=[Tags.OpenAI] ) async def completion_request( request: Request, data: CompletionRequest @@ -96,8 +96,7 @@ async def completion_request( # Chat completions endpoint @router.post( - "/v1/chat/completions", - dependencies=[Depends(check_api_key)], + "/v1/chat/completions", dependencies=[Depends(check_api_key)], tags=[Tags.OpenAI] ) async def chat_completion_request( request: Request, data: ChatCompletionRequest @@ -156,6 +155,7 @@ async def chat_completion_request( @router.post( "/v1/embeddings", dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], + tags=[Tags.OpenAI], ) async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse: embeddings_task = asyncio.create_task(get_embeddings(data, request)) diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 2c60cd77..deb17661 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -28,6 +28,7 @@ SamplerOverrideListResponse, SamplerOverrideSwitchRequest, ) +from endpoints.core.types.tags import Tags from endpoints.core.types.template import TemplateList, TemplateSwitchRequest from endpoints.core.types.token import ( TokenDecodeRequest, @@ -48,7 +49,7 @@ # Healthcheck endpoint -@router.get("/health") +@router.get("/health", tags=[Tags.Core]) async def healthcheck(response: Response) -> HealthCheckResponse: """Get the current service health status""" healthy, issues = await HealthManager.is_service_healthy() @@ -62,8 +63,12 @@ async def healthcheck(response: Response) -> HealthCheckResponse: # Model list endpoint -@router.get("/v1/models", dependencies=[Depends(check_api_key)]) -@router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/models", + dependencies=[Depends(check_api_key)], + tags=[Tags.OpenAI, Tags.List], +) +@router.get("/v1/model/list", dependencies=[Depends(check_api_key)], tags=[Tags.List]) async def list_models(request: Request) -> ModelList: """ Lists all models in the model directory. @@ -91,6 +96,7 @@ async def list_models(request: Request) -> ModelList: @router.get( "/v1/model", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.List], ) async def current_model() -> ModelCard: """Returns the currently loaded model.""" @@ -98,7 +104,11 @@ async def current_model() -> ModelCard: return get_current_model() -@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/model/draft/list", + dependencies=[Depends(check_api_key)], + tags=[Tags.List], +) async def list_draft_models(request: Request) -> ModelList: """ Lists all draft models in the model directory. @@ -118,7 +128,9 @@ async def list_draft_models(request: Request) -> ModelList: # Load model endpoint -@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) +@router.post( + "/v1/model/load", dependencies=[Depends(check_admin_key)], tags=[Tags.Admin] +) async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: """Loads a model into the model container. This returns an SSE stream.""" @@ -163,13 +175,14 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: @router.post( "/v1/model/unload", dependencies=[Depends(check_admin_key), Depends(check_model_container)], + tags=[Tags.Admin], ) async def unload_model(): """Unloads the currently loaded model.""" await model.unload_model(skip_wait=True) -@router.post("/v1/download", dependencies=[Depends(check_admin_key)]) +@router.post("/v1/download", dependencies=[Depends(check_admin_key)], tags=[Tags.Admin]) async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse: """Downloads a model from HuggingFace.""" @@ -191,8 +204,8 @@ async def download_model(request: Request, data: DownloadRequest) -> DownloadRes # Lora list endpoint -@router.get("/v1/loras", dependencies=[Depends(check_api_key)]) -@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) +@router.get("/v1/loras", dependencies=[Depends(check_api_key)], tags=[Tags.List]) +@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)], tags=[Tags.List]) async def list_all_loras(request: Request) -> LoraList: """ Lists all LoRAs in the lora directory. @@ -213,6 +226,7 @@ async def list_all_loras(request: Request) -> LoraList: @router.get( "/v1/lora", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.List], ) async def active_loras() -> LoraList: """Returns the currently loaded loras.""" @@ -224,6 +238,7 @@ async def active_loras() -> LoraList: @router.post( "/v1/lora/load", dependencies=[Depends(check_admin_key), Depends(check_model_container)], + tags=[Tags.Admin], ) async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse: """Loads a LoRA into the model container.""" @@ -259,6 +274,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse: @router.post( "/v1/lora/unload", dependencies=[Depends(check_admin_key), Depends(check_model_container)], + tags=[Tags.Admin], ) async def unload_loras(): """Unloads the currently loaded loras.""" @@ -266,7 +282,11 @@ async def unload_loras(): await model.unload_loras() -@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/model/embedding/list", + dependencies=[Depends(check_api_key)], + tags=[Tags.List], +) async def list_embedding_models(request: Request) -> ModelList: """ Lists all embedding models in the model directory. @@ -288,6 +308,7 @@ async def list_embedding_models(request: Request) -> ModelList: @router.get( "/v1/model/embedding", dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], + tags=[Tags.List], ) async def get_embedding_model() -> ModelCard: """Returns the currently loaded embedding model.""" @@ -296,7 +317,11 @@ async def get_embedding_model() -> ModelCard: return models.data[0] -@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)]) +@router.post( + "/v1/model/embedding/load", + dependencies=[Depends(check_admin_key)], + tags=[Tags.Admin], +) async def load_embedding_model( request: Request, data: EmbeddingModelLoadRequest ) -> ModelLoadResponse: @@ -343,6 +368,7 @@ async def load_embedding_model( @router.post( "/v1/model/embedding/unload", dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)], + tags=[Tags.Admin], ) async def unload_embedding_model(): """Unloads the current embedding model.""" @@ -354,6 +380,7 @@ async def unload_embedding_model(): @router.post( "/v1/token/encode", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Tokenisation], ) async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: """Encodes a string or chat completion messages into tokens.""" @@ -384,6 +411,7 @@ async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: @router.post( "/v1/token/decode", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Tokenisation], ) async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: """Decodes tokens into a string.""" @@ -394,7 +422,9 @@ async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: return response -@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/auth/permission", dependencies=[Depends(check_api_key)], tags=[Tags.Auth] +) async def key_permission(request: Request) -> AuthPermissionResponse: """ Gets the access level/permission of a provided key in headers. @@ -414,8 +444,10 @@ async def key_permission(request: Request) -> AuthPermissionResponse: raise HTTPException(400, error_message) from exc -@router.get("/v1/templates", dependencies=[Depends(check_api_key)]) -@router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) +@router.get("/v1/templates", dependencies=[Depends(check_api_key)], tags=[Tags.List]) +@router.get( + "/v1/template/list", dependencies=[Depends(check_api_key)], tags=[Tags.List] +) async def list_templates(request: Request) -> TemplateList: """ Get a list of all templates. @@ -437,6 +469,7 @@ async def list_templates(request: Request) -> TemplateList: @router.post( "/v1/template/switch", dependencies=[Depends(check_admin_key), Depends(check_model_container)], + tags=[Tags.Admin], ) async def switch_template(data: TemplateSwitchRequest): """Switch the currently loaded template.""" @@ -464,6 +497,7 @@ async def switch_template(data: TemplateSwitchRequest): @router.post( "/v1/template/unload", dependencies=[Depends(check_admin_key), Depends(check_model_container)], + tags=[Tags.Admin], ) async def unload_template(): """Unloads the currently selected template""" @@ -472,8 +506,16 @@ async def unload_template(): # Sampler override endpoints -@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) -@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/sampling/overrides", + dependencies=[Depends(check_api_key)], + tags=[Tags.List], +) +@router.get( + "/v1/sampling/override/list", + dependencies=[Depends(check_api_key)], + tags=[Tags.List], +) async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse: """ List all currently applied sampler overrides. @@ -494,6 +536,7 @@ async def list_sampler_overrides(request: Request) -> SamplerOverrideListRespons @router.post( "/v1/sampling/override/switch", dependencies=[Depends(check_admin_key)], + tags=[Tags.Admin], ) async def switch_sampler_override(data: SamplerOverrideSwitchRequest): """Switch the currently loaded override preset""" @@ -523,6 +566,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest): @router.post( "/v1/sampling/override/unload", dependencies=[Depends(check_admin_key)], + tags=[Tags.Admin], ) async def unload_sampler_override(): """Unloads the currently selected override preset""" diff --git a/endpoints/core/types/tags.py b/endpoints/core/types/tags.py new file mode 100644 index 00000000..0d785885 --- /dev/null +++ b/endpoints/core/types/tags.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class Tags(str, Enum): + """openapi endpoint groups""" + + OpenAI = "OpenAI" + Kobold = "Kobold" + Admin = "Admin" + List = "List" + Tokenisation = "Tokenisation" + Core = "Core" + Auth = "Auth"