Skip to content

Commit

Permalink
rewrite API keys implementation
Browse files Browse the repository at this point in the history
- add RBAC support
- support multiple API keys
- keep backwards compatibility
- easy extension to other authentication types
  • Loading branch information
SecretiveShell committed Sep 6, 2024
1 parent d34756d commit e0bc35f
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 192 deletions.
319 changes: 181 additions & 138 deletions common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,159 +6,202 @@
import secrets
import yaml
from fastapi import Header, HTTPException, Request
from pydantic import BaseModel
from pydantic import BaseModel, Field
from loguru import logger
from typing import Optional

from common.utils import coalesce


class AuthKeys(BaseModel):
"""
This class represents the authentication keys for the application.
It contains two types of keys: 'api_key' and 'admin_key'.
The 'api_key' is used for general API calls, while the 'admin_key'
is used for administrative tasks. The class also provides a method
to verify if a given key matches the stored 'api_key' or 'admin_key'.
"""

api_key: str
admin_key: str

def verify_key(self, test_key: str, key_type: str):
"""Verify if a given key matches the stored key."""
if key_type == "admin_key":
return test_key == self.admin_key
if key_type == "api_key":
# Admin keys are valid for all API calls
return test_key == self.api_key or test_key == self.admin_key
return False
from typing import Optional, Union
from enum import Flag, auto
from abc import ABC, abstractmethod

from common.utils import coalesce, unwrap

# Global auth constants
AUTH_KEYS: Optional[AuthKeys] = None
DISABLE_AUTH: bool = False
__all__ = ["ROLE", "auth"]


def load_auth_keys(disable_from_config: bool):
"""Load the authentication keys from api_tokens.yml. If the file does not
exist, generate new keys and save them to api_tokens.yml."""
global AUTH_KEYS
global DISABLE_AUTH
# RBAC roles
class ROLE(Flag):
USER = auto()
ADMIN = auto()

DISABLE_AUTH = disable_from_config
if disable_from_config:
logger.warning(
"Disabling authentication makes your instance vulnerable. "
"Set the `disable_auth` flag to False in config.yml if you "
"want to share this instance with others."
)

return
class API_KEY(BaseModel):
"""stores an API key"""

key: str = Field(..., description="the API key value")
role: ROLE = Field()


class AUTH_PROVIDER(ABC):
@staticmethod
def add_api_key(role: ROLE) -> API_KEY:
"""add an API key"""

@staticmethod
def set_api_key(role: ROLE, api_key: str) -> API_KEY:
"""add an existing API key"""

@staticmethod
def remove_api_key(api_key: str) -> bool:
"""remove an API key"""

@staticmethod
def check_api_key(api_key: str) -> Union[API_KEY, None]:
"""check if an API key is valid"""

@staticmethod
def authenticate_api_key(api_key: str, role: ROLE) -> bool:
"""check if an api key has ROLE"""


try:
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
auth_keys_dict = yaml.safe_load(auth_file)
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
except FileNotFoundError:
new_auth_keys = AuthKeys(
api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16)
class SIMPLE_AUTH_PROVIDER(AUTH_PROVIDER):
api_keys: list[API_KEY] = []

def __init__(self) -> None:
try:
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
keys_dict: dict = yaml.safe_load(auth_file)

# load legacy keys
admin_key = keys_dict.get("admin_key")
if admin_key:
self.set_api_key(ROLE.ADMIN, admin_key)

admin_key = keys_dict.get("api_key")
if admin_key:
self.set_api_key(ROLE.USER, admin_key)

# load new keys
admin_keys = keys_dict.get("admin_keys")
if admin_keys:
for key in admin_keys:
self.set_api_key(ROLE.ADMIN, key)

user_keys = keys_dict.get("user_keys")
if admin_keys:
for key in admin_keys:
self.set_api_key(ROLE.ADMIN, key)

except FileNotFoundError:
file = {
"admin_keys": [
self.add_api_key(ROLE.ADMIN),
],
"user_keys": [
self.add_api_key(ROLE.USER),
],
}

with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(file, auth_file, default_flow_style=False)

logger.info("API keys:")
for key in self.api_keys:
logger.info(f"{key.role.name} :\t {key.key}")
logger.info(
"If these keys get compromised, make sure to delete api_tokens.yml and restart the server. Have fun!"
)
AUTH_KEYS = new_auth_keys

with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)

logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"
f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
"If these keys get compromised, make sure to delete api_tokens.yml "
"and restart the server. Have fun!"
)


def get_key_permission(request: Request):
"""
Gets the key permission from a request.
Internal only! Use the depends functions for incoming requests.
"""

# Give full admin permissions if auth is disabled
if DISABLE_AUTH:
return "admin"

# Hyphens are okay here
test_key = coalesce(
request.headers.get("x-admin-key"),
request.headers.get("x-api-key"),
request.headers.get("authorization"),
)

if test_key is None:
raise ValueError("The provided authentication key is missing.")

if test_key.lower().startswith("bearer"):
test_key = test_key.split(" ")[1]

if AUTH_KEYS.verify_key(test_key, "admin_key"):
return "admin"
elif AUTH_KEYS.verify_key(test_key, "api_key"):
return "api"
else:
raise ValueError("The provided authentication key is invalid.")


async def check_api_key(
x_api_key: str = Header(None), authorization: str = Header(None)
):
"""Check if the API key is valid."""

# Allow request if auth is disabled
if DISABLE_AUTH:
return

if x_api_key:
if not AUTH_KEYS.verify_key(x_api_key, "api_key"):
raise HTTPException(401, "Invalid API key")
return x_api_key

if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2:
raise HTTPException(401, "Invalid API key")
if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
split_key[1], "api_key"
):
raise HTTPException(401, "Invalid API key")

return authorization
def add_api_key(self, role: ROLE) -> API_KEY:
return self.set_api_key(key=secrets.token_hex(16), role=role)

raise HTTPException(401, "Please provide an API key")
def set_api_key(self, role: ROLE, api_key: str) -> API_KEY:
key = API_KEY(key=api_key, role=role)
self.api_keys.append(key)
return key

def remove_api_key(self, api_key: str) -> bool:
for key in self.api_keys:
if key.key == api_key:
self.api_keys.remove(key)
return True
return False

async def check_admin_key(
x_admin_key: str = Header(None), authorization: str = Header(None)
):
"""Check if the admin key is valid."""
def check_api_key(self, api_key: str) -> Union[API_KEY, None]:
for key in self.api_keys:
if key.key == api_key:
return key
return None

# Allow request if auth is disabled
if DISABLE_AUTH:
return
def authenticate_api_key(self, api_key: str, role: ROLE) -> bool:
key = self.check_api_key(api_key)
print(f"#### {key=}")
if not key:
return False
return key.role & role # if key.role in role

if x_admin_key:
if not AUTH_KEYS.verify_key(x_admin_key, "admin_key"):
raise HTTPException(401, "Invalid admin key")
return x_admin_key

if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2:
raise HTTPException(401, "Invalid admin key")
if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
split_key[1], "admin_key"
class NOAUTH_AUTH_PROVIDER(AUTH_PROVIDER):
def add_api_key(self, role: ROLE) -> API_KEY:
return API_KEY(key=secrets.token_hex(16), role=role)

def set_api_key(self, role: ROLE, api_key: str) -> API_KEY:
return API_KEY(key=secrets.token_hex(16), role=role)

def remove_api_key(self, api_key: str) -> bool:
return True

def check_api_key(self, api_key: str) -> Union[API_KEY, None]:
return API_KEY(key=secrets.token_hex(16), role=ROLE.ADMIN)

def authenticate_api_key(self, api_key: str, role: ROLE) -> bool:
return True


class AUTH_PROVIDER_CONTAINER:
provider: AUTH_PROVIDER

def load(self, disable_from_config: bool):
"""Load the authentication keys from api_tokens.yml. If the file does not
exist, generate new keys and save them to api_tokens.yml."""

# TODO: Make provider a paramater instead of disable_from_config
provider = "noauth" if disable_from_config else "simple"

# allows for more types of providers
provider_class = {
"noauth": NOAUTH_AUTH_PROVIDER,
"simple": SIMPLE_AUTH_PROVIDER,
}.get(provider)

if not provider_class:
raise Exception()

if provider_class == NOAUTH_AUTH_PROVIDER:
logger.warning(
"Disabling authentication makes your instance vulnerable. "
"Set the `disable_auth` flag to False in config.yml if you "
"want to share this instance with others."
)

self.provider = provider_class()

# by returning a dynamic dependency we can have one function where we can specify what roles can access the endpoint
def check_api_key(self, role: ROLE):
"""Check if the API key is valid."""

async def check(
x_api_key: str = Header(None), authorization: str = Header(None)
):
raise HTTPException(401, "Invalid admin key")
return authorization
if x_api_key:
if not self.provider.authenticate_api_key(x_api_key, role):
raise HTTPException(401, "Invalid API key")
return x_api_key

if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2:
raise HTTPException(401, "Invalid API key")
if split_key[
0
].lower() != "bearer" or not self.provider.authenticate_api_key(
split_key[1], role
):
raise HTTPException(401, "Invalid API key")

return authorization

raise HTTPException(401, "Please provide an API key")

return check


raise HTTPException(401, "Please provide an admin key")
auth = AUTH_PROVIDER_CONTAINER()
Loading

0 comments on commit e0bc35f

Please sign in to comment.