Skip to content

Commit

Permalink
Dynamic mirror round robin auto-weighting for downloads (#3)
Browse files Browse the repository at this point in the history
* Mirror prioritization via request tracking in DB

* better tracking

* Cache mirror score in app mem for 5mins

* comment todo

* Rewrite prioritization

* Remove memes

* Better handling of 404 case

* rm print

* Cleanup osz2 served log

* Log mirror weight

* bit better log

* Add TODO for mirror_cache_age

* Remove unused TimedOut

* remove unnecessary pattern

* Reduce unncessary code a bit

* Deploy to prod

* Update healthcheck settings to boot up faster

* Debug logs

* Remove 408 case

* Stop takeover for merge
  • Loading branch information
cmyui authored Jun 22, 2024
1 parent a1aa85c commit 1aef3bd
Show file tree
Hide file tree
Showing 15 changed files with 284 additions and 87 deletions.
6 changes: 6 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,9 @@ CODE_HOTRELOAD=

OSU_API_V2_CLIENT_ID=
OSU_API_V2_CLIENT_SECRET=

DB_USER=cmyui
DB_PASS=lol123
DB_HOST=localhost
DB_PORT=3306
DB_NAME=akatsuki
1 change: 1 addition & 0 deletions app/adapters/beatmap_mirrors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class BeatmapMirror(ABC):

def __init__(self, *args: Any, **kwargs: Any) -> None:
self.http_client = httpx.AsyncClient()
self.weight = 0
super().__init__(*args, **kwargs)

@abstractmethod
Expand Down
3 changes: 3 additions & 0 deletions app/adapters/beatmap_mirrors/mino.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging

import httpx

from app.adapters.beatmap_mirrors import BeatmapMirror


Expand All @@ -12,6 +14,7 @@ async def fetch_beatmap_zip_data(self, beatmapset_id: int) -> bytes | None:
logging.info(f"Fetching beatmapset osz2 from mino: {beatmapset_id}")
response = await self.http_client.get(
f"{self.base_url}/d/{beatmapset_id}",
timeout=httpx.Timeout(None, connect=2),
)
response.raise_for_status()
return response.read()
Expand Down
163 changes: 81 additions & 82 deletions app/adapters/beatmap_mirrors/mirror_aggregate.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,103 @@
import asyncio
import logging
import random
import time
from datetime import datetime

from app.adapters.beatmap_mirrors import BeatmapMirror
from app.adapters.beatmap_mirrors.gatari import GatariMirror
from app.adapters.beatmap_mirrors.mino import MinoMirror
from app.adapters.beatmap_mirrors.nerinyan import NerinyanMirror
from app.adapters.beatmap_mirrors.osu_direct import OsuDirectMirror
from app.adapters.beatmap_mirrors.ripple import RippleMirror
from app.repositories import beatmap_mirror_requests
from app.scheduling import DynamicWeightedRoundRobin

BEATMAP_MIRRORS: list[BeatmapMirror] = [
# GatariMirror(),
# MinoMirror(),
NerinyanMirror(),
OsuDirectMirror(),
# Disabled as ripple only supports ranked maps
# RippleMirror(),
]
# from app.adapters.beatmap_mirrors.gatari import GatariMirror
# from app.adapters.beatmap_mirrors.mino import MinoMirror
# from app.adapters.beatmap_mirrors.ripple import RippleMirror

ZIP_FILE_HEADER = b"PK\x03\x04"

async def run_with_semaphore(
semaphore: asyncio.Semaphore,
mirror: BeatmapMirror,
beatmapset_id: int,
) -> tuple[BeatmapMirror, bytes | None]:
async with semaphore:
return (mirror, await mirror.fetch_beatmap_zip_data(beatmapset_id))
BEATMAP_SELECTOR = DynamicWeightedRoundRobin(
mirrors=[
# GatariMirror(),
# MinoMirror(),
NerinyanMirror(),
OsuDirectMirror(),
# Disabled as ripple only supports ranked maps
# RippleMirror(),
],
)


class TimedOut: ...


TIMED_OUT = TimedOut()


async def fetch_beatmap_zip_data(beatmapset_id: int) -> bytes | TimedOut | None:
async def fetch_beatmap_zip_data(beatmapset_id: int) -> bytes | None:
"""\
Parallelize calls with a timeout across up to 5 mirrors at time,
to ensure our clients get a response in a reasonable time.
Fetch a beatmapset .osz2 file by any means necessary, balancing upon
multiple underlying beatmap mirrors to ensure the best possible
availability and performance.
"""
started_at = datetime.now()

await BEATMAP_SELECTOR.update_all_mirror_and_selector_weights()

while True:
mirror = BEATMAP_SELECTOR.select_mirror()
beatmap_zip_data: bytes | None = None
try:
beatmap_zip_data = await mirror.fetch_beatmap_zip_data(beatmapset_id)

if beatmap_zip_data is not None and (
not beatmap_zip_data.startswith(ZIP_FILE_HEADER)
or len(beatmap_zip_data) < 20_000
):
raise ValueError("Received bad osz2 data from mirror")
except Exception as exc:
ended_at = datetime.now()
await beatmap_mirror_requests.create(
request_url=f"{mirror.base_url}/d/{beatmapset_id}",
api_key_id=None,
mirror_name=mirror.name,
success=False,
started_at=started_at,
ended_at=ended_at,
response_size=len(beatmap_zip_data) if beatmap_zip_data else 0,
response_error=str(exc),
)
await BEATMAP_SELECTOR.update_all_mirror_and_selector_weights()
logging.warning(
"Failed to fetch beatmapset osz2 from mirror",
exc_info=True,
extra={
"mirror_name": mirror.name,
"mirror_weight": mirror.weight,
"beatmapset_id": beatmapset_id,
},
)
continue
else:
break

ended_at = datetime.now()

await beatmap_mirror_requests.create(
request_url=f"{mirror.base_url}/d/{beatmapset_id}",
api_key_id=None,
mirror_name=mirror.name,
success=True,
started_at=started_at,
ended_at=ended_at,
response_size=len(beatmap_zip_data) if beatmap_zip_data else 0,
response_error=None,
)
await BEATMAP_SELECTOR.update_all_mirror_and_selector_weights()

# TODO: it would be nice to be able to stream the responses,
# but that would require a different approach where the
# discovery process would be complete once the mirror has
# started streaming, rather than after the response has
# been read in full.

concurrency_limit = 5
global_timeout = 15
semaphore = asyncio.Semaphore(concurrency_limit)

start_time = time.time()

# TODO: prioritization based on reliability, speed, etc.
random.shuffle(BEATMAP_MIRRORS)

coroutines = [
asyncio.create_task(
run_with_semaphore(
semaphore,
mirror,
beatmapset_id,
),
)
for mirror in BEATMAP_MIRRORS
]
try:
done, pending = await asyncio.wait(
coroutines,
return_when=asyncio.FIRST_COMPLETED,
timeout=global_timeout,
)
for task in pending:
task.cancel()
first_result = await list(done)[0]
except TimeoutError:
return None

# TODO: log which mirrors finished, and which timed out

mirror, result = first_result
if result is None:
return None

end_time = time.time()
ms_elapsed = (end_time - start_time) * 1000
ms_elapsed = (ended_at.timestamp() - started_at.timestamp()) * 1000

logging.info(
"A mirror was first to finish during .osz2 aggregate request",
"Served beatmapset osz2 from mirror",
extra={
"mirror_name": mirror.name,
"mirror_weight": mirror.weight,
"beatmapset_id": beatmapset_id,
"ms_elapsed": ms_elapsed,
"data_size": len(result),
"bad_data": (
result
if not result.startswith(b"PK\x03\x04") or len(result) < 20_000
else None
"data_size": (
len(beatmap_zip_data) if beatmap_zip_data is not None else None
),
},
)
return result
return beatmap_zip_data
3 changes: 3 additions & 0 deletions app/adapters/beatmap_mirrors/osu_direct.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging

import httpx

from app.adapters.beatmap_mirrors import BeatmapMirror


Expand All @@ -12,6 +14,7 @@ async def fetch_beatmap_zip_data(self, beatmapset_id: int) -> bytes | None:
logging.info(f"Fetching beatmapset osz2 from osu!direct: {beatmapset_id}")
response = await self.http_client.get(
f"{self.base_url}/d/{beatmapset_id}",
timeout=httpx.Timeout(None, connect=2),
)
response.raise_for_status()
return response.read()
Expand Down
14 changes: 14 additions & 0 deletions app/adapters/mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import urllib.parse


def create_dsn(
driver: str | None,
username: str,
password: str,
host: str,
port: int | None,
database: str,
) -> str:
driver_str = f"+{driver}" if driver else ""
passwd_str = urllib.parse.quote_plus(password) if password else ""
return f"mysql{driver_str}://{username}:{passwd_str}@{host}:{port}/{database}"
2 changes: 0 additions & 2 deletions app/api/v1/osz2_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,5 @@ async def download_beatmapset_osz2(beatmapset_id: int) -> Response:
"Content-Disposition": f"attachment; filename={beatmapset_id}.osz",
},
)
if isinstance(beatmap_zip_data, mirror_aggregate.TimedOut):
return Response(status_code=408)

return Response(status_code=404)
28 changes: 28 additions & 0 deletions app/init_api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager

from databases import Database
from fastapi import FastAPI
from fastapi import Request
from fastapi import Response
from starlette.middleware.base import RequestResponseEndpoint

from app import settings
from app import state
from app.adapters import mysql
from app.api import api_router


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
await state.database.connect()
yield
await state.database.disconnect()


def init_routes(app: FastAPI) -> FastAPI:
app.include_router(api_router)
return app
Expand All @@ -29,15 +41,31 @@ async def http_middleware(
return app


def init_db(app: FastAPI) -> FastAPI:
state.database = Database(
url=mysql.create_dsn(
driver="aiomysql",
username=settings.DB_USER,
password=settings.DB_PASS,
host=settings.DB_HOST,
port=settings.DB_PORT,
database=settings.DB_NAME,
),
)
return app


def init_api() -> FastAPI:
app = FastAPI(
openapi_url="/openapi.json" if settings.APP_ENV != "production" else None,
docs_url="/docs" if settings.APP_ENV != "production" else None,
redoc_url="/redoc" if settings.APP_ENV != "production" else None,
swagger_ui_oauth2_redirect_url=None,
lifespan=lifespan,
)
app = init_routes(app)
app = init_middleware(app)
app = init_db(app)
return app


Expand Down
Empty file added app/repositories/__init__.py
Empty file.
Loading

0 comments on commit 1aef3bd

Please sign in to comment.