-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
bot: Start implementing new FastAPI-based server
- Loading branch information
1 parent
9504243
commit f6d4cc6
Showing
24 changed files
with
1,132 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import TYPE_CHECKING | ||
|
||
from fastapi import HTTPException, Request | ||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR | ||
|
||
if TYPE_CHECKING: | ||
from bot.database import Database | ||
|
||
|
||
def database_dependency(request: Request) -> 'Database': | ||
""" | ||
FastAPI dependency to get the database object. | ||
Args: | ||
request (web.Request): The request. | ||
Returns: | ||
Database: The database object. | ||
""" | ||
|
||
state = request.app.state | ||
if not hasattr(state, 'database'): | ||
raise HTTPException( | ||
status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail='No database connection' | ||
) | ||
|
||
database: 'Database' = state.database | ||
if database is None: | ||
raise HTTPException( | ||
status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail='No database connection' | ||
) | ||
|
||
return database |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
from fastapi import Depends, HTTPException, Request | ||
from starlette.status import HTTP_401_UNAUTHORIZED | ||
|
||
from .database import database_dependency | ||
|
||
if TYPE_CHECKING: | ||
from bot.api.utils.session import SessionManager | ||
from bot.database import Database | ||
from bot.models.oauth import OAuth | ||
|
||
|
||
EXPECTED_AUTH_SCHEME = 'Bearer' | ||
EXPECTED_AUTH_PARTS = 2 | ||
|
||
|
||
def session_dependency( | ||
request: Request, db: 'Database' = Depends(database_dependency) | ||
) -> 'OAuth': | ||
""" | ||
FastAPI dependency to get the requesting user's info. | ||
Args: | ||
request (web.Request): The request. | ||
Returns: | ||
OAuth: The info for the current Discord user. | ||
""" | ||
|
||
authorization = request.headers.get('Authorization') | ||
if authorization is None: | ||
raise HTTPException( | ||
status_code=HTTP_401_UNAUTHORIZED, detail='No authorization header' | ||
) | ||
|
||
parts = authorization.split() | ||
if len(parts) != EXPECTED_AUTH_PARTS: | ||
raise HTTPException( | ||
status_code=HTTP_401_UNAUTHORIZED, detail='Invalid authorization header' | ||
) | ||
|
||
scheme, token = parts | ||
if scheme != EXPECTED_AUTH_SCHEME: | ||
raise HTTPException( | ||
status_code=HTTP_401_UNAUTHORIZED, detail='Invalid authorization scheme' | ||
) | ||
|
||
session_manager: 'SessionManager' = request.app.state.session_manager | ||
session = session_manager.decode_session(token) | ||
if session is None: | ||
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail='Invalid session') | ||
|
||
user: Optional['OAuth'] = db.get_oauth('discord', session.user_id) | ||
if user is None: | ||
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail='User not found') | ||
|
||
return user |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
""" | ||
Nextcord extension that runs the API server for the bot | ||
""" | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
from .main import run_app | ||
|
||
if TYPE_CHECKING: | ||
from bot.utils.blanco import BlancoBot | ||
|
||
|
||
def setup(bot: 'BlancoBot'): | ||
""" | ||
Run the API server within the bot's existing event loop. | ||
""" | ||
run_app(bot.loop, bot.database) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
""" | ||
Main module for the API server. | ||
""" | ||
|
||
from asyncio import set_event_loop | ||
from contextlib import asynccontextmanager | ||
from logging import INFO | ||
from typing import TYPE_CHECKING, Any, Optional | ||
|
||
from fastapi import FastAPI | ||
from uvicorn import Config, Server, run | ||
from uvicorn.config import LOGGING_CONFIG | ||
|
||
from bot.database import Database | ||
from bot.utils.config import config as bot_config | ||
from bot.utils.logger import DATE_FMT_STR, LOG_FMT_COLOR, create_logger | ||
|
||
from .routes.account import account_router | ||
from .routes.oauth import oauth_router | ||
from .utils.session import SessionManager | ||
|
||
if TYPE_CHECKING: | ||
from asyncio import AbstractEventLoop | ||
|
||
|
||
_database: Optional[Database] = None | ||
|
||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI): | ||
logger = create_logger('api.lifespan') | ||
|
||
if _database is None: | ||
logger.warn('Manually creating database connection') | ||
database = Database(bot_config.db_file) | ||
else: | ||
logger.info('Connecting to database from FastAPI') | ||
database = _database | ||
|
||
app.state.database = database | ||
app.state.session_manager = SessionManager(database) | ||
yield | ||
|
||
|
||
app = FastAPI(lifespan=lifespan) | ||
app.include_router(account_router) | ||
app.include_router(oauth_router) | ||
|
||
|
||
@app.get('/') | ||
async def health_check(): | ||
return {'status': 'ok'} | ||
|
||
|
||
def _get_log_config() -> dict[str, Any]: | ||
log_config = LOGGING_CONFIG | ||
log_config['formatters']['default']['fmt'] = LOG_FMT_COLOR[INFO] | ||
log_config['formatters']['default']['datefmt'] = DATE_FMT_STR | ||
log_config['formatters']['access']['fmt'] = LOG_FMT_COLOR[INFO] | ||
|
||
return log_config | ||
|
||
|
||
def run_app(loop: 'AbstractEventLoop', db: Database): | ||
""" | ||
Run the API server in the bot's event loop. | ||
""" | ||
global _database # noqa: PLW0603 | ||
_database = db | ||
|
||
set_event_loop(loop) | ||
|
||
config = Config( | ||
app=app, | ||
loop=loop, # type: ignore | ||
host='0.0.0.0', | ||
port=bot_config.server_port, | ||
log_config=_get_log_config(), | ||
) | ||
server = Server(config) | ||
|
||
loop.create_task(server.serve()) | ||
|
||
|
||
if __name__ == '__main__': | ||
run( | ||
app='bot.api.main:app', | ||
host='127.0.0.1', | ||
port=bot_config.server_port, | ||
reload=True, | ||
reload_dirs=['bot/api'], | ||
log_config=_get_log_config(), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import Optional | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class AccountResponse(BaseModel): | ||
username: str = Field(description="The user's username.") | ||
spotify_logged_in: bool = Field( | ||
description='Whether the user is logged in to Spotify.' | ||
) | ||
spotify_username: Optional[str] = Field( | ||
default=None, description="The user's Spotify username, if logged in." | ||
) | ||
lastfm_logged_in: bool = Field( | ||
description='Whether the user is logged in to Last.fm.' | ||
) | ||
lastfm_username: Optional[str] = Field( | ||
default=None, description="The user's Last.fm username, if logged in." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from typing import Optional | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class OAuthResponse(BaseModel): | ||
session_id: str = Field(description='The session ID for the user.') | ||
jwt: str = Field(description='The JSON Web Token for the user.') | ||
|
||
|
||
class DiscordUser(BaseModel): | ||
id: int = Field(description='The user ID.') | ||
username: str = Field(description='The username.') | ||
discriminator: str = Field(description='The discriminator.') | ||
avatar: Optional[str] = Field( | ||
default=None, description='The avatar hash, if the user has one.' | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from pydantic import BaseModel | ||
|
||
|
||
class Session(BaseModel): | ||
user_id: int | ||
session_id: str | ||
expiration_time: int |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from fastapi import APIRouter | ||
|
||
from .login import get_login_url as route_login | ||
from .me import get_logged_in_user as route_me | ||
|
||
account_router = APIRouter(prefix='/account', tags=['account']) | ||
account_router.add_api_route('/login', route_login, methods=['GET']) | ||
account_router.add_api_route('/me', route_me, methods=['GET']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from secrets import token_urlsafe | ||
|
||
from fastapi import HTTPException | ||
from fastapi.responses import RedirectResponse | ||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR | ||
from yarl import URL | ||
|
||
from bot.utils.config import config as bot_config | ||
|
||
|
||
async def get_login_url() -> RedirectResponse: | ||
oauth_id = bot_config.discord_oauth_id | ||
base_url = bot_config.base_url | ||
|
||
if oauth_id is None or base_url is None: | ||
raise HTTPException( | ||
status_code=HTTP_500_INTERNAL_SERVER_ERROR, | ||
detail='Missing Discord OAuth ID or base URL', | ||
) | ||
|
||
state = token_urlsafe(16) | ||
|
||
url = URL.build( | ||
scheme='https', | ||
host='discord.com', | ||
path='/api/oauth2/authorize', | ||
query={ | ||
'client_id': oauth_id, | ||
'response_type': 'code', | ||
'scope': 'identify guilds email', | ||
'redirect_uri': f'{base_url}/oauth/discord', | ||
'state': state, | ||
'prompt': 'none', | ||
}, | ||
) | ||
|
||
response = RedirectResponse(url=str(url)) | ||
response.set_cookie('state', state, httponly=True, samesite='lax') | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
Route for getting the current user's account information. | ||
""" | ||
|
||
from typing import TYPE_CHECKING, Optional | ||
|
||
from fastapi import Depends | ||
|
||
from bot.api.depends.database import database_dependency | ||
from bot.api.depends.session import session_dependency | ||
from bot.api.models.account import AccountResponse | ||
|
||
if TYPE_CHECKING: | ||
from bot.database import Database | ||
from bot.models.oauth import LastfmAuth, OAuth | ||
|
||
|
||
async def get_logged_in_user( | ||
user: 'OAuth' = Depends(session_dependency), | ||
db: 'Database' = Depends(database_dependency), | ||
) -> AccountResponse: | ||
spotify_username = None | ||
spotify: Optional['OAuth'] = db.get_oauth('spotify', user.user_id) | ||
if spotify is not None: | ||
spotify_username = spotify.username | ||
|
||
lastfm_username = None | ||
lastfm: Optional['LastfmAuth'] = db.get_lastfm_credentials(user.user_id) | ||
if lastfm is not None: | ||
lastfm_username = lastfm.username | ||
|
||
return AccountResponse( | ||
username=user.username, | ||
spotify_logged_in=spotify is not None, | ||
spotify_username=spotify_username, | ||
lastfm_logged_in=lastfm is not None, | ||
lastfm_username=lastfm_username, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from fastapi import APIRouter | ||
|
||
from .discord import discord_oauth as route_discord | ||
|
||
oauth_router = APIRouter(prefix='/oauth', tags=['oauth']) | ||
oauth_router.add_api_route('/discord', route_discord, methods=['GET']) |
Oops, something went wrong.