diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py index c8e45bb4..9d4c5f58 100644 --- a/endpoints/Kobold/router.py +++ b/endpoints/Kobold/router.py @@ -26,14 +26,14 @@ api_name = "KoboldAI" -router = APIRouter(prefix="/api") +router = APIRouter(prefix="/api", tags=[Tags.Kobold]) urls = { "Generation": "http://{host}:{port}/api/v1/generate", "Streaming": "http://{host}:{port}/api/extra/generate/stream", } -kai_router = APIRouter() -extra_kai_router = APIRouter() +kai_router = APIRouter(tags=[Tags.Kobold]) +extra_kai_router = APIRouter(tags=[Tags.Kobold]) def setup(): @@ -50,6 +50,7 @@ def setup(): tags=[Tags.Kobold], ) async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: + """Generate a response to a prompt.""" response = await get_generation(data, request) return response @@ -61,6 +62,7 @@ async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: tags=[Tags.Kobold], ) async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse: + """Stream the chat response to a prompt.""" response = EventSourceResponse(stream_generation(data, request), ping=maxsize) return response @@ -72,6 +74,7 @@ async def generate_stream(request: Request, data: GenerateRequest) -> GenerateRe tags=[Tags.Kobold], ) async def abort_generate(data: AbortRequest) -> AbortResponse: + """Aborts a generation from the cache.""" response = await abort_generation(data.genkey) return response @@ -88,6 +91,7 @@ async def abort_generate(data: AbortRequest) -> AbortResponse: tags=[Tags.Kobold], ) async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: + """Fetches the status of a generation from the cache.""" response = await generation_status(data.genkey) return response @@ -111,6 +115,7 @@ async def current_model() -> CurrentModelResponse: tags=[Tags.Kobold], ) async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse: + """Get the number of tokens in a given prompt.""" raw_tokens = model.container.encode_tokens(data.prompt) tokens = unwrap(raw_tokens, []) return TokenCountResponse(value=len(tokens), ids=tokens) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 094aab62..2143c173 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -29,7 +29,7 @@ api_name = "OAI" -router = APIRouter() +router = APIRouter(tags=[Tags.OpenAI]) urls = { "Completions": "http://{host}:{port}/v1/completions", "Chat completions": "http://{host}:{port}/v1/chat/completions", @@ -158,6 +158,10 @@ async def chat_completion_request( tags=[Tags.OpenAI], ) async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse: + """Generate Text embeddings for a given text input. + + Requires Infinity embed to be installed and an embedding model to be loaded. + """ embeddings_task = asyncio.create_task(get_embeddings(data, request)) response = await run_with_request_disconnect( request, diff --git a/endpoints/core/router.py b/endpoints/core/router.py index deb17661..715ba74a 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -325,6 +325,8 @@ async def get_embedding_model() -> ModelCard: async def load_embedding_model( request: Request, data: EmbeddingModelLoadRequest ) -> ModelLoadResponse: + """Loads an embedding model.""" + # Verify request parameters if not data.name: error_message = handle_request_error(