Skip to content

Commit

Permalink
Endpoints: Add key permission checker
Browse files Browse the repository at this point in the history
This is a definite way to check if an authorized key is API or admin.
The endpoint only runs if the key is valid in the first place to keep
inline with the API's security model.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
kingbri1 committed Mar 18, 2024
1 parent c9a6d9a commit 3c08f46
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 2 deletions.
14 changes: 14 additions & 0 deletions common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from loguru import logger
from typing import Optional

from endpoints.OAI.types.auth import AuthPermissionResponse


class AuthKeys(BaseModel):
"""
Expand Down Expand Up @@ -75,6 +77,18 @@ def load_auth_keys(disable_from_config: bool):
)


async def validate_key_permission(test_key: str):
if test_key.lower().startswith("bearer"):
test_key = test_key.split(" ")[1]

if AUTH_KEYS.verify_key(test_key, "admin_key"):
return AuthPermissionResponse(permission="admin")
elif AUTH_KEYS.verify_key(test_key, "api_key"):
return AuthPermissionResponse(permission="api")
else:
raise ValueError("The provided authentication key is invalid.")


async def check_api_key(
x_api_key: str = Header(None), authorization: str = Header(None)
):
Expand Down
32 changes: 30 additions & 2 deletions endpoints/OAI/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import signal
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi import FastAPI, Depends, HTTPException, Header, Request
from fastapi.middleware.cors import CORSMiddleware
from functools import partial
from loguru import logger
from sse_starlette import EventSourceResponse
from sys import maxsize
from typing import Optional

from common import config, model, gen_logging, sampling
from common.auth import check_admin_key, check_api_key
from common.auth import check_admin_key, check_api_key, validate_key_permission
from common.concurrency import (
call_with_semaphore,
generate_with_semaphore,
Expand All @@ -22,6 +23,7 @@
get_template_from_file,
)
from common.utils import (
coalesce,
handle_request_error,
unwrap,
)
Expand Down Expand Up @@ -399,6 +401,32 @@ async def decode_tokens(data: TokenDecodeRequest):
return response


@app.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
async def get_key_permission(
x_admin_key: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None),
authorization: Optional[str] = Header(None),
):
"""
Gets the access level/permission of a provided key in headers.
Priority:
- X-api-key
- X-admin-key
- Authorization
"""

test_key = coalesce(x_admin_key, x_api_key, authorization)

try:
response = await validate_key_permission(test_key)
return response
except ValueError as exc:
error_message = handle_request_error(str(exc)).error.message

raise HTTPException(400, error_message) from exc


# Completions endpoint
@app.post(
"/v1/completions",
Expand Down
7 changes: 7 additions & 0 deletions endpoints/OAI/types/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Types for auth requests."""

from pydantic import BaseModel


class AuthPermissionResponse(BaseModel):
permission: str

0 comments on commit 3c08f46

Please sign in to comment.