Skip to content

Commit

Permalink
bot: Start implementing new FastAPI-based server
Browse files Browse the repository at this point in the history
  • Loading branch information
jareddantis committed Mar 31, 2024
1 parent 9504243 commit f6d4cc6
Show file tree
Hide file tree
Showing 24 changed files with 1,132 additions and 53 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
"--disable=too-many-return-statements",
"--disable=too-many-branches"
],
"editor.formatOnSave": true
"editor.formatOnSave": true,
"editor.defaultFormatter": "charliermarsh.ruff"
}
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ install:
dev-frontend: config.yml blanco.db
poetry run python -m bot.dev_server

dev-backend: config.yml blanco.db
poetry run python -m bot.api.main

dev: config.yml blanco.db
poetry run python -m bot.main

Expand Down
33 changes: 33 additions & 0 deletions bot/api/depends/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import TYPE_CHECKING

from fastapi import HTTPException, Request
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR

if TYPE_CHECKING:
from bot.database import Database


def database_dependency(request: Request) -> 'Database':
"""
FastAPI dependency to get the database object.
Args:
request (web.Request): The request.
Returns:
Database: The database object.
"""

state = request.app.state
if not hasattr(state, 'database'):
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail='No database connection'
)

database: 'Database' = state.database
if database is None:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail='No database connection'
)

return database
58 changes: 58 additions & 0 deletions bot/api/depends/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import TYPE_CHECKING, Optional

from fastapi import Depends, HTTPException, Request
from starlette.status import HTTP_401_UNAUTHORIZED

from .database import database_dependency

if TYPE_CHECKING:
from bot.api.utils.session import SessionManager
from bot.database import Database
from bot.models.oauth import OAuth


EXPECTED_AUTH_SCHEME = 'Bearer'
EXPECTED_AUTH_PARTS = 2


def session_dependency(
request: Request, db: 'Database' = Depends(database_dependency)
) -> 'OAuth':
"""
FastAPI dependency to get the requesting user's info.
Args:
request (web.Request): The request.
Returns:
OAuth: The info for the current Discord user.
"""

authorization = request.headers.get('Authorization')
if authorization is None:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail='No authorization header'
)

parts = authorization.split()
if len(parts) != EXPECTED_AUTH_PARTS:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail='Invalid authorization header'
)

scheme, token = parts
if scheme != EXPECTED_AUTH_SCHEME:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail='Invalid authorization scheme'
)

session_manager: 'SessionManager' = request.app.state.session_manager
session = session_manager.decode_session(token)
if session is None:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail='Invalid session')

user: Optional['OAuth'] = db.get_oauth('discord', session.user_id)
if user is None:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail='User not found')

return user
17 changes: 17 additions & 0 deletions bot/api/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
Nextcord extension that runs the API server for the bot
"""

from typing import TYPE_CHECKING

from .main import run_app

if TYPE_CHECKING:
from bot.utils.blanco import BlancoBot


def setup(bot: 'BlancoBot'):
"""
Run the API server within the bot's existing event loop.
"""
run_app(bot.loop, bot.database)
93 changes: 93 additions & 0 deletions bot/api/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
Main module for the API server.
"""

from asyncio import set_event_loop
from contextlib import asynccontextmanager
from logging import INFO
from typing import TYPE_CHECKING, Any, Optional

from fastapi import FastAPI
from uvicorn import Config, Server, run
from uvicorn.config import LOGGING_CONFIG

from bot.database import Database
from bot.utils.config import config as bot_config
from bot.utils.logger import DATE_FMT_STR, LOG_FMT_COLOR, create_logger

from .routes.account import account_router
from .routes.oauth import oauth_router
from .utils.session import SessionManager

if TYPE_CHECKING:
from asyncio import AbstractEventLoop


_database: Optional[Database] = None


@asynccontextmanager
async def lifespan(app: FastAPI):
logger = create_logger('api.lifespan')

if _database is None:
logger.warn('Manually creating database connection')
database = Database(bot_config.db_file)
else:
logger.info('Connecting to database from FastAPI')
database = _database

app.state.database = database
app.state.session_manager = SessionManager(database)
yield


app = FastAPI(lifespan=lifespan)
app.include_router(account_router)
app.include_router(oauth_router)


@app.get('/')
async def health_check():
return {'status': 'ok'}


def _get_log_config() -> dict[str, Any]:
log_config = LOGGING_CONFIG
log_config['formatters']['default']['fmt'] = LOG_FMT_COLOR[INFO]
log_config['formatters']['default']['datefmt'] = DATE_FMT_STR
log_config['formatters']['access']['fmt'] = LOG_FMT_COLOR[INFO]

return log_config


def run_app(loop: 'AbstractEventLoop', db: Database):
"""
Run the API server in the bot's event loop.
"""
global _database # noqa: PLW0603
_database = db

set_event_loop(loop)

config = Config(
app=app,
loop=loop, # type: ignore
host='0.0.0.0',
port=bot_config.server_port,
log_config=_get_log_config(),
)
server = Server(config)

loop.create_task(server.serve())


if __name__ == '__main__':
run(
app='bot.api.main:app',
host='127.0.0.1',
port=bot_config.server_port,
reload=True,
reload_dirs=['bot/api'],
log_config=_get_log_config(),
)
19 changes: 19 additions & 0 deletions bot/api/models/account.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Optional

from pydantic import BaseModel, Field


class AccountResponse(BaseModel):
username: str = Field(description="The user's username.")
spotify_logged_in: bool = Field(
description='Whether the user is logged in to Spotify.'
)
spotify_username: Optional[str] = Field(
default=None, description="The user's Spotify username, if logged in."
)
lastfm_logged_in: bool = Field(
description='Whether the user is logged in to Last.fm.'
)
lastfm_username: Optional[str] = Field(
default=None, description="The user's Last.fm username, if logged in."
)
17 changes: 17 additions & 0 deletions bot/api/models/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Optional

from pydantic import BaseModel, Field


class OAuthResponse(BaseModel):
session_id: str = Field(description='The session ID for the user.')
jwt: str = Field(description='The JSON Web Token for the user.')


class DiscordUser(BaseModel):
id: int = Field(description='The user ID.')
username: str = Field(description='The username.')
discriminator: str = Field(description='The discriminator.')
avatar: Optional[str] = Field(
default=None, description='The avatar hash, if the user has one.'
)
7 changes: 7 additions & 0 deletions bot/api/models/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from pydantic import BaseModel


class Session(BaseModel):
user_id: int
session_id: str
expiration_time: int
8 changes: 8 additions & 0 deletions bot/api/routes/account/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from fastapi import APIRouter

from .login import get_login_url as route_login
from .me import get_logged_in_user as route_me

account_router = APIRouter(prefix='/account', tags=['account'])
account_router.add_api_route('/login', route_login, methods=['GET'])
account_router.add_api_route('/me', route_me, methods=['GET'])
39 changes: 39 additions & 0 deletions bot/api/routes/account/login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from secrets import token_urlsafe

from fastapi import HTTPException
from fastapi.responses import RedirectResponse
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
from yarl import URL

from bot.utils.config import config as bot_config


async def get_login_url() -> RedirectResponse:
oauth_id = bot_config.discord_oauth_id
base_url = bot_config.base_url

if oauth_id is None or base_url is None:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail='Missing Discord OAuth ID or base URL',
)

state = token_urlsafe(16)

url = URL.build(
scheme='https',
host='discord.com',
path='/api/oauth2/authorize',
query={
'client_id': oauth_id,
'response_type': 'code',
'scope': 'identify guilds email',
'redirect_uri': f'{base_url}/oauth/discord',
'state': state,
'prompt': 'none',
},
)

response = RedirectResponse(url=str(url))
response.set_cookie('state', state, httponly=True, samesite='lax')
return response
38 changes: 38 additions & 0 deletions bot/api/routes/account/me.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Route for getting the current user's account information.
"""

from typing import TYPE_CHECKING, Optional

from fastapi import Depends

from bot.api.depends.database import database_dependency
from bot.api.depends.session import session_dependency
from bot.api.models.account import AccountResponse

if TYPE_CHECKING:
from bot.database import Database
from bot.models.oauth import LastfmAuth, OAuth


async def get_logged_in_user(
user: 'OAuth' = Depends(session_dependency),
db: 'Database' = Depends(database_dependency),
) -> AccountResponse:
spotify_username = None
spotify: Optional['OAuth'] = db.get_oauth('spotify', user.user_id)
if spotify is not None:
spotify_username = spotify.username

lastfm_username = None
lastfm: Optional['LastfmAuth'] = db.get_lastfm_credentials(user.user_id)
if lastfm is not None:
lastfm_username = lastfm.username

return AccountResponse(
username=user.username,
spotify_logged_in=spotify is not None,
spotify_username=spotify_username,
lastfm_logged_in=lastfm is not None,
lastfm_username=lastfm_username,
)
6 changes: 6 additions & 0 deletions bot/api/routes/oauth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from fastapi import APIRouter

from .discord import discord_oauth as route_discord

oauth_router = APIRouter(prefix='/oauth', tags=['oauth'])
oauth_router.add_api_route('/discord', route_discord, methods=['GET'])
Loading

0 comments on commit f6d4cc6

Please sign in to comment.