From 61f64033c32e1ebc79cbe89c4d03ccb1162cddb2 Mon Sep 17 00:00:00 2001 From: kir-gadjello <111190790+kir-gadjello@users.noreply.github.com> Date: Wed, 1 May 2024 17:13:28 -0300 Subject: [PATCH] Added support for multiple API keys, generating OAI-like tokens by default. Implements #79 --- common/auth.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/common/auth.py b/common/auth.py index fa532622..642f49ef 100644 --- a/common/auth.py +++ b/common/auth.py @@ -3,24 +3,26 @@ application, it should be fine. """ +import string import secrets import yaml from fastapi import Header, HTTPException from pydantic import BaseModel from loguru import logger -from typing import Optional +from typing import Optional, List class AuthKeys(BaseModel): """ This class represents the authentication keys for the application. - It contains two types of keys: 'api_key' and 'admin_key'. - The 'api_key' is used for general API calls, while the 'admin_key' + It contains two types of keys: 'api_key'/'api_keys' and 'admin_key'. + The 'api_key'/'api_keys' are used for general API calls, while the 'admin_key' is used for administrative tasks. The class also provides a method - to verify if a given key matches the stored 'api_key' or 'admin_key'. + to verify if a given key matches the stored 'api_key'/'api_keys' or 'admin_key'. """ api_key: str + api_keys: Optional[List[str]] = None admin_key: str def verify_key(self, test_key: str, key_type: str): @@ -28,6 +30,8 @@ def verify_key(self, test_key: str, key_type: str): if key_type == "admin_key": return test_key == self.admin_key if key_type == "api_key": + if isinstance(self.api_keys, list) and test_key in self.api_keys: + return True # Admin keys are valid for all API calls return test_key == self.api_key or test_key == self.admin_key return False @@ -38,6 +42,20 @@ def verify_key(self, test_key: str, key_type: str): DISABLE_AUTH: bool = False +def gen_rand_ascii(length): + chars = string.ascii_letters + string.digits + nchars = len(chars) + secure_token = secrets.token_bytes(length) + return "".join(map(lambda b: chars[(b % nchars)], secure_token)) + + +# some apps check this regexp https://github.com/secretlint/secretlint/issues/676 +def gen_oai_like_key(): + prefix = "sk-" + suffix = "T3BlbkFJ" + return f"{prefix}{gen_rand_ascii(20)}{suffix}{gen_rand_ascii(20)}" + + def load_auth_keys(disable_from_config: bool): """Load the authentication keys from api_tokens.yml. If the file does not exist, generate new keys and save them to api_tokens.yml.""" @@ -60,15 +78,21 @@ def load_auth_keys(disable_from_config: bool): AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict) except FileNotFoundError: new_auth_keys = AuthKeys( - api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16) + api_keys=[gen_oai_like_key()], + api_key=gen_oai_like_key(), + admin_key=secrets.token_hex(16), ) AUTH_KEYS = new_auth_keys with open("api_tokens.yml", "w", encoding="utf8") as auth_file: yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False) + multiple_keys_msg = "" + if isinstance(AUTH_KEYS.api_keys, list): + multiple_keys_msg = f"Your additional API keys are: {'\n'.join(AUTH_KEYS.api_keys)}\n" + logger.info( - f"Your API key is: {AUTH_KEYS.api_key}\n" + f"Your API key is: {AUTH_KEYS.api_key}\n{multiple_keys_msg}" f"Your admin key is: {AUTH_KEYS.admin_key}\n\n" "If these keys get compromised, make sure to delete api_tokens.yml " "and restart the server. Have fun!"