Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for multiple API keys, generating OAI-like tokens by default #99

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,35 @@
application, it should be fine.
"""

import string
import secrets
import yaml
from fastapi import Header, HTTPException
from pydantic import BaseModel
from loguru import logger
from typing import Optional
from typing import Optional, List


class AuthKeys(BaseModel):
"""
This class represents the authentication keys for the application.
It contains two types of keys: 'api_key' and 'admin_key'.
The 'api_key' is used for general API calls, while the 'admin_key'
It contains two types of keys: 'api_key'/'api_keys' and 'admin_key'.
The 'api_key'/'api_keys' are used for general API calls, while the 'admin_key'
is used for administrative tasks. The class also provides a method
to verify if a given key matches the stored 'api_key' or 'admin_key'.
to verify if a given key matches the stored 'api_key'/'api_keys' or 'admin_key'.
"""

api_key: str
api_keys: Optional[List[str]] = None
admin_key: str

def verify_key(self, test_key: str, key_type: str):
"""Verify if a given key matches the stored key."""
if key_type == "admin_key":
return test_key == self.admin_key
if key_type == "api_key":
if isinstance(self.api_keys, list) and test_key in self.api_keys:
return True
# Admin keys are valid for all API calls
return test_key == self.api_key or test_key == self.admin_key
return False
Expand All @@ -38,6 +42,20 @@ def verify_key(self, test_key: str, key_type: str):
DISABLE_AUTH: bool = False


def gen_rand_ascii(length):
chars = string.ascii_letters + string.digits
nchars = len(chars)
secure_token = secrets.token_bytes(length)
return "".join(map(lambda b: chars[(b % nchars)], secure_token))


# some apps check this regexp https://github.com/secretlint/secretlint/issues/676
def gen_oai_like_key():
prefix = "sk-"
suffix = "T3BlbkFJ"
return f"{prefix}{gen_rand_ascii(20)}{suffix}{gen_rand_ascii(20)}"


def load_auth_keys(disable_from_config: bool):
"""Load the authentication keys from api_tokens.yml. If the file does not
exist, generate new keys and save them to api_tokens.yml."""
Expand All @@ -60,15 +78,21 @@ def load_auth_keys(disable_from_config: bool):
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
except FileNotFoundError:
new_auth_keys = AuthKeys(
api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16)
api_keys=[gen_oai_like_key()],
api_key=gen_oai_like_key(),
admin_key=secrets.token_hex(16),
)
AUTH_KEYS = new_auth_keys

with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)

multiple_keys_msg = ""
if isinstance(AUTH_KEYS.api_keys, list):
multiple_keys_msg = f"Your additional API keys are: {'\n'.join(AUTH_KEYS.api_keys)}\n"

logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"
f"Your API key is: {AUTH_KEYS.api_key}\n{multiple_keys_msg}"
f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
"If these keys get compromised, make sure to delete api_tokens.yml "
"and restart the server. Have fun!"
Expand Down