diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py index 6cfccf5c..ff0ec8c5 100644 --- a/endpoints/Kobold/router.py +++ b/endpoints/Kobold/router.py @@ -29,9 +29,20 @@ "Streaming": "http://{host}:{port}/api/extra/generate/stream", } +kai_router = APIRouter() +extra_kai_router = APIRouter() -@router.post( - "/v1/generate", + +def setup(): + router.include_router(kai_router, prefix="/v1") + router.include_router(kai_router, prefix="/latest", include_in_schema=False) + router.include_router(extra_kai_router, prefix="/extra") + + return router + + +@kai_router.post( + "/generate", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: @@ -40,8 +51,8 @@ async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: return response -@router.post( - "/extra/generate/stream", +@extra_kai_router.post( + "/generate/stream", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse: @@ -50,8 +61,8 @@ async def generate_stream(request: Request, data: GenerateRequest) -> GenerateRe return response -@router.post( - "/extra/abort", +@extra_kai_router.post( + "/abort", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def abort_generate(data: AbortRequest): @@ -60,12 +71,12 @@ async def abort_generate(data: AbortRequest): return response -@router.get( - "/extra/generate/check", +@extra_kai_router.get( + "/generate/check", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) -@router.post( - "/extra/generate/check", +@extra_kai_router.post( + "/generate/check", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: @@ -74,8 +85,8 @@ async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: return response -@router.get( - "/v1/model", dependencies=[Depends(check_api_key), Depends(check_model_container)] +@kai_router.get( + "/model", dependencies=[Depends(check_api_key), Depends(check_model_container)] ) async def current_model(): """Fetches the current model and who owns it.""" @@ -84,8 +95,8 @@ async def current_model(): return {"result": f"{current_model_card.owned_by}/{current_model_card.id}"} -@router.post( - "/extra/tokencount", +@extra_kai_router.post( + "/tokencount", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def get_tokencount(data: TokenCountRequest): @@ -94,14 +105,14 @@ async def get_tokencount(data: TokenCountRequest): return TokenCountResponse(value=len(tokens), ids=tokens) -@router.get("/v1/info/version") +@kai_router.get("/info/version") async def get_version(): """Impersonate KAI United.""" return {"result": "1.2.5"} -@router.get("/extra/version") +@extra_kai_router.get("/version") async def get_extra_version(): """Impersonate Koboldcpp.""" diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 55a2205f..d9701619 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -32,6 +32,10 @@ } +def setup(): + return router + + # Completions endpoint @router.post( "/v1/completions", diff --git a/endpoints/server.py b/endpoints/server.py index f8bfc8ac..0b3edfbb 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -47,14 +47,14 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None): selected_server = router_mapping.get(server.lower()) if selected_server: - app.include_router(selected_server.router) + app.include_router(selected_server.setup()) logger.info(f"Starting {selected_server.api_name} API") for path, url in selected_server.urls.items(): formatted_url = url.format(host=host, port=port) logger.info(f"{path}: {formatted_url}") else: - app.include_router(OAIRouter.router) + app.include_router(OAIRouter.setup()) for path, url in OAIRouter.urls.items(): formatted_url = url.format(host=host, port=port) logger.info(f"{path}: {formatted_url}")