From d7db3ef200d1747bec0a4802566fceabf31d290c Mon Sep 17 00:00:00 2001 From: mini <39670899+minisbett@users.noreply.github.com> Date: Thu, 15 Feb 2024 16:33:15 +0100 Subject: [PATCH 01/48] Change gatari.pw to configured download endpoint (#602) --- app/commands.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/app/commands.py b/app/commands.py index 279de7d63..81c5c0662 100644 --- a/app/commands.py +++ b/app/commands.py @@ -308,9 +308,7 @@ async def maplink(ctx: Context) -> str | None: if bmap is None: return "No map found!" - # gatari.pw & nerina.pw are pretty much the only - # reliable mirrors I know of? perhaps beatconnect - return f"[https://osu.gatari.pw/d/{bmap.set_id} {bmap.full_name}]" + return f"[{app.settings.MIRROR_DOWNLOAD_ENDPOINT}/{bmap.set_id} {bmap.full_name}]" @command(Privileges.UNRESTRICTED, aliases=["last", "r"]) From 34746869597de1b06be34f3c8d3711476982719a Mon Sep 17 00:00:00 2001 From: mini <39670899+minisbett@users.noreply.github.com> Date: Fri, 16 Feb 2024 06:56:04 +0100 Subject: [PATCH 02/48] Add tail to make logs (#605) * Add tail to make logs * Add parameter to make logs --- Makefile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 53892232b..e2fe84723 100644 --- a/Makefile +++ b/Makefile @@ -11,8 +11,9 @@ run-bg: run-caddy: caddy run --envfile .env --config ext/Caddyfile +last?=1 logs: - docker-compose logs -f bancho mysql redis + docker-compose logs -f bancho mysql redis --tail ${last} shell: poetry shell From 6e6d067741dfae91eef1a526629d4ac8fff12419 Mon Sep 17 00:00:00 2001 From: mini <39670899+minisbett@users.noreply.github.com> Date: Fri, 16 Feb 2024 07:13:44 +0100 Subject: [PATCH 03/48] feat.: add setting to disable in-game registration (#603) * Add setting to disable in-game registration * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix spelling error * Simplify error code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Default ingame-registration to disabled * Moved error to constant * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fail fast --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: cmyui --- .env.example | 1 + app/api/domains/osu.py | 17 +++++++++++++++++ app/settings.py | 1 + docker-compose.test.yml | 1 + docker-compose.yml | 1 + 5 files changed, 21 insertions(+) diff --git a/.env.example b/.env.example index 8da06f014..4e43ce48e 100644 --- a/.env.example +++ b/.env.example @@ -54,6 +54,7 @@ PP_CACHED_ACCS=90,95,98,99,100 DISALLOWED_NAMES=mrekk,vaxei,btmc,cookiezi DISALLOWED_PASSWORDS=password,abc123 DISALLOW_OLD_CLIENTS=True +DISALLOW_INGAME_REGISTRATION=True DISCORD_AUDIT_LOG_WEBHOOK= diff --git a/app/api/domains/osu.py b/app/api/domains/osu.py index a3e55f6dd..3cbb408b3 100644 --- a/app/api/domains/osu.py +++ b/app/api/domains/osu.py @@ -1644,6 +1644,16 @@ async def peppyDMHandler() -> Response: """ ingame registration """ +INGAME_REGISTRATION_DISALLOWED_ERROR = { + "form_error": { + "user": { + "password": [ + "In-game registration is disabled. Please register on the website.", + ], + }, + }, +} + @router.post("/users") async def register_account( @@ -1663,6 +1673,13 @@ async def register_account( status_code=status.HTTP_400_BAD_REQUEST, ) + # Disable in-game registration if enabled + if app.settings.DISALLOW_INGAME_REGISTRATION: + return ORJSONResponse( + content=INGAME_REGISTRATION_DISALLOWED_ERROR, + status_code=status.HTTP_400_BAD_REQUEST, + ) + # ensure all args passed # are safe for registration. errors: Mapping[str, list[str]] = defaultdict(list) diff --git a/app/settings.py b/app/settings.py index f925e7566..457cd74d7 100644 --- a/app/settings.py +++ b/app/settings.py @@ -67,6 +67,7 @@ DISALLOWED_NAMES = read_list(os.environ["DISALLOWED_NAMES"]) DISALLOWED_PASSWORDS = read_list(os.environ["DISALLOWED_PASSWORDS"]) DISALLOW_OLD_CLIENTS = read_bool(os.environ["DISALLOW_OLD_CLIENTS"]) +DISALLOW_INGAME_REGISTRATION = read_bool(os.environ["DISALLOW_INGAME_REGISTRATION"]) DISCORD_AUDIT_LOG_WEBHOOK = os.environ["DISCORD_AUDIT_LOG_WEBHOOK"] diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 7892301c9..8c0a92d9c 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -84,6 +84,7 @@ services: - DISALLOWED_NAMES=${DISALLOWED_NAMES} - DISALLOWED_PASSWORDS=${DISALLOWED_PASSWORDS} - DISALLOW_OLD_CLIENTS=${DISALLOW_OLD_CLIENTS} + - DISALLOW_INGAME_REGISTRATION=${DISALLOW_INGAME_REGISTRATION} - DISCORD_AUDIT_LOG_WEBHOOK=${DISCORD_AUDIT_LOG_WEBHOOK} - AUTOMATICALLY_REPORT_PROBLEMS=${AUTOMATICALLY_REPORT_PROBLEMS} - SSL_CERT_PATH=${SSL_CERT_PATH} diff --git a/docker-compose.yml b/docker-compose.yml index 6484f0722..7ac66d7a7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -80,6 +80,7 @@ services: - DISALLOWED_NAMES=${DISALLOWED_NAMES} - DISALLOWED_PASSWORDS=${DISALLOWED_PASSWORDS} - DISALLOW_OLD_CLIENTS=${DISALLOW_OLD_CLIENTS} + - DISALLOW_INGAME_REGISTRATION=${DISALLOW_INGAME_REGISTRATION} - DISCORD_AUDIT_LOG_WEBHOOK=${DISCORD_AUDIT_LOG_WEBHOOK} - AUTOMATICALLY_REPORT_PROBLEMS=${AUTOMATICALLY_REPORT_PROBLEMS} - SSL_CERT_PATH=${SSL_CERT_PATH} From 6c093b1af018e786a30596f5e45a15aa1a2b37be Mon Sep 17 00:00:00 2001 From: mini <39670899+minisbett@users.noreply.github.com> Date: Fri, 16 Feb 2024 07:16:25 +0100 Subject: [PATCH 04/48] Add rounding to score on getscores (#607) --- app/api/domains/osu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/api/domains/osu.py b/app/api/domains/osu.py index 3cbb408b3..ac35a609a 100644 --- a/app/api/domains/osu.py +++ b/app/api/domains/osu.py @@ -1393,7 +1393,7 @@ async def getScores( **personal_best_score_row, name=player.full_name, userid=player.id, - score=int(personal_best_score_row["_score"]), + score=int(round(personal_best_score_row["_score"])), has_replay="1", ), ) @@ -1404,7 +1404,7 @@ async def getScores( [ SCORE_LISTING_FMTSTR.format( **s, - score=int(s["_score"]), + score=int(round(s["_score"])), has_replay="1", rank=idx + 1, ) From 85ed13905176073d328f30e2b4f10b40e50846d8 Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Fri, 16 Feb 2024 01:22:51 -0500 Subject: [PATCH 05/48] Mock achievement downloads in tests (2x speedup to suite) (#596) * mock achievement downloads in tests * raw str * mock out default ava dl; try to refactor --- tests/conftest.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 475dc8035..f07aebc51 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,10 @@ import httpx import pytest +import respx from asgi_lifespan import LifespanManager from asgi_lifespan._types import ASGIApp +from fastapi import status from app.api.init_api import asgi_app @@ -16,6 +18,26 @@ # (We do not need an asynchronous http client for our tests) +@pytest.fixture(autouse=True) +def mock_out_initial_image_downloads(respx_mock: respx.MockRouter) -> None: + # mock out default avatar download + respx_mock.get(url__regex=r"https://i.cmyui.xyz/U24XBZw-4wjVME-JaEz3.png").mock( + return_value=httpx.Response( + status_code=status.HTTP_200_OK, + headers={"Content-Type": "image/png"}, + content=b"i am a png file", + ), + ) + # mock out achievement image downloads + respx_mock.get(url__regex=r"https://assets.ppy.sh/medals/client/.+").mock( + return_value=httpx.Response( + status_code=status.HTTP_200_OK, + headers={"Content-Type": "image/png"}, + content=b"i am a png file", + ), + ) + + @pytest.fixture async def app() -> AsyncIterator[ASGIApp]: async with LifespanManager( From 55c5d23d5119d3e5c58b019d496b2b4419640dc8 Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Fri, 16 Feb 2024 01:25:32 -0500 Subject: [PATCH 06/48] Prefer less permissive mock url pattern (#608) --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index f07aebc51..371ee10c8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ @pytest.fixture(autouse=True) def mock_out_initial_image_downloads(respx_mock: respx.MockRouter) -> None: # mock out default avatar download - respx_mock.get(url__regex=r"https://i.cmyui.xyz/U24XBZw-4wjVME-JaEz3.png").mock( + respx_mock.get("https://i.cmyui.xyz/U24XBZw-4wjVME-JaEz3.png").mock( return_value=httpx.Response( status_code=status.HTTP_200_OK, headers={"Content-Type": "image/png"}, From 59b17c8293eb74b8ecd84c1f5d413a4148a1e714 Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Fri, 16 Feb 2024 02:50:48 -0500 Subject: [PATCH 07/48] Simplify the `Player` object/class (#610) * avoid usage of kwargs * fix type errs * simplify * more clear noop packet handler usecase * f * more specific type ignore * bugfix Player.geoloc to avoid shared refs * ensure priv input to player is of Privileges type * remove most logic from `Player.__init__` * simpler import --- app/api/domains/cho.py | 3 +- app/objects/collections.py | 7 ++- app/objects/player.py | 97 ++++++++++++++++++-------------------- 3 files changed, 54 insertions(+), 53 deletions(-) diff --git a/app/api/domains/cho.py b/app/api/domains/cho.py index 8f7d856c1..623a65210 100644 --- a/app/api/domains/cho.py +++ b/app/api/domains/cho.py @@ -836,8 +836,9 @@ async def handle_osu_login_request( player = Player( id=user_info["id"], name=user_info["name"], - priv=user_info["priv"], + priv=Privileges(user_info["priv"]), pw_bcrypt=user_info["pw_bcrypt"].encode(), + token=Player.generate_token(), clan=clan, clan_priv=clan_priv, geoloc=geoloc, diff --git a/app/objects/collections.py b/app/objects/collections.py index 54e22dae3..450d0de35 100644 --- a/app/objects/collections.py +++ b/app/objects/collections.py @@ -227,8 +227,9 @@ async def get_sql( return Player( id=player["id"], name=player["name"], - pw_bcrypt=player["pw_bcrypt"].encode(), priv=Privileges(player["priv"]), + pw_bcrypt=player["pw_bcrypt"].encode(), + token=Player.generate_token(), clan=clan, clan_priv=clan_priv, geoloc={ @@ -466,8 +467,10 @@ async def initialize_ram_caches(db_conn: databases.core.Connection) -> None: app.state.sessions.bot = Player( id=1, name=bot["name"], - login_time=float(0x7FFFFFFF), # (never auto-dc) priv=Privileges.UNRESTRICTED, + pw_bcrypt=None, + token=Player.generate_token(), + login_time=float(0x7FFFFFFF), # (never auto-dc) is_bot_client=True, ) app.state.sessions.players.append(app.state.sessions.bot) diff --git a/app/objects/player.py b/app/objects/player.py index 292fe82d7..9c86e5a1f 100644 --- a/app/objects/player.py +++ b/app/objects/player.py @@ -37,13 +37,13 @@ from app.objects.score import Score from app.repositories import logs as logs_repo from app.repositories import stats as stats_repo +from app.state.services import Geolocation from app.utils import escape_enum from app.utils import make_safe_name from app.utils import pymysql_encode if TYPE_CHECKING: from app.constants.privileges import ClanPrivileges - from app.objects.achievement import Achievement from app.objects.beatmap import Beatmap from app.objects.clan import Clan from app.objects.score import Score @@ -213,27 +213,59 @@ def __init__( self, id: int, name: str, - priv: int | Privileges, - **extras: Any, + priv: Privileges, + pw_bcrypt: bytes | None, + token: str, + clan: Clan | None = None, + clan_priv: ClanPrivileges | None = None, + geoloc: Geolocation | None = None, + utc_offset: int = 0, + pm_private: bool = False, + silence_end: int = 0, + donor_end: int = 0, + client_details: ClientDetails | None = None, + login_time: float = 0.0, + is_bot_client: bool = False, + is_tourney_client: bool = False, + api_key: str | None = None, ) -> None: + if geoloc is None: + geoloc = { + "latitude": 0.0, + "longitude": 0.0, + "country": {"acronym": "xx", "numeric": 0}, + } + self.id = id self.name = name self.safe_name = self.make_safe(self.name) + self.priv = priv + self.pw_bcrypt = pw_bcrypt + self.token = token + self.clan = clan + self.clan_priv = clan_priv + self.geoloc = geoloc + self.utc_offset = utc_offset + self.pm_private = pm_private + self.silence_end = silence_end + self.donor_end = donor_end + self.client_details = client_details + self.login_time = login_time + self.last_recv_time = login_time + self.is_bot_client = is_bot_client + self.is_tourney_client = is_tourney_client + self.api_key = api_key - if "pw_bcrypt" in extras: - self.pw_bcrypt: bytes | None = extras["pw_bcrypt"] - else: - self.pw_bcrypt = None + # avoid enqueuing packets to bot accounts. + if self.is_bot_client: - # generate a token if not given - token = extras.get("token", None) - if token is not None and isinstance(token, str): - self.token = token - else: - self.token = self.generate_token() + def _noop_enqueue(data: bytes) -> None: + pass - # ensure priv is of type Privileges - self.priv = priv if isinstance(priv, Privileges) else Privileges(priv) + self.enqueue = _noop_enqueue # type: ignore[method-assign] + + self.away_msg: str | None = None + self.in_lobby = False self.stats: dict[GameMode, ModeData] = {} self.status = Status() @@ -248,32 +280,8 @@ def __init__( self.match: Match | None = None self.stealth = False - self.clan: Clan | None = extras.get("clan") - self.clan_priv: ClanPrivileges | None = extras.get("clan_priv") - - self.geoloc: app.state.services.Geolocation = extras.get( - "geoloc", - { - "latitude": 0.0, - "longitude": 0.0, - "country": {"acronym": "xx", "numeric": 0}, - }, - ) - - self.utc_offset = extras.get("utc_offset", 0) - self.pm_private = extras.get("pm_private", False) - self.away_msg: str | None = None - self.silence_end = extras.get("silence_end", 0) - self.donor_end = extras.get("donor_end", 0) - self.in_lobby = False - - self.client_details: ClientDetails | None = extras.get("client_details") self.pres_filter = PresenceFilter.Nil - login_time = extras.get("login_time", 0.0) - self.login_time = login_time - self.last_recv_time = login_time - # XXX: below is mostly implementation-specific & internal stuff # store most recent score for each gamemode. @@ -284,17 +292,6 @@ def __init__( # store the last beatmap /np'ed by the user. self.last_np: LastNp | None = None - # subject to possible change in the future, - # although if anything, bot accounts will - # probably just use the /api/ routes? - self.is_bot_client = extras.get("is_bot_client", False) - if self.is_bot_client: - self.enqueue = lambda data: None # type: ignore - - self.is_tourney_client = extras.get("is_tourney_client", False) - - self.api_key = extras.get("api_key", None) - self._packet_queue = bytearray() def __repr__(self) -> str: From eead47919aea6ce140e9f4b081f2d4f2b8345f58 Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Fri, 16 Feb 2024 11:03:39 -0500 Subject: [PATCH 08/48] Setup test coverage (text & html reports available in CI) (#609) * add coverage dep * add .coverage to .gitignore * coverage run & report in exec-tests script * html cov report * report --show-missing --fail-under=45 * upload code cov artifact in ci --- .github/workflows/test.yaml | 6 ++++ .gitignore | 1 + poetry.lock | 66 ++++++++++++++++++++++++++++++++++++- pyproject.toml | 1 + scripts/run-tests.sh | 4 ++- 5 files changed, 76 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 306571095..b05341a79 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -65,3 +65,9 @@ jobs: - name: Stop containers if: always() run: docker-compose down + + - name: Archive code coverage results + uses: actions/upload-artifact@v2 + with: + name: code-coverage-report + path: coverage/ diff --git a/.gitignore b/.gitignore index e2ec87460..33b467341 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,4 @@ tools/cf_records.txt /.db-data/ /.redis-data/ poetry.toml +.coverage diff --git a/poetry.lock b/poetry.lock index 18460a2b9..c39627f27 100644 --- a/poetry.lock +++ b/poetry.lock @@ -427,6 +427,70 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "coverage" +version = "7.4.1" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "coverage-7.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:077d366e724f24fc02dbfe9d946534357fda71af9764ff99d73c3c596001bbd7"}, + {file = "coverage-7.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0193657651f5399d433c92f8ae264aff31fc1d066deee4b831549526433f3f61"}, + {file = "coverage-7.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d17bbc946f52ca67adf72a5ee783cd7cd3477f8f8796f59b4974a9b59cacc9ee"}, + {file = "coverage-7.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a3277f5fa7483c927fe3a7b017b39351610265308f5267ac6d4c2b64cc1d8d25"}, + {file = "coverage-7.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dceb61d40cbfcf45f51e59933c784a50846dc03211054bd76b421a713dcdf19"}, + {file = "coverage-7.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6008adeca04a445ea6ef31b2cbaf1d01d02986047606f7da266629afee982630"}, + {file = "coverage-7.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c61f66d93d712f6e03369b6a7769233bfda880b12f417eefdd4f16d1deb2fc4c"}, + {file = "coverage-7.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b9bb62fac84d5f2ff523304e59e5c439955fb3b7f44e3d7b2085184db74d733b"}, + {file = "coverage-7.4.1-cp310-cp310-win32.whl", hash = "sha256:f86f368e1c7ce897bf2457b9eb61169a44e2ef797099fb5728482b8d69f3f016"}, + {file = "coverage-7.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:869b5046d41abfea3e381dd143407b0d29b8282a904a19cb908fa24d090cc018"}, + {file = "coverage-7.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b8ffb498a83d7e0305968289441914154fb0ef5d8b3157df02a90c6695978295"}, + {file = "coverage-7.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3cacfaefe6089d477264001f90f55b7881ba615953414999c46cc9713ff93c8c"}, + {file = "coverage-7.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d6850e6e36e332d5511a48a251790ddc545e16e8beaf046c03985c69ccb2676"}, + {file = "coverage-7.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18e961aa13b6d47f758cc5879383d27b5b3f3dcd9ce8cdbfdc2571fe86feb4dd"}, + {file = "coverage-7.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dfd1e1b9f0898817babf840b77ce9fe655ecbe8b1b327983df485b30df8cc011"}, + {file = "coverage-7.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6b00e21f86598b6330f0019b40fb397e705135040dbedc2ca9a93c7441178e74"}, + {file = "coverage-7.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:536d609c6963c50055bab766d9951b6c394759190d03311f3e9fcf194ca909e1"}, + {file = "coverage-7.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7ac8f8eb153724f84885a1374999b7e45734bf93a87d8df1e7ce2146860edef6"}, + {file = "coverage-7.4.1-cp311-cp311-win32.whl", hash = "sha256:f3771b23bb3675a06f5d885c3630b1d01ea6cac9e84a01aaf5508706dba546c5"}, + {file = "coverage-7.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:9d2f9d4cc2a53b38cabc2d6d80f7f9b7e3da26b2f53d48f05876fef7956b6968"}, + {file = "coverage-7.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f68ef3660677e6624c8cace943e4765545f8191313a07288a53d3da188bd8581"}, + {file = "coverage-7.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23b27b8a698e749b61809fb637eb98ebf0e505710ec46a8aa6f1be7dc0dc43a6"}, + {file = "coverage-7.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e3424c554391dc9ef4a92ad28665756566a28fecf47308f91841f6c49288e66"}, + {file = "coverage-7.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e0860a348bf7004c812c8368d1fc7f77fe8e4c095d661a579196a9533778e156"}, + {file = "coverage-7.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe558371c1bdf3b8fa03e097c523fb9645b8730399c14fe7721ee9c9e2a545d3"}, + {file = "coverage-7.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3468cc8720402af37b6c6e7e2a9cdb9f6c16c728638a2ebc768ba1ef6f26c3a1"}, + {file = "coverage-7.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:02f2edb575d62172aa28fe00efe821ae31f25dc3d589055b3fb64d51e52e4ab1"}, + {file = "coverage-7.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ca6e61dc52f601d1d224526360cdeab0d0712ec104a2ce6cc5ccef6ed9a233bc"}, + {file = "coverage-7.4.1-cp312-cp312-win32.whl", hash = "sha256:ca7b26a5e456a843b9b6683eada193fc1f65c761b3a473941efe5a291f604c74"}, + {file = "coverage-7.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:85ccc5fa54c2ed64bd91ed3b4a627b9cce04646a659512a051fa82a92c04a448"}, + {file = "coverage-7.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8bdb0285a0202888d19ec6b6d23d5990410decb932b709f2b0dfe216d031d218"}, + {file = "coverage-7.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:918440dea04521f499721c039863ef95433314b1db00ff826a02580c1f503e45"}, + {file = "coverage-7.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:379d4c7abad5afbe9d88cc31ea8ca262296480a86af945b08214eb1a556a3e4d"}, + {file = "coverage-7.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b094116f0b6155e36a304ff912f89bbb5067157aff5f94060ff20bbabdc8da06"}, + {file = "coverage-7.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2f5968608b1fe2a1d00d01ad1017ee27efd99b3437e08b83ded9b7af3f6f766"}, + {file = "coverage-7.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:10e88e7f41e6197ea0429ae18f21ff521d4f4490aa33048f6c6f94c6045a6a75"}, + {file = "coverage-7.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a4a3907011d39dbc3e37bdc5df0a8c93853c369039b59efa33a7b6669de04c60"}, + {file = "coverage-7.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d224f0c4c9c98290a6990259073f496fcec1b5cc613eecbd22786d398ded3ad"}, + {file = "coverage-7.4.1-cp38-cp38-win32.whl", hash = "sha256:23f5881362dcb0e1a92b84b3c2809bdc90db892332daab81ad8f642d8ed55042"}, + {file = "coverage-7.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:a07f61fc452c43cd5328b392e52555f7d1952400a1ad09086c4a8addccbd138d"}, + {file = "coverage-7.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8e738a492b6221f8dcf281b67129510835461132b03024830ac0e554311a5c54"}, + {file = "coverage-7.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:46342fed0fff72efcda77040b14728049200cbba1279e0bf1188f1f2078c1d70"}, + {file = "coverage-7.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9641e21670c68c7e57d2053ddf6c443e4f0a6e18e547e86af3fad0795414a628"}, + {file = "coverage-7.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aeb2c2688ed93b027eb0d26aa188ada34acb22dceea256d76390eea135083950"}, + {file = "coverage-7.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d12c923757de24e4e2110cf8832d83a886a4cf215c6e61ed506006872b43a6d1"}, + {file = "coverage-7.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0491275c3b9971cdbd28a4595c2cb5838f08036bca31765bad5e17edf900b2c7"}, + {file = "coverage-7.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8dfc5e195bbef80aabd81596ef52a1277ee7143fe419efc3c4d8ba2754671756"}, + {file = "coverage-7.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1a78b656a4d12b0490ca72651fe4d9f5e07e3c6461063a9b6265ee45eb2bdd35"}, + {file = "coverage-7.4.1-cp39-cp39-win32.whl", hash = "sha256:f90515974b39f4dea2f27c0959688621b46d96d5a626cf9c53dbc653a895c05c"}, + {file = "coverage-7.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:64e723ca82a84053dd7bfcc986bdb34af8d9da83c521c19d6b472bc6880e191a"}, + {file = "coverage-7.4.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:32a8d985462e37cfdab611a6f95b09d7c091d07668fdc26e47a725ee575fe166"}, + {file = "coverage-7.4.1.tar.gz", hash = "sha256:1ed4b95480952b1a26d863e546fa5094564aa0065e1e5f0d4d0041f293251d04"}, +] + +[package.extras] +toml = ["tomli"] + [[package]] name = "cryptography" version = "42.0.2" @@ -1835,4 +1899,4 @@ cython = "*" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "7c3881eb25fe26fec7b3c82c27585bd79d29fdb50b83137837d42a500bd167cd" +content-hash = "64468dbf2aab41fb6a602d155a304aa26d7562256dde56ead7973e8991192b9a" diff --git a/pyproject.toml b/pyproject.toml index 47a9943f2..9ab2cdcd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ pytest-asyncio = "0.23.5" asgi-lifespan = "2.1.0" respx = "0.20.2" tzdata = "2024.1" +coverage = "^7.4.1" [tool.poetry.group.dev.dependencies] pre-commit = "3.6.1" diff --git a/scripts/run-tests.sh b/scripts/run-tests.sh index 8009f75e8..849760c4d 100755 --- a/scripts/run-tests.sh +++ b/scripts/run-tests.sh @@ -56,4 +56,6 @@ execDBStatement "source /srv/root/migrations/base.sql" # Run tests echo "Running tests..." -pytest -vv -s tests/ +coverage run -m pytest -vv -s tests/ +coverage report --show-missing --fail-under=45 +coverage html From 4310cf0769f903c3c9d81dc9cb109986530405a1 Mon Sep 17 00:00:00 2001 From: cmyui Date: Fri, 16 Feb 2024 13:42:48 -0500 Subject: [PATCH 09/48] Hard disallow deprecated config attrs after supported-until date --- app/settings_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/app/settings_utils.py b/app/settings_utils.py index 63db43e19..e71249c42 100644 --- a/app/settings_utils.py +++ b/app/settings_utils.py @@ -28,6 +28,11 @@ def support_deprecated_vars( val2 = os.getenv(deprecated_name) if val2: + if until < date.today(): + raise ValueError( + f'The "{deprecated_name}" config option has been deprecated as of {until.isoformat()} and is no longer supported. Use {new_name} instead.', + ) + log( f'The "{deprecated_name}" config option has been deprecated and will be supported until {until.isoformat()}. Use {new_name} instead.', Ansi.LYELLOW, From 9e0d4811e8092c9266769ab2bd3da835116e7613 Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Fri, 16 Feb 2024 15:30:58 -0500 Subject: [PATCH 10/48] Split `MapPools` and related collections/cache into repositories (#597) * pool maps repo and finish tourney pools * fixes for `match.tourney_pool.maps` usage * import fix * type fix * style * fix mods display & sort in !pool info * fix incorrect column name in pools fetch_many * fix incorrect null handling in pools repo fetch_by_name * fix no-arg case handling in !mp scrim --- app/api/v1/api.py | 30 +++- app/commands.py | 203 +++++++++++++------------- app/objects/collections.py | 92 ------------ app/objects/match.py | 61 +------- app/repositories/tourney_pool_maps.py | 180 +++++++++++++++++++++++ app/repositories/tourney_pools.py | 148 +++++++++++++++++++ app/state/sessions.py | 3 - 7 files changed, 453 insertions(+), 264 deletions(-) create mode 100644 app/repositories/tourney_pool_maps.py create mode 100644 app/repositories/tourney_pools.py diff --git a/app/api/v1/api.py b/app/api/v1/api.py index 3aa34aa15..20469a57c 100644 --- a/app/api/v1/api.py +++ b/app/api/v1/api.py @@ -29,6 +29,8 @@ from app.repositories import players as players_repo from app.repositories import scores as scores_repo from app.repositories import stats as stats_repo +from app.repositories import tourney_pool_maps as tourney_pool_maps_repo +from app.repositories import tourney_pools as tourney_pools_repo from app.usecases.performance import ScoreParams AVATARS_PATH = SystemPath.cwd() / ".data/avatars" @@ -994,22 +996,36 @@ async def api_get_pool( ) -> Response: """Return information of a given mappool.""" - pool = app.state.sessions.pools.get(id=pool_id) - if not pool: + tourney_pool = await tourney_pools_repo.fetch_by_id(id=pool_id) + if tourney_pool is None: return ORJSONResponse( {"status": "Pool not found."}, status_code=status.HTTP_404_NOT_FOUND, ) + tourney_pool_maps: dict[tuple[int, int], Beatmap] = {} + for pool_map in await tourney_pool_maps_repo.fetch_many(pool_id=pool_id): + bmap = await Beatmap.from_bid(pool_map["map_id"]) + if bmap is not None: + tourney_pool_maps[(pool_map["mods"], pool_map["slot"])] = bmap + + pool_creator = app.state.sessions.players.get(id=tourney_pool["created_by"]) + + if pool_creator is None: + return ORJSONResponse( + {"status": "Pool creator not found."}, + status_code=status.HTTP_404_NOT_FOUND, + ) + return ORJSONResponse( { - "id": pool.id, - "name": pool.name, - "created_at": pool.created_at, - "created_by": format_player_basic(pool.created_by), + "id": tourney_pool["id"], + "name": tourney_pool["name"], + "created_at": tourney_pool["created_at"], + "created_by": format_player_basic(pool_creator), "maps": { f"{mods!r}{slot}": format_map_basic(bmap) - for (mods, slot), bmap in pool.maps.items() + for (mods, slot), bmap in tourney_pool_maps.items() }, }, ) diff --git a/app/commands.py b/app/commands.py index 81c5c0662..abf6da95b 100644 --- a/app/commands.py +++ b/app/commands.py @@ -9,7 +9,6 @@ import time import traceback import uuid -from collections import defaultdict from collections.abc import Awaitable from collections.abc import Callable from collections.abc import Mapping @@ -50,7 +49,6 @@ from app.objects.beatmap import RankedStatus from app.objects.beatmap import ensure_osu_file_is_available from app.objects.clan import Clan -from app.objects.match import MapPool from app.objects.match import Match from app.objects.match import MatchTeams from app.objects.match import MatchTeamTypes @@ -63,6 +61,8 @@ from app.repositories import map_requests as map_requests_repo from app.repositories import maps as maps_repo from app.repositories import players as players_repo +from app.repositories import tourney_pool_maps as tourney_pool_maps_repo +from app.repositories import tourney_pools as tourney_pools_repo from app.usecases.performance import ScoreParams from app.utils import seconds_readable @@ -1779,8 +1779,11 @@ async def mp_condition(ctx: Context, match: Match) -> str | None: @ensure_match async def mp_scrim(ctx: Context, match: Match) -> str | None: """Start a scrim in the current match.""" + if len(ctx.args) != 1: + return "Invalid syntax: !mp scrim " + r_match = regexes.BEST_OF.fullmatch(ctx.args[0]) - if len(ctx.args) != 1 or not r_match: + if not r_match: return "Invalid syntax: !mp scrim " best_of = int(r_match[1]) @@ -1895,15 +1898,18 @@ async def mp_loadpool(ctx: Context, match: Match) -> str | None: name = ctx.args[0] - pool = app.state.sessions.pools.get_by_name(name) - if not pool: + tourney_pool = await tourney_pools_repo.fetch_by_name(name) + if tourney_pool is None: return "Could not find a pool by that name!" - if match.pool is pool: - return f"{pool!r} already selected!" + if ( + match.tourney_pool is not None + and match.tourney_pool["id"] == tourney_pool["id"] + ): + return f"{tourney_pool['name']} already selected!" - match.pool = pool - return f"{pool!r} selected." + match.tourney_pool = tourney_pool + return f"{tourney_pool['name']} selected." @mp_commands.add(Privileges.UNRESTRICTED, aliases=["ulp"]) @@ -1916,10 +1922,10 @@ async def mp_unloadpool(ctx: Context, match: Match) -> str | None: if ctx.player is not match.host: return "Only available to the host." - if not match.pool: + if not match.tourney_pool: return "No mappool currently selected!" - match.pool = None + match.tourney_pool = None return "Mappool unloaded." @@ -1930,7 +1936,7 @@ async def mp_ban(ctx: Context, match: Match) -> str | None: if len(ctx.args) != 1: return "Invalid syntax: !mp ban " - if not match.pool: + if not match.tourney_pool: return "No pool currently selected!" mods_slot = ctx.args[0] @@ -1944,7 +1950,12 @@ async def mp_ban(ctx: Context, match: Match) -> str | None: mods = Mods.from_modstr(r_match[1]) slot = int(r_match[2]) - if (mods, slot) not in match.pool.maps: + map_pick = await tourney_pool_maps_repo.fetch_by_pool_and_pick( + pool_id=match.tourney_pool["id"], + mods=mods, + slot=slot, + ) + if map_pick is None: return f"Found no {mods_slot} pick in the pool." if (mods, slot) in match.bans: @@ -1961,7 +1972,7 @@ async def mp_unban(ctx: Context, match: Match) -> str | None: if len(ctx.args) != 1: return "Invalid syntax: !mp unban " - if not match.pool: + if not match.tourney_pool: return "No pool currently selected!" mods_slot = ctx.args[0] @@ -1975,7 +1986,12 @@ async def mp_unban(ctx: Context, match: Match) -> str | None: mods = Mods.from_modstr(r_match[1]) slot = int(r_match[2]) - if (mods, slot) not in match.pool.maps: + map_pick = await tourney_pool_maps_repo.fetch_by_pool_and_pick( + pool_id=match.tourney_pool["id"], + mods=mods, + slot=slot, + ) + if map_pick is None: return f"Found no {mods_slot} pick in the pool." if (mods, slot) not in match.bans: @@ -1992,7 +2008,7 @@ async def mp_pick(ctx: Context, match: Match) -> str | None: if len(ctx.args) != 1: return "Invalid syntax: !mp pick " - if not match.pool: + if not match.tourney_pool: return "No pool currently loaded!" mods_slot = ctx.args[0] @@ -2006,14 +2022,21 @@ async def mp_pick(ctx: Context, match: Match) -> str | None: mods = Mods.from_modstr(r_match[1]) slot = int(r_match[2]) - if (mods, slot) not in match.pool.maps: + map_pick = await tourney_pool_maps_repo.fetch_by_pool_and_pick( + pool_id=match.tourney_pool["id"], + mods=mods, + slot=slot, + ) + if map_pick is None: return f"Found no {mods_slot} pick in the pool." if (mods, slot) in match.bans: return f"{mods_slot} has been banned from being picked." - # update match beatmap to the picked map. - bmap = match.pool.maps[(mods, slot)] + bmap = await Beatmap.from_bid(map_pick["map_id"]) + if not bmap: + return f"Found no beatmap for {mods_slot} pick." + match.map_md5 = bmap.md5 match.map_id = bmap.id match.map_name = bmap.full_name @@ -2067,37 +2090,13 @@ async def pool_create(ctx: Context) -> str | None: name = ctx.args[0] - if app.state.sessions.pools.get_by_name(name): + existing_pool = await tourney_pools_repo.fetch_by_name(name) + if existing_pool is not None: return "Pool already exists by that name!" - # insert pool into db - await app.state.services.database.execute( - "INSERT INTO tourney_pools " - "(name, created_at, created_by) " - "VALUES (:name, NOW(), :user_id)", - {"name": name, "user_id": ctx.player.id}, - ) - - # add to cache (get from sql for id & time) - rec = await app.state.services.database.fetch_one( - "SELECT * FROM tourney_pools WHERE name = :name", - {"name": name}, - ) - assert rec is not None - row = dict(rec._mapping) - - pool_creator = await app.state.sessions.players.from_cache_or_sql( - id=row["created_by"], - ) - assert pool_creator is not None - - app.state.sessions.pools.append( - MapPool( - id=row["id"], - name=row["name"], - created_at=row["created_at"], - created_by=pool_creator, - ), + tourney_pool = await tourney_pools_repo.create( + name=name, + created_by=ctx.player.id, ) return f"{name} created." @@ -2111,23 +2110,12 @@ async def pool_delete(ctx: Context) -> str | None: name = ctx.args[0] - pool = app.state.sessions.pools.get_by_name(name) - if not pool: + existing_pool = await tourney_pools_repo.fetch_by_name(name) + if existing_pool is None: return "Could not find a pool by that name!" - # delete from db - await app.state.services.database.execute( - "DELETE FROM tourney_pools WHERE id = :pool_id", - {"pool_id": pool.id}, - ) - - await app.state.services.database.execute( - "DELETE FROM tourney_pool_maps WHERE pool_id = :pool_id", - {"pool_id": pool.id}, - ) - - # remove from cache - app.state.sessions.pools.remove(pool) + await tourney_pools_repo.delete_by_id(existing_pool["id"]) + await tourney_pool_maps_repo.delete_all_in_pool(pool_id=existing_pool["id"]) return f"{name} deleted." @@ -2157,27 +2145,29 @@ async def pool_add(ctx: Context) -> str | None: mods = Mods.from_modstr(r_match[1]) slot = int(r_match[2]) - pool = app.state.sessions.pools.get_by_name(name) - if not pool: + tourney_pool = await tourney_pools_repo.fetch_by_name(name) + if tourney_pool is None: return "Could not find a pool by that name!" - if (mods, slot) in pool.maps: - return f"{mods_slot} is already {pool.maps[(mods, slot)].embed}!" + tourney_pool_maps = await tourney_pool_maps_repo.fetch_many( + pool_id=tourney_pool["id"], + ) + for pool_map in tourney_pool_maps: + if mods == pool_map["mods"] and slot == pool_map["slot"]: + pool_beatmap = await Beatmap.from_bid(pool_map["map_id"]) + assert pool_beatmap is not None + return f"{mods_slot} is already {pool_beatmap.embed}!" - if bmap in pool.maps.values(): - return "Map is already in the pool!" + if pool_map["map_id"] == bmap.id: + return f"{bmap.embed} is already in the pool!" - # insert into db - await app.state.services.database.execute( - "INSERT INTO tourney_pool_maps " - "(map_id, pool_id, mods, slot) " - "VALUES (:map_id, :pool_id, :mods, :slot)", - {"map_id": bmap.id, "pool_id": pool.id, "mods": mods, "slot": slot}, + await tourney_pool_maps_repo.create( + map_id=bmap.id, + pool_id=tourney_pool["id"], + mods=mods, + slot=slot, ) - # add to cache - pool.maps[(mods, slot)] = bmap - return f"{bmap.embed} added to {name} as {mods_slot}." @@ -2199,38 +2189,39 @@ async def pool_remove(ctx: Context) -> str | None: mods = Mods.from_modstr(r_match[1]) slot = int(r_match[2]) - pool = app.state.sessions.pools.get_by_name(name) - if not pool: + tourney_pool = await tourney_pools_repo.fetch_by_name(name) + if tourney_pool is None: return "Could not find a pool by that name!" - if (mods, slot) not in pool.maps: + map_pick = await tourney_pool_maps_repo.fetch_by_pool_and_pick( + pool_id=tourney_pool["id"], + mods=mods, + slot=slot, + ) + if map_pick is None: return f"Found no {mods_slot} pick in the pool." - # delete from db - await app.state.services.database.execute( - "DELETE FROM tourney_pool_maps WHERE mods = :mods AND slot = :slot", - {"mods": mods, "slot": slot}, + await tourney_pool_maps_repo.delete_map_from_pool( + map_pick["pool_id"], + map_pick["map_id"], ) - # remove from cache - del pool.maps[(mods, slot)] - return f"{mods_slot} removed from {name}." @pool_commands.add(Privileges.TOURNEY_MANAGER, aliases=["l"], hidden=True) async def pool_list(ctx: Context) -> str | None: """List all existing mappools information.""" - pools = app.state.sessions.pools - if not pools: + tourney_pools = await tourney_pools_repo.fetch_many(page=None, page_size=None) + if not tourney_pools: return "There are currently no pools!" - l = [f"Mappools ({len(pools)})"] + l = [f"Mappools ({len(tourney_pools)})"] - for pool in pools: + for pool in tourney_pools: l.append( - f"[{pool.created_at:%Y-%m-%d}] {pool.id}. " - f"{pool.name}, by {pool.created_by}.", + f"[{pool['created_at']:%Y-%m-%d}] {pool['id']} " + f"{pool['name']}, by {pool['created_by']}.", ) return "\n".join(l) @@ -2244,20 +2235,26 @@ async def pool_info(ctx: Context) -> str | None: name = ctx.args[0] - pool = app.state.sessions.pools.get_by_name(name) - if not pool: + tourney_pool = await tourney_pools_repo.fetch_by_name(name) + if tourney_pool is None: return "Could not find a pool by that name!" - _time = pool.created_at.strftime("%H:%M:%S%p") - _date = pool.created_at.strftime("%Y-%m-%d") + _time = tourney_pool["created_at"].strftime("%H:%M:%S%p") + _date = tourney_pool["created_at"].strftime("%Y-%m-%d") datetime_fmt = f"Created at {_time} on {_date}" - l = [f"{pool.id}. {pool.name}, by {pool.created_by} | {datetime_fmt}."] + l = [ + f"{tourney_pool['id']}. {tourney_pool['name']}, by {tourney_pool['created_by']} | {datetime_fmt}.", + ] - for (mods, slot), bmap in sorted( - pool.maps.items(), - key=lambda x: (repr(x[0][0]), x[0][1]), + for tourney_map in sorted( + await tourney_pool_maps_repo.fetch_many(pool_id=tourney_pool["id"]), + key=lambda x: (repr(Mods(x["mods"])), x["slot"]), ): - l.append(f"{mods!r}{slot}: {bmap.embed}") + bmap = await Beatmap.from_bid(tourney_map["map_id"]) + if bmap is None: + log(f"Could not find beatmap {tourney_map['map_id']}.", Ansi.LRED) + continue + l.append(f"{Mods(tourney_map['mods'])!r}{tourney_map['slot']}: {bmap.embed}") return "\n".join(l) diff --git a/app/objects/collections.py b/app/objects/collections.py index 450d0de35..17ec7d4ea 100644 --- a/app/objects/collections.py +++ b/app/objects/collections.py @@ -16,27 +16,15 @@ from app.constants.privileges import Privileges from app.logging import Ansi from app.logging import log -from app.objects.achievement import Achievement from app.objects.channel import Channel from app.objects.clan import Clan -from app.objects.match import MapPool from app.objects.match import Match from app.objects.player import Player -from app.repositories import achievements as achievements_repo from app.repositories import channels as channels_repo from app.repositories import clans as clans_repo from app.repositories import players as players_repo from app.utils import make_safe_name -__all__ = ( - "Channels", - "Matches", - "Players", - "MapPools", - "Clans", - "initialize_ram_caches", -) - # TODO: decorator for these collections which automatically # adds debugging to their append/remove/insert/extend methods. @@ -302,85 +290,6 @@ def remove(self, player: Player) -> None: super().remove(player) -class MapPools(list[MapPool]): - """The currently active mappools on the server.""" - - def __iter__(self) -> Iterator[MapPool]: - return super().__iter__() - - def get( - self, - id: int | None = None, - name: str | None = None, - ) -> MapPool | None: - """Get a mappool by id, or name from cache.""" - for player in self: - if id is not None: - if player.id == id: - return player - elif name is not None: - if player.name == name: - return player - - return None - - def __contains__(self, o: object) -> bool: - """Check whether internal list contains `o`.""" - # Allow string to be passed to compare vs. name. - if isinstance(o, str): - return o in (pool.name for pool in self) - else: - return o in self - - def get_by_name(self, name: str) -> MapPool | None: - """Get a pool from the list by `name`.""" - for player in self: - if player.name == name: - return player - - return None - - def append(self, mappool: MapPool) -> None: - """Append `mappool` to the list.""" - super().append(mappool) - - if app.settings.DEBUG: - log(f"{mappool} added to mappools list.") - - def extend(self, mappools: Iterable[MapPool]) -> None: - """Extend the list with `mappools`.""" - super().extend(mappools) - - if app.settings.DEBUG: - log(f"{mappools} added to mappools list.") - - def remove(self, mappool: MapPool) -> None: - """Remove `mappool` from the list.""" - super().remove(mappool) - - if app.settings.DEBUG: - log(f"{mappool} removed from mappools list.") - - async def prepare(self, db_conn: databases.core.Connection) -> None: - """Fetch data from sql & return; preparing to run the server.""" - log("Fetching mappools from sql.", Ansi.LCYAN) - for row in await db_conn.fetch_all("SELECT * FROM tourney_pools"): - created_by = await app.state.sessions.players.from_cache_or_sql( - id=row["created_by"], - ) - - assert created_by is not None - - pool = MapPool( - id=row["id"], - name=row["name"], - created_at=row["created_at"], - created_by=created_by, - ) - await pool.maps_from_sql(db_conn) - self.append(pool) - - class Clans(list[Clan]): """The currently active clans on the server.""" @@ -457,7 +366,6 @@ async def initialize_ram_caches(db_conn: databases.core.Connection) -> None: # fetch channels, clans and pools from db await app.state.sessions.channels.prepare(db_conn) await app.state.sessions.clans.prepare(db_conn) - await app.state.sessions.pools.prepare(db_conn) bot = await players_repo.fetch_one(id=1) if bot is None: diff --git a/app/objects/match.py b/app/objects/match.py index 79e503cef..47e41060f 100644 --- a/app/objects/match.py +++ b/app/objects/match.py @@ -21,6 +21,7 @@ from app.logging import Ansi from app.logging import log from app.objects.beatmap import Beatmap +from app.repositories.tourney_pools import TourneyPool from app.utils import escape_enum from app.utils import pymysql_encode @@ -30,16 +31,6 @@ from app.objects.channel import Channel from app.objects.player import Player -__all__ = ( - "SlotStatus", - "MatchTeams", - #'MatchTypes', - "MatchWinConditions", - "MatchTeamTypes", - "MapPool", - "Slot", - "Match", -) MAX_MATCH_NAME_LENGTH = 50 @@ -96,54 +87,6 @@ class MatchTeamTypes(IntEnum): tag_team_vs = 3 -class MapPool: - def __init__( - self, - id: int, - name: str, - created_at: datetime, - created_by: Player, - ) -> None: - self.id = id - self.name = name - self.created_at = created_at - self.created_by = created_by - - self.maps: dict[ - tuple[Mods, int], - Beatmap, - ] = {} - - def __repr__(self) -> str: - return f"<{self.name}>" - - async def maps_from_sql(self, db_conn: databases.core.Connection) -> None: - """Retrieve all maps from sql to populate `self.maps`.""" - for row in await db_conn.fetch_all( - "SELECT map_id, mods, slot FROM tourney_pool_maps WHERE pool_id = :pool_id", - {"pool_id": self.id}, - ): - map_id = row["map_id"] - bmap = await Beatmap.from_bid(map_id) - - if not bmap: - # map not found? remove it from the - # pool and log this incident to console. - # NOTE: it's intentional that this removes - # it from not only this pool, but all pools. - # TODO: perhaps discord webhook? - log(f"Removing {map_id} from pool {self.name} (not found).", Ansi.LRED) - - await db_conn.execute( - "DELETE FROM tourney_pool_maps WHERE map_id = :map_id", - {"map_id": map_id}, - ) - continue - - key: tuple[Mods, int] = (Mods(row["mods"]), row["slot"]) - self.maps[key] = bmap - - class Slot: """An individual player slot in an osu! multiplayer match.""" @@ -247,7 +190,7 @@ def __init__( self.starting: StartingTimers | None = None self.seed = seed # used for mania random mod - self.pool: MapPool | None = None + self.tourney_pool: TourneyPool | None = None # scrimmage stuff self.is_scrimming = False diff --git a/app/repositories/tourney_pool_maps.py b/app/repositories/tourney_pool_maps.py new file mode 100644 index 000000000..0e473410b --- /dev/null +++ b/app/repositories/tourney_pool_maps.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import textwrap +from typing import Any +from typing import TypedDict +from typing import cast + +import app.state.services + +# +---------+---------+------+-----+---------+-------+ +# | Field | Type | Null | Key | Default | Extra | +# +---------+---------+------+-----+---------+-------+ +# | map_id | int | NO | PRI | NULL | | +# | pool_id | int | NO | PRI | NULL | | +# | mods | int | NO | | NULL | | +# | slot | tinyint | NO | | NULL | | +# +---------+---------+------+-----+---------+-------+ + + +class TourneyPoolMap(TypedDict): + map_id: int + pool_id: int + mods: int + slot: int + + +READ_PARAMS = textwrap.dedent( + """\ + map_id, pool_id, mods, slot + """, +) + + +async def create(map_id: int, pool_id: int, mods: int, slot: int) -> TourneyPoolMap: + """Create a new map pool entry in the database.""" + query = f"""\ + INSERT INTO tourney_pool_maps (map_id, pool_id, mods, slot) + VALUES (:map_id, :pool_id, :mods, :slot) + """ + params: dict[str, Any] = { + "map_id": map_id, + "pool_id": pool_id, + "mods": mods, + "slot": slot, + } + await app.state.services.database.execute(query, params) + + query = f"""\ + SELECT {READ_PARAMS} + FROM tourney_pool_maps + WHERE map_id = :map_id + AND pool_id = :pool_id + AND mods = :mods + AND slot = :slot + """ + params = { + "map_id": map_id, + "pool_id": pool_id, + "mods": mods, + "slot": slot, + } + tourney_pool_map = await app.state.services.database.fetch_one(query, params) + + assert tourney_pool_map is not None + return cast(TourneyPoolMap, dict(tourney_pool_map._mapping)) + + +async def fetch_many( + pool_id: int | None = None, + mods: int | None = None, + slot: int | None = None, + page: int | None = 1, + page_size: int | None = 50, +) -> list[TourneyPoolMap]: + """Fetch a list of map pool entries from the database.""" + query = f"""\ + SELECT {READ_PARAMS} + FROM tourney_pool_maps + WHERE pool_id = COALESCE(:pool_id, pool_id) + AND mods = COALESCE(:mods, mods) + AND slot = COALESCE(:slot, slot) + """ + params: dict[str, Any] = { + "pool_id": pool_id, + "mods": mods, + "slot": slot, + } + if page and page_size: + query += """\ + LIMIT :limit + OFFSET :offset + """ + params["limit"] = page_size + params["offset"] = (page - 1) * page_size + tourney_pool_maps = await app.state.services.database.fetch_all(query, params) + return cast( + list[TourneyPoolMap], + [dict(tourney_pool_map._mapping) for tourney_pool_map in tourney_pool_maps], + ) + + +async def fetch_by_pool_and_pick( + pool_id: int, + mods: int, + slot: int, +) -> TourneyPoolMap | None: + """Fetch a map pool entry by pool and pick from the database.""" + query = f"""\ + SELECT {READ_PARAMS} + FROM tourney_pool_maps + WHERE pool_id = :pool_id + AND mods = :mods + AND slot = :slot + """ + params: dict[str, Any] = { + "pool_id": pool_id, + "mods": mods, + "slot": slot, + } + tourney_pool_map = await app.state.services.database.fetch_one(query, params) + if tourney_pool_map is None: + return None + return cast(TourneyPoolMap, dict(tourney_pool_map._mapping)) + + +async def delete_map_from_pool(pool_id: int, map_id: int) -> TourneyPoolMap | None: + """Delete a map pool entry from a given tourney pool from the database.""" + query = f"""\ + SELECT {READ_PARAMS} + FROM tourney_pool_maps + WHERE pool_id = :pool_id + AND map_id = :map_id + """ + params: dict[str, Any] = { + "pool_id": pool_id, + "map_id": map_id, + } + tourney_pool_map = await app.state.services.database.fetch_one(query, params) + if tourney_pool_map is None: + return None + + query = f"""\ + DELETE FROM tourney_pool_maps + WHERE pool_id = :pool_id + AND map_id = :map_id + """ + params = { + "pool_id": pool_id, + "map_id": map_id, + } + await app.state.services.database.execute(query, params) + return cast(TourneyPoolMap, dict(tourney_pool_map._mapping)) + + +async def delete_all_in_pool(pool_id: int) -> list[TourneyPoolMap]: + """Delete all map pool entries from a given tourney pool from the database.""" + query = f"""\ + SELECT {READ_PARAMS} + FROM tourney_pool_maps + WHERE pool_id = :pool_id + """ + params: dict[str, Any] = { + "pool_id": pool_id, + } + tourney_pool_maps = await app.state.services.database.fetch_all(query, params) + if not tourney_pool_maps: + return [] + + query = f"""\ + DELETE FROM tourney_pool_maps + WHERE pool_id = :pool_id + """ + params = { + "pool_id": pool_id, + } + await app.state.services.database.execute(query, params) + return cast( + list[TourneyPoolMap], + [dict(tourney_pool_map._mapping) for tourney_pool_map in tourney_pool_maps], + ) diff --git a/app/repositories/tourney_pools.py b/app/repositories/tourney_pools.py new file mode 100644 index 000000000..dade1a2ed --- /dev/null +++ b/app/repositories/tourney_pools.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import textwrap +from datetime import datetime +from typing import Any +from typing import TypedDict +from typing import cast + +import app.state.services + +# +------------+-------------+------+-----+---------+----------------+ +# | Field | Type | Null | Key | Default | Extra | +# +------------+-------------+------+-----+---------+----------------+ +# | id | int | NO | PRI | NULL | auto_increment | +# | name | varchar(16) | NO | | NULL | | +# | created_at | datetime | NO | | NULL | | +# | created_by | int | NO | MUL | NULL | | +# +------------+-------------+------+-----+---------+----------------+ + + +class TourneyPool(TypedDict): + id: int + name: str + created_at: datetime + created_by: int + + +READ_PARAMS = textwrap.dedent( + """\ + id, name, created_at, created_by + """, +) + + +async def create(name: str, created_by: int) -> TourneyPool: + """Create a new tourney pool entry in the database.""" + query = f"""\ + INSERT INTO tourney_pools (name, created_at, created_by) + VALUES (:name, NOW(), :user_id) + """ + params: dict[str, Any] = { + "name": name, + "user_id": created_by, + } + rec_id = await app.state.services.database.execute(query, params) + + query = f"""\ + SELECT {READ_PARAMS} + FROM tourney_pools + WHERE id = :id + """ + params = { + "id": rec_id, + } + tourney_pool = await app.state.services.database.fetch_one(query, params) + + assert tourney_pool is not None + return cast(TourneyPool, dict(tourney_pool._mapping)) + + +async def fetch_many( + id: int | None = None, + created_by: int | None = None, + page: int | None = 1, + page_size: int | None = 50, +) -> list[TourneyPool]: + query = f"""\ + SELECT {READ_PARAMS} + FROM tourney_pools + WHERE id = COALESCE(:id, id) + AND created_by = COALESCE(:created_by, created_by) + """ + params: dict[str, Any] = { + "id": id, + "created_by": created_by, + } + if page and page_size: + query += """\ + LIMIT :limit + OFFSET :offset + """ + params["limit"] = page_size + params["offset"] = (page - 1) * page_size + tourney_pools = await app.state.services.database.fetch_all(query, params) + return [ + cast(TourneyPool, dict(tourney_pool._mapping)) for tourney_pool in tourney_pools + ] + + +async def fetch_by_name(name: str) -> TourneyPool | None: + """Fetch a tourney pool by name from the database.""" + query = f"""\ + SELECT {READ_PARAMS} + FROM tourney_pools + WHERE name = :name + """ + params: dict[str, Any] = { + "name": name, + } + tourney_pool = await app.state.services.database.fetch_one(query, params) + return ( + cast(TourneyPool, dict(tourney_pool._mapping)) + if tourney_pool is not None + else None + ) + + +async def fetch_by_id(id: int) -> TourneyPool | None: + """Fetch a tourney pool by id from the database.""" + query = f"""\ + SELECT {READ_PARAMS} + FROM tourney_pools + WHERE id = :id + """ + params: dict[str, Any] = { + "id": id, + } + tourney_pool = await app.state.services.database.fetch_one(query, params) + return ( + cast(TourneyPool, dict(tourney_pool._mapping)) + if tourney_pool is not None + else None + ) + + +async def delete_by_id(id: int) -> TourneyPool | None: + """Delete a tourney pool by id from the database.""" + query = f"""\ + SELECT {READ_PARAMS} + FROM tourney_pools + WHERE id = :id + """ + params: dict[str, Any] = { + "id": id, + } + tourney_pool = await app.state.services.database.fetch_one(query, params) + if tourney_pool is None: + return None + + query = f"""\ + DELETE FROM tourney_pools + WHERE id = :id + """ + params = { + "id": id, + } + await app.state.services.database.execute(query, params) + return cast(TourneyPool, dict(tourney_pool._mapping)) diff --git a/app/state/sessions.py b/app/state/sessions.py index 42934a062..4fe183b82 100644 --- a/app/state/sessions.py +++ b/app/state/sessions.py @@ -8,17 +8,14 @@ from app.logging import log from app.objects.collections import Channels from app.objects.collections import Clans -from app.objects.collections import MapPools from app.objects.collections import Matches from app.objects.collections import Players if TYPE_CHECKING: - from app.objects.achievement import Achievement from app.objects.player import Player players = Players() channels = Channels() -pools = MapPools() clans = Clans() matches = Matches() From 250cbfc660def1508c210f4d4eb1b67c8a129a7e Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Fri, 16 Feb 2024 16:35:15 -0500 Subject: [PATCH 11/48] Migrate beatmap objects off kwargs (& fix bugs found from migration) (#615) * migrate beatmap obj off kwargs * fix type errs found from migration * fix wrong type in repo map obj --- app/objects/beatmap.py | 82 ++++++++++++++++++++++++---------------- app/repositories/maps.py | 9 +++-- 2 files changed, 55 insertions(+), 36 deletions(-) diff --git a/app/objects/beatmap.py b/app/objects/beatmap.py index d475648d6..9e72cf822 100644 --- a/app/objects/beatmap.py +++ b/app/objects/beatmap.py @@ -277,38 +277,56 @@ class Beatmap: # XXX: This is set when a map's status is manually changed. """ - def __init__(self, map_set: BeatmapSet, **kwargs: Any) -> None: + def __init__( + self, + map_set: BeatmapSet, + md5: str = "", + id: int = 0, + set_id: int = 0, + artist: str = "", + title: str = "", + version: str = "", + creator: str = "", + last_update: datetime = DEFAULT_LAST_UPDATE, + total_length: int = 0, + max_combo: int = 0, + status: RankedStatus = RankedStatus.Pending, + frozen: bool = False, + plays: int = 0, + passes: int = 0, + mode: GameMode = GameMode.VANILLA_OSU, + bpm: float = 0.0, + cs: float = 0.0, + od: float = 0.0, + ar: float = 0.0, + hp: float = 0.0, + diff: float = 0.0, + filename: str = "", + ) -> None: self.set = map_set - self.md5 = kwargs.get("md5", "") - self.id = kwargs.get("id", 0) - self.set_id = kwargs.get("set_id", 0) - - self.artist = kwargs.get("artist", "") - self.title = kwargs.get("title", "") - self.version = kwargs.get("version", "") # diff name - self.creator = kwargs.get("creator", "") - - self.last_update = kwargs.get("last_update", DEFAULT_LAST_UPDATE) - self.total_length = kwargs.get("total_length", 0) - self.max_combo = kwargs.get("max_combo", 0) - - self.status = RankedStatus(kwargs.get("status", 0)) - self.frozen = kwargs.get("frozen", False) == 1 - - self.plays = kwargs.get("plays", 0) - self.passes = kwargs.get("passes", 0) - self.mode = GameMode(kwargs.get("mode", 0)) - self.bpm = kwargs.get("bpm", 0.0) - - self.cs = kwargs.get("cs", 0.0) - self.od = kwargs.get("od", 0.0) - self.ar = kwargs.get("ar", 0.0) - self.hp = kwargs.get("hp", 0.0) - - self.diff = kwargs.get("diff", 0.0) - - self.filename = kwargs.get("filename", "") + self.md5 = md5 + self.id = id + self.set_id = set_id + self.artist = artist + self.title = title + self.version = version + self.creator = creator + self.last_update = last_update + self.total_length = total_length + self.max_combo = max_combo + self.status = status + self.frozen = frozen + self.plays = plays + self.passes = passes + self.mode = mode + self.bpm = bpm + self.cs = cs + self.od = od + self.ar = ar + self.hp = hp + self.diff = diff + self.filename = filename def __repr__(self) -> str: return self.full_name @@ -840,11 +858,11 @@ async def _from_bsid_sql(cls, bsid: int) -> BeatmapSet | None: last_update=row["last_update"], total_length=row["total_length"], max_combo=row["max_combo"], - status=row["status"], + status=RankedStatus(row["status"]), frozen=row["frozen"], plays=row["plays"], passes=row["passes"], - mode=row["mode"], + mode=GameMode(row["mode"]), bpm=row["bpm"], cs=row["cs"], od=row["od"], diff --git a/app/repositories/maps.py b/app/repositories/maps.py index 6066c852d..d64a1e49f 100644 --- a/app/repositories/maps.py +++ b/app/repositories/maps.py @@ -1,6 +1,7 @@ from __future__ import annotations import textwrap +from datetime import datetime from typing import Any from typing import TypedDict from typing import cast @@ -57,7 +58,7 @@ class Map(TypedDict): version: str creator: str filename: str - last_update: str + last_update: datetime total_length: int max_combo: int frozen: bool @@ -82,7 +83,7 @@ class MapUpdateFields(TypedDict, total=False): version: str creator: str filename: str - last_update: str + last_update: datetime total_length: int max_combo: int frozen: bool @@ -108,7 +109,7 @@ async def create( version: str, creator: str, filename: str, - last_update: str, + last_update: datetime, total_length: int, max_combo: int, frozen: bool, @@ -298,7 +299,7 @@ async def update( version: str | _UnsetSentinel = UNSET, creator: str | _UnsetSentinel = UNSET, filename: str | _UnsetSentinel = UNSET, - last_update: str | _UnsetSentinel = UNSET, + last_update: datetime | _UnsetSentinel = UNSET, total_length: int | _UnsetSentinel = UNSET, max_combo: int | _UnsetSentinel = UNSET, frozen: bool | _UnsetSentinel = UNSET, From 0483d4494c5111dc8de6a27952420f50b98c602c Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Fri, 16 Feb 2024 16:42:26 -0500 Subject: [PATCH 12/48] Remove unnecessary features to decrease maintenance/complexity overhead. + Cleanup & more repo usage. (#612) * Start to remove unnecessary complex code & unnecessary features * remove usage of `__all__`, use players repo in players obj, general cleanups * remove more unnecessary code * app host/port required -- remove deprecated vals * unused import * remove deprecated env vars from docker compoes files * re-add warn on lack of internet connectivity * use better word * move internet check into startup dialog --- app/api/domains/osu.py | 9 +- app/api/init_api.py | 10 +- app/bg_loops.py | 2 - app/commands.py | 4 +- app/constants/clientflags.py | 2 - app/constants/gamemodes.py | 2 - app/constants/mods.py | 4 - app/constants/privileges.py | 2 - app/constants/regexes.py | 3 - app/discord.py | 14 --- app/objects/achievement.py | 2 - app/objects/channel.py | 2 - app/objects/clan.py | 2 - app/objects/player.py | 68 ++++++------- app/settings.py | 18 +--- app/utils.py | 188 +++++------------------------------ docker-compose.test.yml | 3 - docker-compose.yml | 3 - ext/Caddyfile | 2 - ext/nginx.conf.example | 2 - main.py | 105 +------------------ 21 files changed, 80 insertions(+), 367 deletions(-) diff --git a/app/api/domains/osu.py b/app/api/domains/osu.py index ac35a609a..b533e956b 100644 --- a/app/api/domains/osu.py +++ b/app/api/domains/osu.py @@ -1593,9 +1593,16 @@ async def get_screenshot( status_code=status.HTTP_404_NOT_FOUND, ) + if extension in ("jpg", "jpeg"): + media_type = "image/jpeg" + elif extension == "png": + media_type = "image/png" + else: + media_type = None + return FileResponse( path=screenshot_path, - media_type=app.utils.get_media_type(extension), + media_type=media_type, ) diff --git a/app/api/init_api.py b/app/api/init_api.py index d94244623..5f856364c 100644 --- a/app/api/init_api.py +++ b/app/api/init_api.py @@ -2,8 +2,10 @@ from __future__ import annotations import asyncio +import io import os import pprint +import sys from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any @@ -69,11 +71,13 @@ def openapi(self) -> dict[str, Any]: @asynccontextmanager async def lifespan(asgi_app: BanchoAPI) -> AsyncIterator[Never]: - app.utils.setup_runtime_environment() - app.utils.ensure_supported_platform() - app.utils.ensure_directory_structure() + if isinstance(sys.stdout, io.TextIOWrapper): + sys.stdout.reconfigure(encoding="utf-8") + + app.utils.ensure_persistent_volumes_are_available() app.state.loop = asyncio.get_running_loop() + if app.utils.is_running_as_admin(): log( "Running the server with root privileges is not recommended.", diff --git a/app/bg_loops.py b/app/bg_loops.py index 5a8a0dbe1..02f4a8a22 100644 --- a/app/bg_loops.py +++ b/app/bg_loops.py @@ -10,8 +10,6 @@ from app.logging import Ansi from app.logging import log -__all__ = ("initialize_housekeeping_tasks",) - OSU_CLIENT_MIN_PING_INTERVAL = 300000 // 1000 # defined by osu! diff --git a/app/commands.py b/app/commands.py index abf6da95b..262ae3765 100644 --- a/app/commands.py +++ b/app/commands.py @@ -15,6 +15,7 @@ from collections.abc import Sequence from dataclasses import dataclass from datetime import datetime +from datetime import timedelta from functools import wraps from pathlib import Path from time import perf_counter_ns as clock_ns @@ -64,7 +65,6 @@ from app.repositories import tourney_pool_maps as tourney_pool_maps_repo from app.repositories import tourney_pools as tourney_pools_repo from app.usecases.performance import ScoreParams -from app.utils import seconds_readable if TYPE_CHECKING: from app.objects.channel import Channel @@ -1258,7 +1258,7 @@ async def server(ctx: Context) -> str | None: return "\n".join( ( - f"{build_str} | uptime: {seconds_readable(uptime)}", + f"{build_str} | uptime: {timedelta(seconds=uptime)}", f"cpu: {cpu_info_str}", f"ram: {ram_info}", f"search mirror: {mirror_search_url} | download mirror: {mirror_download_url}", diff --git a/app/constants/clientflags.py b/app/constants/clientflags.py index 015c3df80..ec23b0f38 100644 --- a/app/constants/clientflags.py +++ b/app/constants/clientflags.py @@ -6,8 +6,6 @@ from app.utils import escape_enum from app.utils import pymysql_encode -__all__ = ("ClientFlags",) - @unique @pymysql_encode(escape_enum) diff --git a/app/constants/gamemodes.py b/app/constants/gamemodes.py index 0469a1083..768d40ced 100644 --- a/app/constants/gamemodes.py +++ b/app/constants/gamemodes.py @@ -8,8 +8,6 @@ from app.utils import escape_enum from app.utils import pymysql_encode -__all__ = ("GAMEMODE_REPR_LIST", "GameMode") - GAMEMODE_REPR_LIST = ( "vn!std", "vn!taiko", diff --git a/app/constants/mods.py b/app/constants/mods.py index 19907cb05..0392b5a2a 100644 --- a/app/constants/mods.py +++ b/app/constants/mods.py @@ -7,10 +7,6 @@ from app.utils import escape_enum from app.utils import pymysql_encode -__all__ = ("Mods",) - -# NOTE: the order of some of these = stupid - @unique @pymysql_encode(escape_enum) diff --git a/app/constants/privileges.py b/app/constants/privileges.py index 2b5d930ca..590c646c6 100644 --- a/app/constants/privileges.py +++ b/app/constants/privileges.py @@ -7,8 +7,6 @@ from app.utils import escape_enum from app.utils import pymysql_encode -__all__ = ("Privileges", "ClientPrivileges", "ClanPrivileges") - @unique @pymysql_encode(escape_enum) diff --git a/app/constants/regexes.py b/app/constants/regexes.py index 12c26ac1b..2b9368a7b 100644 --- a/app/constants/regexes.py +++ b/app/constants/regexes.py @@ -2,9 +2,6 @@ import re -__all__ = ("OSU_VERSION", "USERNAME", "EMAIL", "BEST_OF") - - OSU_VERSION = re.compile( r"^b(?P\d{8})(?:\.(?P\d))?" r"(?Pbeta|cuttingedge|dev|tourney)?$", diff --git a/app/discord.py b/app/discord.py index 496702e2d..87fd13e36 100644 --- a/app/discord.py +++ b/app/discord.py @@ -10,20 +10,6 @@ from app.state import services -# NOTE: this module currently only implements discord webhooks - -__all__ = ( - "Footer", - "Image", - "Thumbnail", - "Video", - "Provider", - "Author", - "Field", - "Embed", - "Webhook", -) - class Footer: def __init__(self, text: str, **kwargs: Any) -> None: diff --git a/app/objects/achievement.py b/app/objects/achievement.py index 41b0f99b6..df767232c 100644 --- a/app/objects/achievement.py +++ b/app/objects/achievement.py @@ -6,8 +6,6 @@ if TYPE_CHECKING: from app.objects.score import Score -__all__ = ("Achievement",) - class Achievement: """A class to represent a single osu! achievement.""" diff --git a/app/objects/channel.py b/app/objects/channel.py index d6fe22d62..41dbfa8cc 100644 --- a/app/objects/channel.py +++ b/app/objects/channel.py @@ -10,8 +10,6 @@ if TYPE_CHECKING: from app.objects.player import Player -__all__ = ("Channel",) - class Channel: """An osu! chat channel. diff --git a/app/objects/clan.py b/app/objects/clan.py index 0a69e591e..9ba79f539 100644 --- a/app/objects/clan.py +++ b/app/objects/clan.py @@ -11,8 +11,6 @@ if TYPE_CHECKING: from app.objects.player import Player -__all__ = ("Clan",) - class Clan: """A class to represent a single bancho.py clan.""" diff --git a/app/objects/player.py b/app/objects/player.py index 9c86e5a1f..5a5c4ca59 100644 --- a/app/objects/player.py +++ b/app/objects/player.py @@ -10,7 +10,6 @@ from enum import unique from functools import cached_property from typing import TYPE_CHECKING -from typing import Any from typing import TypedDict from typing import cast @@ -36,6 +35,7 @@ from app.objects.score import Grade from app.objects.score import Score from app.repositories import logs as logs_repo +from app.repositories import players as players_repo from app.repositories import stats as stats_repo from app.state.services import Geolocation from app.utils import escape_enum @@ -238,7 +238,6 @@ def __init__( self.id = id self.name = name - self.safe_name = self.make_safe(self.name) self.priv = priv self.pw_bcrypt = pw_bcrypt self.token = token @@ -282,8 +281,6 @@ def _noop_enqueue(data: bytes) -> None: self.pres_filter = PresenceFilter.Nil - # XXX: below is mostly implementation-specific & internal stuff - # store most recent score for each gamemode. self.recent_scores: dict[GameMode, Score | None] = { mode: None for mode in GameMode @@ -297,6 +294,10 @@ def _noop_enqueue(data: bytes) -> None: def __repr__(self) -> str: return f"<{self.name} ({self.id})>" + @property + def safe_name(self) -> str: + return make_safe_name(self.name) + @property def is_online(self) -> bool: return bool(self.token != "") @@ -384,11 +385,6 @@ def generate_token() -> str: """Generate a random uuid as a token.""" return str(uuid.uuid4()) - @staticmethod - def make_safe(name: str) -> str: - """Return a name safe for usage in sql.""" - return make_safe_name(name) - def logout(self) -> None: """Log `self` out of the server.""" # invalidate the user's token. @@ -421,28 +417,28 @@ def logout(self) -> None: async def update_privs(self, new: Privileges) -> None: """Update `self`'s privileges to `new`.""" + self.priv = new + if "bancho_priv" in vars(self): + del self.bancho_priv # wipe cached_property - await app.state.services.database.execute( - "UPDATE users SET priv = :priv WHERE id = :user_id", - {"priv": self.priv, "user_id": self.id}, + await players_repo.update( + id=self.id, + priv=self.priv, ) - if "bancho_priv" in self.__dict__: - del self.bancho_priv # wipe cached_property - async def add_privs(self, bits: Privileges) -> None: """Update `self`'s privileges, adding `bits`.""" + self.priv |= bits + if "bancho_priv" in vars(self): + del self.bancho_priv # wipe cached_property - await app.state.services.database.execute( - "UPDATE users SET priv = :priv WHERE id = :user_id", - {"priv": self.priv, "user_id": self.id}, + await players_repo.update( + id=self.id, + priv=self.priv, ) - if "bancho_priv" in self.__dict__: - del self.bancho_priv # wipe cached_property - if self.is_online: # if they're online, send a packet # to update their client-side privileges @@ -450,16 +446,16 @@ async def add_privs(self, bits: Privileges) -> None: async def remove_privs(self, bits: Privileges) -> None: """Update `self`'s privileges, removing `bits`.""" + self.priv &= ~bits + if "bancho_priv" in vars(self): + del self.bancho_priv # wipe cached_property - await app.state.services.database.execute( - "UPDATE users SET priv = :priv WHERE id = :user_id", - {"priv": self.priv, "user_id": self.id}, + await players_repo.update( + id=self.id, + priv=self.priv, ) - if "bancho_priv" in self.__dict__: - del self.bancho_priv # wipe cached_property - if self.is_online: # if they're online, send a packet # to update their client-side privileges @@ -542,9 +538,9 @@ async def silence(self, admin: Player, duration: float, reason: str) -> None: """Silence `self` for `duration` seconds, and log to sql.""" self.silence_end = int(time.time() + duration) - await app.state.services.database.execute( - "UPDATE users SET silence_end = :silence_end WHERE id = :user_id", - {"silence_end": self.silence_end, "user_id": self.id}, + await players_repo.update( + id=self.id, + silence_end=self.silence_end, ) await logs_repo.create( @@ -570,9 +566,9 @@ async def unsilence(self, admin: Player, reason: str) -> None: """Unsilence `self`, and log to sql.""" self.silence_end = int(time.time()) - await app.state.services.database.execute( - "UPDATE users SET silence_end = :silence_end WHERE id = :user_id", - {"silence_end": self.silence_end, "user_id": self.id}, + await players_repo.update( + id=self.id, + silence_end=self.silence_end, ) await logs_repo.create( @@ -1006,9 +1002,9 @@ async def stats_from_sql_full(self, db_conn: databases.core.Connection) -> None: def update_latest_activity_soon(self) -> None: """Update the player's latest activity in the database.""" - task = app.state.services.database.execute( - "UPDATE users SET latest_activity = UNIX_TIMESTAMP() WHERE id = :user_id", - {"user_id": self.id}, + task = players_repo.update( + id=self.id, + latest_activity=int(time.time()), ) app.state.loop.create_task(task) diff --git a/app/settings.py b/app/settings.py index 457cd74d7..0980c7842 100644 --- a/app/settings.py +++ b/app/settings.py @@ -2,30 +2,16 @@ import os import tomllib -from datetime import date from dotenv import load_dotenv from app.settings_utils import read_bool from app.settings_utils import read_list -from app.settings_utils import support_deprecated_vars load_dotenv() -APP_HOST = support_deprecated_vars( - new_name="APP_HOST", - deprecated_name="SERVER_ADDR", - until=date(2024, 1, 1), -) -APP_PORT = None -_app_port = support_deprecated_vars( - new_name="APP_PORT", - deprecated_name="SERVER_PORT", - until=date(2024, 1, 1), - allow_empty_string=True, -) -if _app_port: - APP_PORT = int(_app_port) +APP_HOST = os.environ["APP_HOST"] +APP_PORT = int(os.environ["APP_PORT"]) DB_HOST = os.environ["DB_HOST"] DB_PORT = int(os.environ["DB_PORT"]) diff --git a/app/utils.py b/app/utils.py index 875421956..f47726bee 100644 --- a/app/utils.py +++ b/app/utils.py @@ -2,13 +2,9 @@ import ctypes import inspect -import io -import ipaddress import os -import shutil import socket import sys -import types from collections.abc import Callable from pathlib import Path from typing import Any @@ -16,13 +12,11 @@ from typing import TypeVar import httpx -import orjson import pymysql import app.settings from app.logging import Ansi from app.logging import log -from app.logging import printc T = TypeVar("T") @@ -93,8 +87,9 @@ def download_achievement_images(achievements_path: Path) -> None: log("Failed to download achievement images.", Ansi.LRED) achievements_path.rmdir() - # allow passthrough (don't hard crash) as the server will - # _mostly_ work in this state. + # allow passthrough (don't hard crash). + # the server will *mostly* work in this state. + pass def download_default_avatar(default_avatar_path: Path) -> None: @@ -109,101 +104,30 @@ def download_default_avatar(default_avatar_path: Path) -> None: default_avatar_path.write_bytes(resp.content) -def seconds_readable(seconds: int) -> str: - """Turn seconds as an int into 'DD:HH:MM:SS'.""" - r: list[str] = [] - - days, seconds = divmod(seconds, 60 * 60 * 24) - if days: - r.append(f"{days:02d}") - - hours, seconds = divmod(seconds, 60 * 60) - if hours: - r.append(f"{hours:02d}") - - minutes, seconds = divmod(seconds, 60) - r.append(f"{minutes:02d}") - - r.append(f"{seconds % 60:02d}") - return ":".join(r) - - -def check_connection(timeout: float = 1.0) -> bool: +def has_internet_connectivity(timeout: float = 1.0) -> bool: """Check for an active internet connection.""" - # attempt to connect to common dns servers - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.settimeout(timeout) - for addr in ( - "1.1.1.1", - "1.0.0.1", # cloudflare - "8.8.8.8", - "8.8.4.4", - ): # google + COMMON_DNS_SERVERS = ( + # Cloudflare + "1.1.1.1", + "1.0.0.1", + # Google + "8.8.8.8", + "8.8.4.4", + ) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as client: + client.settimeout(timeout) + for host in COMMON_DNS_SERVERS: try: - sock.connect((addr, 53)) - return True + client.connect((host, 53)) except OSError: continue + else: + return True # all connections failed return False -def processes_listening_on_unix_socket(socket_path: str) -> int: - """Return the number of processes currently listening on this socket.""" - with open("/proc/net/unix") as f: # TODO: does this require root privs? - unix_socket_data = f.read().splitlines(keepends=False) - - process_count = 0 - - for line in unix_socket_data[1:]: - # 0000000045fe59d0: 00000002 00000000 00010000 0005 01 17665 /tmp/bancho.sock - tokens = line.split() - - # unused params - # ( - # kernel_table_slot_num, - # ref_count, - # protocol, - # flags, - # sock_type, - # sock_state, - # inode, - # ) = tokens[0:7] - - # path may or may not be set - if len(tokens) == 8 and tokens[7] == socket_path: - process_count += 1 - - return process_count - - -def running_via_asgi_webserver() -> bool: - return any(map(sys.argv[0].endswith, ("hypercorn", "uvicorn"))) - - -def _install_synchronous_excepthook() -> None: - """Install a thin wrapper for sys.excepthook to catch bancho-related stuff.""" - real_excepthook = sys.excepthook # backup - - def _excepthook( - type_: type[BaseException], - value: BaseException, - traceback: types.TracebackType | None, - ) -> None: - if type_ is KeyboardInterrupt: - print("\33[2K\r", end="Aborted startup.") - return - - printc( - f"bancho.py v{app.settings.VERSION} ran into an issue before starting up :(", - Ansi.RED, - ) - real_excepthook(type_, value, traceback) - - sys.excepthook = _excepthook - - class FrameInfo(TypedDict): function: str filename: str @@ -236,21 +160,6 @@ def get_appropriate_stacktrace() -> list[FrameInfo]: ] -def is_valid_inet_address(address: str) -> bool: - """Check whether address is a valid ipv(4/6) address.""" - try: - ipaddress.ip_address(address) - except ValueError: - return False - else: - return True - - -def is_valid_unix_address(address: str) -> bool: - """Check whether address is a valid unix address.""" - return address.endswith(".sock") - - def pymysql_encode( conv: Callable[[Any, dict[object, object] | None], str], ) -> Callable[[type[T]], type[T]]: @@ -270,56 +179,25 @@ def escape_enum( return str(int(val)) -def ensure_supported_platform() -> None: - """Ensure we're running on an appropriate platform for bancho.py.""" - if sys.version_info < (3, 11): - log( - "bancho.py uses many modern python features, " - "and the minimum python version is 3.11.", - Ansi.LRED, - ) - raise SystemExit(1) - - -def ensure_directory_structure() -> None: - """Ensure the .data directory and git submodules are ready.""" - # create /.data and its subdirectories. +def ensure_persistent_volumes_are_available() -> None: + # create /.data directory DATA_PATH.mkdir(exist_ok=True) + # create /.data/... subdirectories for sub_dir in ("avatars", "logs", "osu", "osr", "ss"): subdir = DATA_PATH / sub_dir subdir.mkdir(exist_ok=True) + # download achievement images from osu! if not ACHIEVEMENTS_ASSETS_PATH.exists(): ACHIEVEMENTS_ASSETS_PATH.mkdir(parents=True) download_achievement_images(ACHIEVEMENTS_ASSETS_PATH) + # download a default avatar image for new users if not DEFAULT_AVATAR_PATH.exists(): download_default_avatar(DEFAULT_AVATAR_PATH) -def setup_runtime_environment() -> None: - """Configure the server's runtime environment.""" - # install a hook to catch exceptions outside the event loop, - # which will handle various situations where the error details - # can be cleared up for the developer; for example it will explain - # that the config has been updated when an unknown attribute is - # accessed, so the developer knows what to do immediately. - _install_synchronous_excepthook() - - # we print utf-8 content quite often, so configure sys.stdout - if isinstance(sys.stdout, io.TextIOWrapper): - sys.stdout.reconfigure(encoding="utf-8") - - -def _install_debugging_hooks() -> None: - """Change internals to help with debugging & active development.""" - if DEBUG_HOOKS_PATH.exists(): - from _testing import runtime # type: ignore - - runtime.setup() - - def is_running_as_admin() -> bool: try: return os.geteuid() == 0 # type: ignore[attr-defined, no-any-return, unused-ignore] @@ -356,24 +234,8 @@ def display_startup_dialog() -> None: Ansi.LRED, ) - -def create_config_from_default() -> None: - """Create the default config from ext/config.sample.py""" - shutil.copy("ext/config.sample.py", "config.py") - - -def orjson_serialize_to_str(*args: Any, **kwargs: Any) -> str: - return orjson.dumps(*args, **kwargs).decode() - - -def get_media_type(extension: str) -> str | None: - if extension in ("jpg", "jpeg"): - return "image/jpeg" - elif extension == "png": - return "image/png" - - # return none, fastapi will attempt to figure it out - return None + if not has_internet_connectivity(): + log("No internet connectivity detected", Ansi.LYELLOW) def has_jpeg_headers_and_trailers(data_view: memoryview) -> bool: diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 8c0a92d9c..1afd3eb54 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -90,9 +90,6 @@ services: - SSL_CERT_PATH=${SSL_CERT_PATH} - SSL_KEY_PATH=${SSL_KEY_PATH} - DEVELOPER_MODE=${DEVELOPER_MODE} - # On deprecation path - - SERVER_ADDR=${SERVER_ADDR} - - SERVER_PORT=${SERVER_PORT} volumes: test-data: diff --git a/docker-compose.yml b/docker-compose.yml index 7ac66d7a7..f9210f73d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -86,9 +86,6 @@ services: - SSL_CERT_PATH=${SSL_CERT_PATH} - SSL_KEY_PATH=${SSL_KEY_PATH} - DEVELOPER_MODE=${DEVELOPER_MODE} - # On deprecation path - - SERVER_ADDR=${SERVER_ADDR} - - SERVER_PORT=${SERVER_PORT} volumes: data: diff --git a/ext/Caddyfile b/ext/Caddyfile index 3057b5659..1eb5b1872 100644 --- a/ext/Caddyfile +++ b/ext/Caddyfile @@ -8,8 +8,6 @@ c.{$DOMAIN}, ce.{$DOMAIN}, c4.{$DOMAIN}, osu.{$DOMAIN}, b.{$DOMAIN}, api.{$DOMAIN} { encode gzip - # NOTE: if you wish to use unix sockets, - # reverse_proxy * unix//tmp/bancho.sock { reverse_proxy * 127.0.0.1:{$APP_PORT} { header_up X-Real-IP {remote_host} } diff --git a/ext/nginx.conf.example b/ext/nginx.conf.example index 8a9de354f..188538239 100644 --- a/ext/nginx.conf.example +++ b/ext/nginx.conf.example @@ -4,8 +4,6 @@ upstream bancho { server 127.0.0.1:${APP_PORT}; - # NOTE: if you wish to use unix sockets, - # server unix:/tmp/bancho.sock fail_timeout=0; } server { diff --git a/main.py b/main.py index 500e59cc5..6fbaacf90 100755 --- a/main.py +++ b/main.py @@ -1,111 +1,16 @@ #!/usr/bin/env python3.11 -"""main.py - a user-friendly, safe wrapper around bancho.py's runtime - -bancho.py is an in-progress osu! server implementation for developers of all levels -of experience interested in hosting their own osu private server instance(s). - -the project is developed primarily by the Akatsuki (https://akatsuki.pw) team, -and our aim is to create the most easily maintainable, reliable, and feature-rich -osu! server implementation available. - -we're also fully open source! -https://github.com/osuAkatsuki/bancho.py -""" from __future__ import annotations -__author__ = "Joshua Smith (cmyui)" -__email__ = "josh@akatsuki.gg" -__discord__ = "cmyui#0425" - -import os - -# set working directory to the bancho/ directory. -os.chdir(os.path.dirname(os.path.realpath(__file__))) - -import argparse import logging -import sys -from collections.abc import Sequence import uvicorn import app.settings import app.utils -from app.logging import Ansi -from app.logging import log - - -def main(argv: Sequence[str]) -> int: - """Ensure runtime environment is ready, and start the server.""" - - parser = argparse.ArgumentParser( - description=("An open-source osu! server implementation by Akatsuki."), - ) - parser.add_argument( - "-V", - "--version", - action="version", - version=f"%(prog)s v{app.settings.VERSION}", - ) - - parser.parse_args(argv) - - """ Server should be safe to start """ - # install any debugging hooks from - # _testing/runtime.py, if present - app.utils._install_debugging_hooks() - # check our internet connection status - if not app.utils.check_connection(timeout=1.5): - log("No internet connection available.", Ansi.LYELLOW) - - # show info & any contextual warnings. +def main() -> int: app.utils.display_startup_dialog() - - # the server supports both inet and unix sockets. - - uds = None - host = None - port = None - - if ( - app.utils.is_valid_inet_address(app.settings.APP_HOST) - and app.settings.APP_PORT is not None - ): - host = app.settings.APP_HOST - port = app.settings.APP_PORT - elif ( - app.utils.is_valid_unix_address(app.settings.APP_HOST) - and app.settings.APP_PORT is None - ): - uds = app.settings.APP_HOST - - # make sure the socket file does not exist on disk and can be bound - # (uvicorn currently does not do this for us, and will raise an exc) - if os.path.exists(app.settings.APP_HOST): - try: - if ( - app.utils.processes_listening_on_unix_socket(app.settings.APP_HOST) - != 0 - ): - log( - f"There are other processes listening on {app.settings.APP_HOST}.\n" - f"If you've lost it, bancho.py can be killed gracefully with SIGINT.", - Ansi.LRED, - ) - return 1 - except Exception: - pass - else: - os.remove(app.settings.APP_HOST) - else: - raise ValueError( - "%r does not appear to be an IPv4, IPv6 or Unix address" - % app.settings.APP_HOST, - ) from None - - # run the server indefinitely uvicorn.run( "app.api.init_api:asgi_app", reload=app.settings.DEBUG, @@ -113,13 +18,11 @@ def main(argv: Sequence[str]) -> int: server_header=False, date_header=False, headers=[("bancho-version", app.settings.VERSION)], - uds=uds, - host=host or "127.0.0.1", # uvicorn defaults - port=port or 8000, # uvicorn defaults + host=app.settings.APP_HOST, + port=app.settings.APP_PORT, ) - return 0 if __name__ == "__main__": - raise SystemExit(main(sys.argv[1:])) + exit(main()) From 51299b7d553d3b7a7443c5db5718ab461a0226e8 Mon Sep 17 00:00:00 2001 From: cmyui Date: Fri, 16 Feb 2024 17:34:57 -0500 Subject: [PATCH 13/48] maps repo delete() bugfix --- app/repositories/maps.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/app/repositories/maps.py b/app/repositories/maps.py index d64a1e49f..049e8318e 100644 --- a/app/repositories/maps.py +++ b/app/repositories/maps.py @@ -392,8 +392,8 @@ async def delete(id: int) -> Map | None: params: dict[str, Any] = { "id": id, } - rec = await app.state.services.database.fetch_one(query, params) - if rec is None: + map = await app.state.services.database.fetch_one(query, params) + if map is None: return None query = """\ @@ -403,5 +403,5 @@ async def delete(id: int) -> Map | None: params = { "id": id, } - map = await app.state.services.database.execute(query, params) + await app.state.services.database.execute(query, params) return cast(Map, dict(map._mapping)) if map is not None else None From 529c7908cdbbac0ea131124a000bb8e006aefb08 Mon Sep 17 00:00:00 2001 From: cmyui Date: Fri, 16 Feb 2024 18:16:53 -0500 Subject: [PATCH 14/48] add types-PyYAML to dev dependencies --- poetry.lock | 13 ++++++++++++- pyproject.toml | 1 + 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index c39627f27..51c6809af 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1744,6 +1744,17 @@ files = [ {file = "types_PyMySQL-1.1.0.1-py3-none-any.whl", hash = "sha256:9aec9ee0453314d477ef26e5832b4a992bc4cc3557358d62b0fe4af760a7728f"}, ] +[[package]] +name = "types-pyyaml" +version = "6.0.12.12" +description = "Typing stubs for PyYAML" +optional = false +python-versions = "*" +files = [ + {file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"}, + {file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"}, +] + [[package]] name = "types-requests" version = "2.31.0.20240125" @@ -1899,4 +1910,4 @@ cython = "*" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "64468dbf2aab41fb6a602d155a304aa26d7562256dde56ead7973e8991192b9a" +content-hash = "7b491f78b9acb025f1df3a739965f32a40d1d108ef57bb35be2b57a3380f95f5" diff --git a/pyproject.toml b/pyproject.toml index 9ab2cdcd8..6f9946b13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ types-psutil = "5.9.5.20240205" types-pymysql = "1.1.0.1" types-requests = "2.31.0.20240125" mypy = "1.8.0" +types-pyyaml = "^6.0.12.12" [build-system] requires = ["poetry-core"] From 3eb40abb4f184717841bac547ef155bf1d01342f Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Fri, 16 Feb 2024 18:37:40 -0500 Subject: [PATCH 15/48] handle private match history (#616) --- app/api/domains/cho.py | 3 ++- app/objects/match.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/app/api/domains/cho.py b/app/api/domains/cho.py index 623a65210..05a972211 100644 --- a/app/api/domains/cho.py +++ b/app/api/domains/cho.py @@ -1392,7 +1392,8 @@ async def handle(self, player: Player) -> None: match = Match( id=match_id, name=self.match_data.name, - password=self.match_data.passwd, + password=self.match_data.passwd.removesuffix("//private"), + has_public_history=not self.match_data.passwd.endswith("//private"), map_name=self.match_data.map_name, map_id=self.match_data.map_id, map_md5=self.match_data.map_md5, diff --git a/app/objects/match.py b/app/objects/match.py index 47e41060f..750f313d6 100644 --- a/app/objects/match.py +++ b/app/objects/match.py @@ -151,6 +151,7 @@ def __init__( id: int, name: str, password: str, + has_public_history: bool, map_name: str, map_id: int, map_md5: str, @@ -166,6 +167,7 @@ def __init__( self.id = id self.name = name self.passwd = password + self.has_public_history = has_public_history self.host_id = host_id self._refs: set[Player] = set() From 44b0fddc07ebbc5094063f5389e953656374fb23 Mon Sep 17 00:00:00 2001 From: cmyui Date: Fri, 16 Feb 2024 18:49:40 -0500 Subject: [PATCH 16/48] fix mp scrim responses --- app/objects/match.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/app/objects/match.py b/app/objects/match.py index 750f313d6..e78eefe2e 100644 --- a/app/objects/match.py +++ b/app/objects/match.py @@ -10,16 +10,12 @@ from typing import TYPE_CHECKING from typing import TypedDict -import databases.core - import app.packets import app.settings import app.state from app.constants import regexes from app.constants.gamemodes import GameMode from app.constants.mods import Mods -from app.logging import Ansi -from app.logging import log from app.objects.beatmap import Beatmap from app.repositories.tourney_pools import TourneyPool from app.utils import escape_enum @@ -466,6 +462,8 @@ def add_suffix(score: int | float) -> str | int | float: return str(score) if ffa: + from app.objects.player import Player + assert isinstance(winner, Player) msg.append( From 92fd4201b9decd21d337f4e06a578238ac933e5a Mon Sep 17 00:00:00 2001 From: cmyui Date: Fri, 16 Feb 2024 21:02:43 -0500 Subject: [PATCH 17/48] remove unused code --- app/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/app/utils.py b/app/utils.py index f47726bee..0e9856a20 100644 --- a/app/utils.py +++ b/app/utils.py @@ -24,7 +24,6 @@ DATA_PATH = Path.cwd() / ".data" ACHIEVEMENTS_ASSETS_PATH = DATA_PATH / "assets/medals/client" DEFAULT_AVATAR_PATH = DATA_PATH / "avatars/default.jpg" -DEBUG_HOOKS_PATH = Path.cwd() / "_testing/runtime.py" def make_safe_name(name: str) -> str: From d29c4348631a72bf3ddf7ee0b9f8a89602021ac5 Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Fri, 16 Feb 2024 21:20:23 -0500 Subject: [PATCH 18/48] Scrim fixes and cleanup (#617) * fix bug in mp scrim results * small refactor to unnest some code --- app/objects/match.py | 229 +++++++++++++++++++++---------------------- 1 file changed, 114 insertions(+), 115 deletions(-) diff --git a/app/objects/match.py b/app/objects/match.py index e78eefe2e..69b4b1674 100644 --- a/app/objects/match.py +++ b/app/objects/match.py @@ -372,15 +372,14 @@ async def await_submissions( while True: assert s.player is not None rc_score = s.player.recent_score - assert rc_score is not None max_age = datetime.now() - timedelta( seconds=bmap.total_length + time_waited + 0.5, ) - assert rc_score.bmap is not None if ( rc_score + and rc_score.bmap and rc_score.bmap.md5 == self.map_md5 and rc_score.server_time > max_age ): @@ -432,122 +431,122 @@ async def update_matchpoints(self, was_playing: Sequence[Slot]) -> None: for player in didnt_submit: self.chat.send_bot(f"{player} didn't submit a score (timeout: 10s).") - if scores: - ffa = self.team_type in ( - MatchTeamTypes.head_to_head, - MatchTeamTypes.tag_coop, + if not scores: + self.chat.send_bot("Scores could not be calculated.") + return None + + ffa = self.team_type in ( + MatchTeamTypes.head_to_head, + MatchTeamTypes.tag_coop, + ) + + # all scores are equal, it was a tie. + if len(scores) != 1 and len(set(scores.values())) == 1: + self.winners.append(None) + self.chat.send_bot("The point has ended in a tie!") + return None + + # Find the winner & increment their matchpoints. + winner: Player | MatchTeams = max(scores, key=lambda k: scores[k]) + self.winners.append(winner) + self.match_points[winner] += 1 + + msg: list[str] = [] + + def add_suffix(score: int | float) -> str | int | float: + if self.use_pp_scoring: + return f"{score:.2f}pp" + elif self.win_condition == MatchWinConditions.accuracy: + return f"{score:.2f}%" + elif self.win_condition == MatchWinConditions.combo: + return f"{score}x" + else: + return str(score) + + if ffa: + from app.objects.player import Player + + assert isinstance(winner, Player) + + msg.append( + f"{winner.name} takes the point! ({add_suffix(scores[winner])} " + f"[Match avg. {add_suffix(sum(scores.values()) / len(scores))}])", ) - # all scores are equal, it was a tie. - if len(scores) != 1 and len(set(scores.values())) == 1: - self.winners.append(None) - self.chat.send_bot("The point has ended in a tie!") - return None - - # Find the winner & increment their matchpoints. - winner: Player | MatchTeams = max(scores, key=lambda k: scores[k]) - self.winners.append(winner) - self.match_points[winner] += 1 - - msg: list[str] = [] - - def add_suffix(score: int | float) -> str | int | float: - if self.use_pp_scoring: - return f"{score:.2f}pp" - elif self.win_condition == MatchWinConditions.accuracy: - return f"{score:.2f}%" - elif self.win_condition == MatchWinConditions.combo: - return f"{score}x" - else: - return str(score) - - if ffa: - from app.objects.player import Player - - assert isinstance(winner, Player) + wmp = self.match_points[winner] + + # check if match point #1 has enough points to win. + if self.winning_pts and wmp == self.winning_pts: + # we have a champion, announce & reset our match. + self.is_scrimming = False + self.reset_scrim() + self.bans.clear() + + m = f"{winner.name} takes the match! Congratulations!" + else: + # no winner, just announce the match points so far. + # for ffa, we'll only announce the top <=3 players. + m_points = sorted(self.match_points.items(), key=lambda x: x[1]) + m = f"Total Score: {' | '.join([f'{k.name} - {v}' for k, v in m_points])}" + + msg.append(m) + del m + + else: # teams + assert isinstance(winner, MatchTeams) + + r_match = regexes.TOURNEY_MATCHNAME.match(self.name) + if r_match: + match_name = r_match["name"] + team_names = { + MatchTeams.blue: r_match["T1"], + MatchTeams.red: r_match["T2"], + } + else: + match_name = self.name + team_names = {MatchTeams.blue: "Blue", MatchTeams.red: "Red"} + + # teams are binary, so we have a loser. + if winner is MatchTeams.blue: + loser = MatchTeams.red + else: + loser = MatchTeams.blue + + # from match name if available, else blue/red. + wname = team_names[winner] + lname = team_names[loser] + + # scores from the recent play + # (according to win condition) + ws = add_suffix(scores[winner]) + ls = add_suffix(scores[loser]) + + # total win/loss score in the match. + wmp = self.match_points[winner] + lmp = self.match_points[loser] + + # announce the score for the most recent play. + msg.append(f"{wname} takes the point! ({ws} vs. {ls})") + + # check if the winner has enough match points to win the match. + if self.winning_pts and wmp == self.winning_pts: + # we have a champion, announce & reset our match. + self.is_scrimming = False + self.reset_scrim() msg.append( - f"{winner.name} takes the point! ({add_suffix(scores[winner])} " - f"[Match avg. {add_suffix(sum(scores.values()) / len(scores))}])", - ) - - wmp = self.match_points[winner] - - # check if match point #1 has enough points to win. - if self.winning_pts and wmp == self.winning_pts: - # we have a champion, announce & reset our match. - self.is_scrimming = False - self.reset_scrim() - self.bans.clear() - - m = f"{winner.name} takes the match! Congratulations!" - else: - # no winner, just announce the match points so far. - # for ffa, we'll only announce the top <=3 players. - m_points = sorted(self.match_points.items(), key=lambda x: x[1]) - m = f"Total Score: {' | '.join([f'{k.name} - {v}' for k, v in m_points])}" - - msg.append(m) - del m - - else: # teams - assert isinstance(winner, MatchTeams) - - r_match = regexes.TOURNEY_MATCHNAME.match(self.name) - if r_match: - match_name = r_match["name"] - team_names = { - MatchTeams.blue: r_match["T1"], - MatchTeams.red: r_match["T2"], - } - else: - match_name = self.name - team_names = {MatchTeams.blue: "Blue", MatchTeams.red: "Red"} - - # teams are binary, so we have a loser. - if winner is MatchTeams.blue: - loser = MatchTeams.red - else: - loser = MatchTeams.blue - - # from match name if available, else blue/red. - wname = team_names[winner] - lname = team_names[loser] - - # scores from the recent play - # (according to win condition) - ws = add_suffix(scores[winner]) - ls = add_suffix(scores[loser]) - - # total win/loss score in the match. - wmp = self.match_points[winner] - lmp = self.match_points[loser] - - # announce the score for the most recent play. - msg.append(f"{wname} takes the point! ({ws} vs. {ls})") - - # check if the winner has enough match points to win the match. - if self.winning_pts and wmp == self.winning_pts: - # we have a champion, announce & reset our match. - self.is_scrimming = False - self.reset_scrim() - - msg.append( - f"{wname} takes the match, finishing {match_name} " - f"with a score of {wmp} - {lmp}! Congratulations!", - ) - else: - # no winner, just announce the match points so far. - msg.append(f"Total Score: {wname} | {wmp} - {lmp} | {lname}") - - if didnt_submit: - self.chat.send_bot( - "If you'd like to perform a rematch, " - "please use the `!mp rematch` command.", + f"{wname} takes the match, finishing {match_name} " + f"with a score of {wmp} - {lmp}! Congratulations!", ) + else: + # no winner, just announce the match points so far. + msg.append(f"Total Score: {wname} | {wmp} - {lmp} | {lname}") + + if didnt_submit: + self.chat.send_bot( + "If you'd like to perform a rematch, " + "please use the `!mp rematch` command.", + ) - for line in msg: - self.chat.send_bot(line) - - else: - self.chat.send_bot("Scores could not be calculated.") + for line in msg: + self.chat.send_bot(line) From 215bd0fa9dffd92fc40459741126a767bf618401 Mon Sep 17 00:00:00 2001 From: cmyui Date: Fri, 16 Feb 2024 23:06:29 -0500 Subject: [PATCH 19/48] remove remaining `__all__` consts --- app/objects/player.py | 2 -- app/objects/score.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/app/objects/player.py b/app/objects/player.py index 5a5c4ca59..c4408d522 100644 --- a/app/objects/player.py +++ b/app/objects/player.py @@ -48,8 +48,6 @@ from app.objects.clan import Clan from app.objects.score import Score -__all__ = ("ModeData", "Status", "Player") - @unique @pymysql_encode(escape_enum) diff --git a/app/objects/score.py b/app/objects/score.py index 4269a84c9..1051e18d0 100644 --- a/app/objects/score.py +++ b/app/objects/score.py @@ -23,8 +23,6 @@ if TYPE_CHECKING: from app.objects.player import Player -__all__ = ("Grade", "SubmissionStatus", "Score") - BEATMAPS_PATH = Path.cwd() / ".data/osu" From 6a1e1cdb2d277480af451436e147f161f728dd63 Mon Sep 17 00:00:00 2001 From: cmyui Date: Fri, 16 Feb 2024 23:08:31 -0500 Subject: [PATCH 20/48] remove unused variables --- app/api/domains/cho.py | 1 - app/api/domains/osu.py | 1 - app/api/init_api.py | 1 - app/repositories/favourites.py | 3 +-- tests/integration/domains/osu_test.py | 2 +- 5 files changed, 2 insertions(+), 6 deletions(-) diff --git a/app/api/domains/cho.py b/app/api/domains/cho.py index 05a972211..5a49de879 100644 --- a/app/api/domains/cho.py +++ b/app/api/domains/cho.py @@ -4,7 +4,6 @@ import asyncio import re -import string import struct import time from collections.abc import Callable diff --git a/app/api/domains/osu.py b/app/api/domains/osu.py index b533e956b..068d4320a 100644 --- a/app/api/domains/osu.py +++ b/app/api/domains/osu.py @@ -6,7 +6,6 @@ import hashlib import random import secrets -from base64 import b64decode from collections import defaultdict from collections.abc import Awaitable from collections.abc import Callable diff --git a/app/api/init_api.py b/app/api/init_api.py index 5f856364c..91b7cd755 100644 --- a/app/api/init_api.py +++ b/app/api/init_api.py @@ -3,7 +3,6 @@ import asyncio import io -import os import pprint import sys from collections.abc import AsyncIterator diff --git a/app/repositories/favourites.py b/app/repositories/favourites.py index 1f01fd8dd..7cdb52053 100644 --- a/app/repositories/favourites.py +++ b/app/repositories/favourites.py @@ -1,7 +1,6 @@ from __future__ import annotations import textwrap -from datetime import datetime from typing import Any from typing import TypedDict from typing import cast @@ -42,7 +41,7 @@ async def create( "userid": userid, "setid": setid, } - rec_id = await app.state.services.database.execute(query, params) + await app.state.services.database.execute(query, params) query = f"""\ SELECT {READ_PARAMS} diff --git a/tests/integration/domains/osu_test.py b/tests/integration/domains/osu_test.py index aa6a225e8..49adc7efa 100644 --- a/tests/integration/domains/osu_test.py +++ b/tests/integration/domains/osu_test.py @@ -104,7 +104,7 @@ async def test_score_submission( # cho token must be valid uuid try: - session_token = UUID(response.headers["cho-token"]) + UUID(response.headers["cho-token"]) except ValueError: raise AssertionError( "cho-token is not a valid uuid", From 86af41a35095396a446e531c7a9ea60d92b060df Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Sat, 17 Feb 2024 17:29:49 -0500 Subject: [PATCH 21/48] Rename players_repo to users_repo (#621) --- app/api/domains/cho.py | 8 ++++---- app/api/domains/osu.py | 8 ++++---- app/api/v1/api.py | 12 ++++++------ app/api/v2/players.py | 8 ++++---- app/commands.py | 14 +++++++------- app/objects/clan.py | 6 +++--- app/objects/collections.py | 8 ++++---- app/objects/player.py | 14 +++++++------- app/repositories/{players.py => users.py} | 0 9 files changed, 39 insertions(+), 39 deletions(-) rename app/repositories/{players.py => users.py} (100%) diff --git a/app/api/domains/cho.py b/app/api/domains/cho.py index 5a49de879..cfe460055 100644 --- a/app/api/domains/cho.py +++ b/app/api/domains/cho.py @@ -63,7 +63,7 @@ from app.repositories import client_hashes as client_hashes_repo from app.repositories import ingame_logins as logins_repo from app.repositories import mail as mail_repo -from app.repositories import players as players_repo +from app.repositories import users as users_repo from app.state import services from app.usecases.performance import ScoreParams @@ -599,8 +599,8 @@ def parse_adapters_string(adapters_string: str) -> tuple[list[str], bool]: async def authenticate( username: str, untrusted_password: bytes, -) -> players_repo.Player | None: - user_info = await players_repo.fetch_one( +) -> users_repo.Player | None: + user_info = await users_repo.fetch_one( name=username, fetch_all_fields=True, ) @@ -817,7 +817,7 @@ async def handle_osu_login_request( # country wasn't stored on registration. log(f"Fixing {login_data['username']}'s country.", Ansi.LGREEN) - await players_repo.update( + await users_repo.update( id=user_info["id"], country=geoloc["country"]["acronym"], ) diff --git a/app/api/domains/osu.py b/app/api/domains/osu.py index 068d4320a..517806bf9 100644 --- a/app/api/domains/osu.py +++ b/app/api/domains/osu.py @@ -64,10 +64,10 @@ from app.repositories import favourites as favourites_repo from app.repositories import mail as mail_repo from app.repositories import maps as maps_repo -from app.repositories import players as players_repo from app.repositories import ratings as ratings_repo from app.repositories import scores as scores_repo from app.repositories import stats as stats_repo +from app.repositories import users as users_repo from app.repositories.achievements import Achievement from app.usecases import achievements as achievements_usecases from app.usecases import user_achievements as user_achievements_usecases @@ -1705,7 +1705,7 @@ async def register_account( errors["username"].append("Disallowed username; pick another.") if "username" not in errors: - if await players_repo.fetch_one(name=username): + if await users_repo.fetch_one(name=username): errors["username"].append("Username already taken by another player.") # Emails must: @@ -1714,7 +1714,7 @@ async def register_account( if not regexes.EMAIL.match(email): errors["user_email"].append("Invalid email syntax.") else: - if await players_repo.fetch_one(email=email): + if await users_repo.fetch_one(email=email): errors["user_email"].append("Email already taken by another player.") # Passwords must: @@ -1754,7 +1754,7 @@ async def register_account( async with app.state.services.database.transaction(): # add to `users` table. - player = await players_repo.create( + player = await users_repo.create( name=username, email=email, pw_bcrypt=pw_bcrypt, diff --git a/app/api/v1/api.py b/app/api/v1/api.py index 20469a57c..4a2aff77e 100644 --- a/app/api/v1/api.py +++ b/app/api/v1/api.py @@ -26,11 +26,11 @@ from app.objects.beatmap import ensure_osu_file_is_available from app.objects.clan import Clan from app.objects.player import Player -from app.repositories import players as players_repo from app.repositories import scores as scores_repo from app.repositories import stats as stats_repo from app.repositories import tourney_pool_maps as tourney_pool_maps_repo from app.repositories import tourney_pools as tourney_pools_repo +from app.repositories import users as users_repo from app.usecases.performance import ScoreParams AVATARS_PATH = SystemPath.cwd() / ".data/avatars" @@ -228,7 +228,7 @@ async def api_get_player_count() -> Response: "counts": { # -1 for the bot, who is always online "online": len(app.state.sessions.players.unrestricted) - 1, - "total": await players_repo.fetch_count(), + "total": await users_repo.fetch_count(), }, }, ) @@ -249,9 +249,9 @@ async def api_get_player_info( # get user info from username or user id if username: - user_info = await players_repo.fetch_one(name=username) + user_info = await users_repo.fetch_one(name=username) else: # if user_id - user_info = await players_repo.fetch_one(id=user_id) + user_info = await users_repo.fetch_one(id=user_id) if user_info is None: return ORJSONResponse( @@ -339,9 +339,9 @@ async def api_get_player_status( # no such player online, return their last seen time if they exist in sql if username: - row = await players_repo.fetch_one(name=username) + row = await users_repo.fetch_one(name=username) else: # if userid - row = await players_repo.fetch_one(id=user_id) + row = await users_repo.fetch_one(id=user_id) if not row: return ORJSONResponse( diff --git a/app/api/v2/players.py b/app/api/v2/players.py index f267f714d..706077f3c 100644 --- a/app/api/v2/players.py +++ b/app/api/v2/players.py @@ -13,8 +13,8 @@ from app.api.v2.models.players import Player from app.api.v2.models.players import PlayerStats from app.api.v2.models.players import PlayerStatus -from app.repositories import players as players_repo from app.repositories import stats as stats_repo +from app.repositories import users as users_repo router = APIRouter() @@ -30,7 +30,7 @@ async def get_players( page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=100), ) -> Success[list[Player]] | Failure: - players = await players_repo.fetch_many( + players = await users_repo.fetch_many( priv=priv, country=country, clan_id=clan_id, @@ -40,7 +40,7 @@ async def get_players( page=page, page_size=page_size, ) - total_players = await players_repo.fetch_count( + total_players = await users_repo.fetch_count( priv=priv, country=country, clan_id=clan_id, @@ -63,7 +63,7 @@ async def get_players( @router.get("/players/{player_id}") async def get_player(player_id: int) -> Success[Player] | Failure: - data = await players_repo.fetch_one(id=player_id) + data = await users_repo.fetch_one(id=player_id) if data is None: return responses.failure( message="Player not found.", diff --git a/app/commands.py b/app/commands.py index 262ae3765..963297065 100644 --- a/app/commands.py +++ b/app/commands.py @@ -61,9 +61,9 @@ from app.repositories import logs as logs_repo from app.repositories import map_requests as map_requests_repo from app.repositories import maps as maps_repo -from app.repositories import players as players_repo from app.repositories import tourney_pool_maps as tourney_pool_maps_repo from app.repositories import tourney_pools as tourney_pools_repo +from app.repositories import users as users_repo from app.usecases.performance import ScoreParams if TYPE_CHECKING: @@ -275,11 +275,11 @@ async def changename(ctx: Context) -> str | None: if name in app.settings.DISALLOWED_NAMES: return "Disallowed username; pick another." - if await players_repo.fetch_one(name=name): + if await users_repo.fetch_one(name=name): return "Username already taken by another player." # all checks passed, update their name - await players_repo.update(ctx.player.id, name=name) + await users_repo.update(ctx.player.id, name=name) ctx.player.enqueue( app.packets.notification(f"Your username has been changed to {name}!"), @@ -565,7 +565,7 @@ async def apikey(ctx: Context) -> str | None: # generate new token ctx.player.api_key = str(uuid.uuid4()) - await players_repo.update(ctx.player.id, api_key=ctx.player.api_key) + await users_repo.update(ctx.player.id, api_key=ctx.player.api_key) app.state.sessions.api_keys[ctx.player.api_key] = ctx.player.id return f"API key generated. Copy your api key from (this url)[http://{ctx.player.api_key}]." @@ -2330,7 +2330,7 @@ async def clan_create(ctx: Context) -> str | None: clan.owner_id = ctx.player.id clan.member_ids.add(ctx.player.id) - await players_repo.update( + await users_repo.update( ctx.player.id, clan_id=clan.id, clan_priv=ClanPrivileges.Owner, @@ -2369,7 +2369,7 @@ async def clan_disband(ctx: Context) -> str | None: # reset their clan privs (cache & sql). # NOTE: only online players need be to be uncached. for member_id in clan.member_ids: - await players_repo.update(member_id, clan_id=0, clan_priv=0) + await users_repo.update(member_id, clan_id=0, clan_priv=0) member = app.state.sessions.players.get(id=member_id) if member: @@ -2398,7 +2398,7 @@ async def clan_info(ctx: Context) -> str | None: msg = [f"{clan!r} | Founded {clan.created_at:%b %d, %Y}."] # get members privs from sql - clan_members = await players_repo.fetch_many(clan_id=clan.id) + clan_members = await users_repo.fetch_many(clan_id=clan.id) for member in sorted(clan_members, key=lambda m: m["clan_priv"], reverse=True): priv_str = ("Member", "Officer", "Owner")[member["clan_priv"] - 1] msg.append(f"[{priv_str}] {member['name']}") diff --git a/app/objects/clan.py b/app/objects/clan.py index 9ba79f539..421197816 100644 --- a/app/objects/clan.py +++ b/app/objects/clan.py @@ -6,7 +6,7 @@ import app.state from app.constants.privileges import ClanPrivileges from app.repositories import clans as clans_repo -from app.repositories import players as players_repo +from app.repositories import users as users_repo if TYPE_CHECKING: from app.objects.player import Player @@ -53,7 +53,7 @@ async def remove_member(self, player: Player) -> None: """Remove a given player from the clan's members.""" self.member_ids.remove(player.id) - await players_repo.update(player.id, clan_id=0, clan_priv=0) + await users_repo.update(player.id, clan_id=0, clan_priv=0) if not self.member_ids: # no members left, disband clan. @@ -66,7 +66,7 @@ async def remove_member(self, player: Player) -> None: self.owner_id = next(iter(self.member_ids)) await clans_repo.update(self.id, owner=self.owner_id) - await players_repo.update(self.owner_id, clan_priv=ClanPrivileges.Owner) + await users_repo.update(self.owner_id, clan_priv=ClanPrivileges.Owner) player.clan = None player.clan_priv = None diff --git a/app/objects/collections.py b/app/objects/collections.py index 17ec7d4ea..55513db82 100644 --- a/app/objects/collections.py +++ b/app/objects/collections.py @@ -22,7 +22,7 @@ from app.objects.player import Player from app.repositories import channels as channels_repo from app.repositories import clans as clans_repo -from app.repositories import players as players_repo +from app.repositories import users as users_repo from app.utils import make_safe_name # TODO: decorator for these collections which automatically @@ -198,7 +198,7 @@ async def get_sql( ) -> Player | None: """Get a player by token, id, or name from sql.""" # try to get from sql. - player = await players_repo.fetch_one( + player = await users_repo.fetch_one( id=id, name=name, fetch_all_fields=True, @@ -349,7 +349,7 @@ async def prepare(self, db_conn: databases.core.Connection) -> None: """Fetch data from sql & return; preparing to run the server.""" log("Fetching clans from sql.", Ansi.LCYAN) for row in await clans_repo.fetch_many(): - clan_members = await players_repo.fetch_many(clan_id=row["id"]) + clan_members = await users_repo.fetch_many(clan_id=row["id"]) clan = Clan( id=row["id"], name=row["name"], @@ -367,7 +367,7 @@ async def initialize_ram_caches(db_conn: databases.core.Connection) -> None: await app.state.sessions.channels.prepare(db_conn) await app.state.sessions.clans.prepare(db_conn) - bot = await players_repo.fetch_one(id=1) + bot = await users_repo.fetch_one(id=1) if bot is None: raise RuntimeError("Bot account not found in database.") diff --git a/app/objects/player.py b/app/objects/player.py index c4408d522..356e8e1d7 100644 --- a/app/objects/player.py +++ b/app/objects/player.py @@ -35,8 +35,8 @@ from app.objects.score import Grade from app.objects.score import Score from app.repositories import logs as logs_repo -from app.repositories import players as players_repo from app.repositories import stats as stats_repo +from app.repositories import users as users_repo from app.state.services import Geolocation from app.utils import escape_enum from app.utils import make_safe_name @@ -420,7 +420,7 @@ async def update_privs(self, new: Privileges) -> None: if "bancho_priv" in vars(self): del self.bancho_priv # wipe cached_property - await players_repo.update( + await users_repo.update( id=self.id, priv=self.priv, ) @@ -432,7 +432,7 @@ async def add_privs(self, bits: Privileges) -> None: if "bancho_priv" in vars(self): del self.bancho_priv # wipe cached_property - await players_repo.update( + await users_repo.update( id=self.id, priv=self.priv, ) @@ -449,7 +449,7 @@ async def remove_privs(self, bits: Privileges) -> None: if "bancho_priv" in vars(self): del self.bancho_priv # wipe cached_property - await players_repo.update( + await users_repo.update( id=self.id, priv=self.priv, ) @@ -536,7 +536,7 @@ async def silence(self, admin: Player, duration: float, reason: str) -> None: """Silence `self` for `duration` seconds, and log to sql.""" self.silence_end = int(time.time() + duration) - await players_repo.update( + await users_repo.update( id=self.id, silence_end=self.silence_end, ) @@ -564,7 +564,7 @@ async def unsilence(self, admin: Player, reason: str) -> None: """Unsilence `self`, and log to sql.""" self.silence_end = int(time.time()) - await players_repo.update( + await users_repo.update( id=self.id, silence_end=self.silence_end, ) @@ -1000,7 +1000,7 @@ async def stats_from_sql_full(self, db_conn: databases.core.Connection) -> None: def update_latest_activity_soon(self) -> None: """Update the player's latest activity in the database.""" - task = players_repo.update( + task = users_repo.update( id=self.id, latest_activity=int(time.time()), ) diff --git a/app/repositories/players.py b/app/repositories/users.py similarity index 100% rename from app/repositories/players.py rename to app/repositories/users.py From 6ca6951b0f77c78f7af76b0fa6fd5277d8ab93d4 Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Sat, 17 Feb 2024 17:37:06 -0500 Subject: [PATCH 22/48] Deprecate clans collection, only hold clan id references & don't cache in memory (#619) * deprecate clans collection -- move everything to use references & don't hold state in memory * delete clans collection * delete Clan object and fix bugs * remove unnecessary type-ignore * remove todos which will never be done * bug fixes from testing * prefer none explicitly * fix bug from conflict * fix names * fix --- app/api/domains/cho.py | 7 +- app/api/domains/osu.py | 9 ++- app/api/v1/api.py | 133 ++++++++++++++++++------------------- app/commands.py | 118 +++++++++++++++++++------------- app/objects/__init__.py | 1 - app/objects/clan.py | 75 --------------------- app/objects/collections.py | 84 +---------------------- app/objects/player.py | 32 +-------- app/repositories/clans.py | 8 +-- app/state/sessions.py | 2 - app/utils.py | 8 +++ 11 files changed, 164 insertions(+), 313 deletions(-) delete mode 100644 app/objects/clan.py diff --git a/app/api/domains/cho.py b/app/api/domains/cho.py index cfe460055..0d153c9cf 100644 --- a/app/api/domains/cho.py +++ b/app/api/domains/cho.py @@ -42,7 +42,6 @@ from app.objects.beatmap import Beatmap from app.objects.beatmap import ensure_osu_file_is_available from app.objects.channel import Channel -from app.objects.clan import Clan from app.objects.match import MAX_MATCH_NAME_LENGTH from app.objects.match import Match from app.objects.match import MatchTeams @@ -791,10 +790,10 @@ async def handle_osu_login_request( """ All checks passed, player is safe to login """ # get clan & clan priv if we're in a clan - clan: Clan | None = None + clan_id: int | None = None clan_priv: ClanPrivileges | None = None if user_info["clan_id"] != 0: - clan = app.state.sessions.clans.get(id=user_info["clan_id"]) + clan_id = user_info["clan_id"] clan_priv = ClanPrivileges(user_info["clan_priv"]) db_country = user_info["country"] @@ -838,7 +837,7 @@ async def handle_osu_login_request( priv=Privileges(user_info["priv"]), pw_bcrypt=user_info["pw_bcrypt"].encode(), token=Player.generate_token(), - clan=clan, + clan_id=clan_id, clan_priv=clan_priv, geoloc=geoloc, utc_offset=login_data["utc_offset"], diff --git a/app/api/domains/osu.py b/app/api/domains/osu.py index 517806bf9..2795f8d8a 100644 --- a/app/api/domains/osu.py +++ b/app/api/domains/osu.py @@ -60,6 +60,7 @@ from app.objects.score import Grade from app.objects.score import Score from app.objects.score import SubmissionStatus +from app.repositories import clans as clans_repo from app.repositories import comments as comments_repo from app.repositories import favourites as favourites_repo from app.repositories import mail as mail_repo @@ -1387,10 +1388,16 @@ async def getScores( return Response("\n".join(response_lines).encode()) if personal_best_score_row is not None: + user_clan = await clans_repo.fetch_one(id=player.clan_id) + display_name = ( + f"[{user_clan['tag']}] {player.name}" + if user_clan is not None + else player.name + ) response_lines.append( SCORE_LISTING_FMTSTR.format( **personal_best_score_row, - name=player.full_name, + name=display_name, userid=player.id, score=int(round(personal_best_score_row["_score"])), has_replay="1", diff --git a/app/api/v1/api.py b/app/api/v1/api.py index 4a2aff77e..6789d32c4 100644 --- a/app/api/v1/api.py +++ b/app/api/v1/api.py @@ -24,8 +24,7 @@ from app.constants.mods import Mods from app.objects.beatmap import Beatmap from app.objects.beatmap import ensure_osu_file_is_available -from app.objects.clan import Clan -from app.objects.player import Player +from app.repositories import clans as clans_repo from app.repositories import scores as scores_repo from app.repositories import stats as stats_repo from app.repositories import tourney_pool_maps as tourney_pool_maps_repo @@ -70,50 +69,6 @@ DATETIME_OFFSET = 0x89F7FF5F7B58000 -def format_clan_basic(clan: Clan) -> dict[str, object]: - return { - "id": clan.id, - "name": clan.name, - "tag": clan.tag, - "members": len(clan.member_ids), - } - - -def format_player_basic(player: Player) -> dict[str, object]: - return { - "id": player.id, - "name": player.name, - "country": player.geoloc["country"]["acronym"], - "clan": format_clan_basic(player.clan) if player.clan else None, - "online": player.is_online, - } - - -def format_map_basic(m: Beatmap) -> dict[str, object]: - return { - "id": m.id, - "md5": m.md5, - "set_id": m.set_id, - "artist": m.artist, - "title": m.title, - "version": m.version, - "creator": m.creator, - "last_update": m.last_update, - "total_length": m.total_length, - "max_combo": m.max_combo, - "status": m.status, - "plays": m.plays, - "passes": m.passes, - "mode": m.mode, - "bpm": m.bpm, - "cs": m.cs, - "od": m.od, - "ar": m.ar, - "hp": m.hp, - "diff": m.diff, - } - - @router.get("/calculate_pp") async def api_calculate_pp( token: HTTPCredentials = Depends(oauth2_scheme), @@ -498,16 +453,20 @@ async def api_get_player_scores( bmap = await Beatmap.from_md5(row.pop("map_md5")) row["beatmap"] = bmap.as_dict if bmap else None + clan: clans_repo.Clan | None = None + if player.clan_id: + clan = await clans_repo.fetch_one(id=player.clan_id) + player_info = { "id": player.id, "name": player.name, "clan": ( { - "id": player.clan.id, - "name": player.clan.name, - "tag": player.clan.tag, + "id": clan["id"], + "name": clan["name"], + "tag": clan["tag"], } - if player.clan + if clan is not None else None ), } @@ -949,36 +908,31 @@ async def api_get_clan( clan_id: int = Query(..., alias="id", ge=1, le=2_147_483_647), ) -> Response: """Return information of a given clan.""" - clan = app.state.sessions.clans.get(id=clan_id) + clan = await clans_repo.fetch_one(id=clan_id) if not clan: return ORJSONResponse( {"status": "Clan not found."}, status_code=status.HTTP_404_NOT_FOUND, ) - members: list[Player] = [] + clan_members = await users_repo.fetch_many(clan_id=clan["id"]) - for member_id in clan.member_ids: - member = await app.state.sessions.players.from_cache_or_sql(id=member_id) - assert member is not None - members.append(member) - - owner = await app.state.sessions.players.from_cache_or_sql(id=clan.owner_id) + owner = await app.state.sessions.players.from_cache_or_sql(id=clan["owner"]) assert owner is not None return ORJSONResponse( { - "id": clan.id, - "name": clan.name, - "tag": clan.tag, + "id": clan["id"], + "name": clan["name"], + "tag": clan["tag"], "members": [ { - "id": member.id, - "name": member.name, - "country": member.geoloc["country"]["acronym"], - "rank": ("Member", "Officer", "Owner")[member.clan_priv - 1], # type: ignore + "id": member["id"], + "name": member["name"], + "country": member["country"], + "rank": ("Member", "Officer", "Owner")[member["clan_priv"] - 1], } - for member in members + for member in clan_members ], "owner": { "id": owner.id, @@ -1017,14 +971,57 @@ async def api_get_pool( status_code=status.HTTP_404_NOT_FOUND, ) + pool_creator_clan = await clans_repo.fetch_one(id=pool_creator.clan_id) + pool_creator_clan_members: list[users_repo.Player] = [] + if pool_creator_clan is not None: + pool_creator_clan_members = await users_repo.fetch_many( + clan_id=pool_creator.clan_id, + ) + return ORJSONResponse( { "id": tourney_pool["id"], "name": tourney_pool["name"], "created_at": tourney_pool["created_at"], - "created_by": format_player_basic(pool_creator), + "created_by": { + "id": pool_creator.id, + "name": pool_creator.name, + "country": pool_creator.geoloc["country"]["acronym"], + "clan": ( + { + "id": pool_creator_clan["id"], + "name": pool_creator_clan["name"], + "tag": pool_creator_clan["tag"], + "members": len(pool_creator_clan_members), + } + if pool_creator_clan is not None + else None + ), + "online": pool_creator.is_online, + }, "maps": { - f"{mods!r}{slot}": format_map_basic(bmap) + f"{mods!r}{slot}": { + "id": bmap.id, + "md5": bmap.md5, + "set_id": bmap.set_id, + "artist": bmap.artist, + "title": bmap.title, + "version": bmap.version, + "creator": bmap.creator, + "last_update": bmap.last_update, + "total_length": bmap.total_length, + "max_combo": bmap.max_combo, + "status": bmap.status, + "plays": bmap.plays, + "passes": bmap.passes, + "mode": bmap.mode, + "bpm": bmap.bpm, + "cs": bmap.cs, + "od": bmap.od, + "ar": bmap.ar, + "hp": bmap.hp, + "diff": bmap.diff, + } for (mods, slot), bmap in tourney_pool_maps.items() }, }, diff --git a/app/commands.py b/app/commands.py index 963297065..e0fca667b 100644 --- a/app/commands.py +++ b/app/commands.py @@ -49,7 +49,6 @@ from app.objects.beatmap import Beatmap from app.objects.beatmap import RankedStatus from app.objects.beatmap import ensure_osu_file_is_available -from app.objects.clan import Clan from app.objects.match import Match from app.objects.match import MatchTeams from app.objects.match import MatchTeamTypes @@ -868,9 +867,14 @@ async def user(ctx: Context) -> str | None: else "False" ) + user_clan = await clans_repo.fetch_one(id=player.clan_id) + display_name = ( + f"[{user_clan['tag']}] {player.name}" if user_clan is not None else player.name + ) + return "\n".join( ( - f'[{"Bot" if player.is_bot_client else "Player"}] {player.full_name} ({player.id})', + f'[{"Bot" if player.is_bot_client else "Player"}] {display_name} ({player.id})', f"Privileges: {priv_list}", f"Donator: {donator_info}", f"Channels: {[c._name for c in player.channels]}", @@ -2295,54 +2299,43 @@ async def clan_create(ctx: Context) -> str | None: if not 2 <= len(name) <= 16: return "Clan name may be 2-16 characters long." - if ctx.player.clan: - return f"You're already a member of {ctx.player.clan}!" + if ctx.player.clan_id: + clan = await clans_repo.fetch_one(id=ctx.player.clan_id) + if clan: + clan_display_name = f"[{clan['tag']}] {clan['name']}" + return f"You're already a member of {clan_display_name}!" - if app.state.sessions.clans.get(name=name): + if await clans_repo.fetch_one(name=name): return "That name has already been claimed by another clan." - if app.state.sessions.clans.get(tag=tag): + if await clans_repo.fetch_one(tag=tag): return "That tag has already been claimed by another clan." - created_at = datetime.now() - # add clan to sql - persisted_clan = await clans_repo.create( + new_clan = await clans_repo.create( name=name, tag=tag, owner=ctx.player.id, ) - # add clan to cache - clan = Clan( - id=persisted_clan["id"], - name=name, - tag=tag, - created_at=created_at, - owner_id=ctx.player.id, - ) - app.state.sessions.clans.append(clan) - # set owner's clan & clan priv (cache & sql) - ctx.player.clan = clan + ctx.player.clan_id = new_clan["id"] ctx.player.clan_priv = ClanPrivileges.Owner - clan.owner_id = ctx.player.id - clan.member_ids.add(ctx.player.id) - await users_repo.update( ctx.player.id, - clan_id=clan.id, + clan_id=new_clan["id"], clan_priv=ClanPrivileges.Owner, ) # announce clan creation announce_chan = app.state.sessions.channels.get_by_name("#announce") + clan_display_name = f"[{new_clan['tag']}] {new_clan['name']}" if announce_chan: - msg = f"\x01ACTION founded {clan!r}." + msg = f"\x01ACTION founded {clan_display_name}." announce_chan.send(msg, sender=ctx.player, to_self=True) - return f"{clan!r} created." + return f"{clan_display_name} founded." @clan_commands.add(Privileges.UNRESTRICTED, aliases=["delete", "d"]) @@ -2353,36 +2346,41 @@ async def clan_disband(ctx: Context) -> str | None: if ctx.player not in app.state.sessions.players.staff: return "Only staff members may disband the clans of others." - clan = app.state.sessions.clans.get(tag=" ".join(ctx.args).upper()) + clan = await clans_repo.fetch_one(tag=" ".join(ctx.args).upper()) if not clan: return "Could not find a clan by that tag." else: + if ctx.player.clan_id is None: + return "You're not a member of a clan!" + # disband the player's clan - clan = ctx.player.clan + clan = await clans_repo.fetch_one(id=ctx.player.clan_id) if not clan: return "You're not a member of a clan!" - await clans_repo.delete(clan.id) - app.state.sessions.clans.remove(clan) + await clans_repo.delete(clan["id"]) - # remove all members from the clan, - # reset their clan privs (cache & sql). - # NOTE: only online players need be to be uncached. - for member_id in clan.member_ids: + # remove all members from the clan + clan_member_ids = [ + clan_member["id"] + for clan_member in await users_repo.fetch_many(clan_id=clan["id"]) + ] + for member_id in clan_member_ids: await users_repo.update(member_id, clan_id=0, clan_priv=0) member = app.state.sessions.players.get(id=member_id) if member: - member.clan = None + member.clan_id = None member.clan_priv = None # announce clan disbanding announce_chan = app.state.sessions.channels.get_by_name("#announce") + clan_display_name = f"[{clan['tag']}] {clan['name']}" if announce_chan: - msg = f"\x01ACTION disbanded {clan!r}." + msg = f"\x01ACTION disbanded {clan_display_name}." announce_chan.send(msg, sender=ctx.player, to_self=True) - return f"{clan!r} disbanded." + return f"{clan_display_name} disbanded." @clan_commands.add(Privileges.UNRESTRICTED, aliases=["i"]) @@ -2391,14 +2389,15 @@ async def clan_info(ctx: Context) -> str | None: if not ctx.args: return "Invalid syntax: !clan info " - clan = app.state.sessions.clans.get(tag=" ".join(ctx.args).upper()) + clan = await clans_repo.fetch_one(tag=" ".join(ctx.args).upper()) if not clan: return "Could not find a clan by that tag." - msg = [f"{clan!r} | Founded {clan.created_at:%b %d, %Y}."] + clan_display_name = f"[{clan['tag']}] {clan['name']}" + msg = [f"{clan_display_name} | Founded {clan['created_at']:%b %d, %Y}."] # get members privs from sql - clan_members = await users_repo.fetch_many(clan_id=clan.id) + clan_members = await users_repo.fetch_many(clan_id=clan["id"]) for member in sorted(clan_members, key=lambda m: m["clan_priv"], reverse=True): priv_str = ("Member", "Officer", "Owner")[member["clan_priv"] - 1] msg.append(f"[{priv_str}] {member['name']}") @@ -2409,13 +2408,34 @@ async def clan_info(ctx: Context) -> str | None: @clan_commands.add(Privileges.UNRESTRICTED) async def clan_leave(ctx: Context) -> str | None: """Leaves the clan you're in.""" - if not ctx.player.clan: + if not ctx.player.clan_id: return "You're not in a clan." elif ctx.player.clan_priv == ClanPrivileges.Owner: return "You must transfer your clan's ownership before leaving it. Alternatively, you can use !clan disband." - await ctx.player.clan.remove_member(ctx.player) - return f"You have successfully left {ctx.player.clan!r}." + clan = await clans_repo.fetch_one(id=ctx.player.clan_id) + if not clan: + return "You're not in a clan." + + clan_members = await users_repo.fetch_many(clan_id=clan["id"]) + + await users_repo.update(ctx.player.id, clan_id=0, clan_priv=0) + ctx.player.clan_id = None + ctx.player.clan_priv = None + + clan_display_name = f"[{clan['tag']}] {clan['name']}" + + if not clan_members: + # no members left, disband clan + await clans_repo.delete(clan["id"]) + + # announce clan disbanding + announce_chan = app.state.sessions.channels.get_by_name("#announce") + if announce_chan: + msg = f"\x01ACTION disbanded {clan_display_name}." + announce_chan.send(msg, sender=ctx.player, to_self=True) + + return f"You have successfully left {clan_display_name}." # TODO: !clan inv, !clan join, !clan leave @@ -2432,14 +2452,16 @@ async def clan_list(ctx: Context) -> str | None: else: offset = 0 - total_clans = len(app.state.sessions.clans) - if offset >= total_clans: + all_clans = await clans_repo.fetch_many(page=None, page_size=None) + num_clans = len(all_clans) + if offset >= num_clans: return "No clans found." - msg = [f"bancho.py clans listing ({total_clans} total)."] + msg = [f"bancho.py clans listing ({num_clans} total)."] - for idx, clan in enumerate(app.state.sessions.clans, offset): - msg.append(f"{idx + 1}. {clan!r}") + for idx, clan in enumerate(all_clans, offset): + clan_display_name = f"[{clan['tag']}] {clan['name']}" + msg.append(f"{idx + 1}. {clan_display_name}") return "\n".join(msg) diff --git a/app/objects/__init__.py b/app/objects/__init__.py index f0465f975..761dd983b 100644 --- a/app/objects/__init__.py +++ b/app/objects/__init__.py @@ -4,7 +4,6 @@ from . import achievement from . import beatmap from . import channel -from . import clan from . import collections from . import match from . import models diff --git a/app/objects/clan.py b/app/objects/clan.py deleted file mode 100644 index 421197816..000000000 --- a/app/objects/clan.py +++ /dev/null @@ -1,75 +0,0 @@ -from __future__ import annotations - -from datetime import datetime -from typing import TYPE_CHECKING - -import app.state -from app.constants.privileges import ClanPrivileges -from app.repositories import clans as clans_repo -from app.repositories import users as users_repo - -if TYPE_CHECKING: - from app.objects.player import Player - - -class Clan: - """A class to represent a single bancho.py clan.""" - - def __init__( - self, - id: int, - name: str, - tag: str, - created_at: datetime, - owner_id: int, - member_ids: set[int] | None = None, - ) -> None: - """A class representing one of bancho.py's clans.""" - self.id = id - self.name = name - self.tag = tag - self.created_at = created_at - - self.owner_id = owner_id # userid - - if member_ids is None: - member_ids = set() - - self.member_ids = member_ids # userids - - async def add_member(self, player: Player) -> None: - """Add a given player to the clan's members.""" - self.member_ids.add(player.id) - - await app.state.services.database.execute( - "UPDATE users SET clan_id = :clan_id, clan_priv = 1 WHERE id = :user_id", - {"clan_id": self.id, "user_id": player.id}, - ) - - player.clan = self - player.clan_priv = ClanPrivileges.Member - - async def remove_member(self, player: Player) -> None: - """Remove a given player from the clan's members.""" - self.member_ids.remove(player.id) - - await users_repo.update(player.id, clan_id=0, clan_priv=0) - - if not self.member_ids: - # no members left, disband clan. - await clans_repo.delete(self.id) - app.state.sessions.clans.remove(self) - elif player.id == self.owner_id: - # owner leaving and members left, - # transfer the ownership. - # TODO: prefer officers - self.owner_id = next(iter(self.member_ids)) - - await clans_repo.update(self.id, owner=self.owner_id) - await users_repo.update(self.owner_id, clan_priv=ClanPrivileges.Owner) - - player.clan = None - player.clan_priv = None - - def __repr__(self) -> str: - return f"[{self.tag}] {self.name}" diff --git a/app/objects/collections.py b/app/objects/collections.py index 55513db82..67af86dcd 100644 --- a/app/objects/collections.py +++ b/app/objects/collections.py @@ -1,5 +1,3 @@ -# TODO: there is still a lot of inconsistency -# in a lot of these classes; needs refactor. from __future__ import annotations from collections.abc import Iterable @@ -17,7 +15,6 @@ from app.logging import Ansi from app.logging import log from app.objects.channel import Channel -from app.objects.clan import Clan from app.objects.match import Match from app.objects.player import Player from app.repositories import channels as channels_repo @@ -25,9 +22,6 @@ from app.repositories import users as users_repo from app.utils import make_safe_name -# TODO: decorator for these collections which automatically -# adds debugging to their append/remove/insert/extend methods. - class Channels(list[Channel]): """The currently active chat channels on the server.""" @@ -206,10 +200,10 @@ async def get_sql( if player is None: return None - clan: Clan | None = None + clan_id: int | None = None clan_priv: ClanPrivileges | None = None if player["clan_id"] != 0: - clan = app.state.sessions.clans.get(id=player["clan_id"]) + clan_id = player["clan_id"] clan_priv = ClanPrivileges(player["clan_priv"]) return Player( @@ -218,7 +212,7 @@ async def get_sql( priv=Privileges(player["priv"]), pw_bcrypt=player["pw_bcrypt"].encode(), token=Player.generate_token(), - clan=clan, + clan_id=clan_id, clan_priv=clan_priv, geoloc={ "latitude": 0.0, @@ -290,82 +284,10 @@ def remove(self, player: Player) -> None: super().remove(player) -class Clans(list[Clan]): - """The currently active clans on the server.""" - - def __iter__(self) -> Iterator[Clan]: - return super().__iter__() - - def __contains__(self, o: object) -> bool: - """Check whether internal list contains `o`.""" - # Allow string to be passed to compare vs. name. - if isinstance(o, str): - return o in (clan.name for clan in self) - else: - return o in self - - def get( - self, - id: int | None = None, - name: str | None = None, - tag: str | None = None, - ) -> Clan | None: - """Get a clan by name, tag, or id.""" - for clan in self: - if id is not None: - if clan.id == id: - return clan - elif name is not None: - if clan.name == name: - return clan - elif tag is not None: - if clan.tag == tag: - return clan - - return None - - def append(self, clan: Clan) -> None: - """Append `clan` to the list.""" - super().append(clan) - - if app.settings.DEBUG: - log(f"{clan} added to clans list.") - - def extend(self, clans: Iterable[Clan]) -> None: - """Extend the list with `clans`.""" - super().extend(clans) - - if app.settings.DEBUG: - log(f"{clans} added to clans list.") - - def remove(self, clan: Clan) -> None: - """Remove `clan` from the list.""" - super().remove(clan) - - if app.settings.DEBUG: - log(f"{clan} removed from clans list.") - - async def prepare(self, db_conn: databases.core.Connection) -> None: - """Fetch data from sql & return; preparing to run the server.""" - log("Fetching clans from sql.", Ansi.LCYAN) - for row in await clans_repo.fetch_many(): - clan_members = await users_repo.fetch_many(clan_id=row["id"]) - clan = Clan( - id=row["id"], - name=row["name"], - tag=row["tag"], - created_at=row["created_at"], - owner_id=row["owner"], - member_ids={member["id"] for member in clan_members}, - ) - self.append(clan) - - async def initialize_ram_caches(db_conn: databases.core.Connection) -> None: """Setup & cache the global collections before listening for connections.""" # fetch channels, clans and pools from db await app.state.sessions.channels.prepare(db_conn) - await app.state.sessions.clans.prepare(db_conn) bot = await users_repo.fetch_one(id=1) if bot is None: diff --git a/app/objects/player.py b/app/objects/player.py index 356e8e1d7..d99ba1909 100644 --- a/app/objects/player.py +++ b/app/objects/player.py @@ -34,6 +34,7 @@ from app.objects.match import SlotStatus from app.objects.score import Grade from app.objects.score import Score +from app.repositories import clans as clans_repo from app.repositories import logs as logs_repo from app.repositories import stats as stats_repo from app.repositories import users as users_repo @@ -45,7 +46,6 @@ if TYPE_CHECKING: from app.constants.privileges import ClanPrivileges from app.objects.beatmap import Beatmap - from app.objects.clan import Clan from app.objects.score import Score @@ -214,7 +214,7 @@ def __init__( priv: Privileges, pw_bcrypt: bytes | None, token: str, - clan: Clan | None = None, + clan_id: int | None = None, clan_priv: ClanPrivileges | None = None, geoloc: Geolocation | None = None, utc_offset: int = 0, @@ -239,7 +239,7 @@ def __init__( self.priv = priv self.pw_bcrypt = pw_bcrypt self.token = token - self.clan = clan + self.clan_id = clan_id self.clan_priv = clan_priv self.geoloc = geoloc self.utc_offset = utc_offset @@ -315,14 +315,6 @@ def avatar_url(self) -> str: """The url to the player's avatar.""" return f"https://a.{app.settings.DOMAIN}/{self.id}" - @property - def full_name(self) -> str: - """The user's "full" name; including their clan tag.""" - if self.clan: - return f"[{self.clan.tag}] {self.name}" - else: - return self.name - # TODO: chat embed with clan tag hyperlinked? @property @@ -693,24 +685,6 @@ def leave_match(self) -> None: self.match = None - async def join_clan(self, clan: Clan) -> bool: - """Attempt to add `self` to `clan`.""" - if self.id in clan.member_ids: - return False - - if not "invited": # TODO - return False - - await clan.add_member(self) - return True - - async def leave_clan(self) -> None: - """Attempt to remove `self` from `c`.""" - if not self.clan: - return - - await self.clan.remove_member(self) - def join_channel(self, channel: Channel) -> bool: """Attempt to add `self` to `channel`.""" if ( diff --git a/app/repositories/clans.py b/app/repositories/clans.py index 17ba55c47..b894edf75 100644 --- a/app/repositories/clans.py +++ b/app/repositories/clans.py @@ -182,8 +182,8 @@ async def delete(id: int) -> Clan | None: params: dict[str, Any] = { "id": id, } - rec = await app.state.services.database.fetch_one(query, params) - if rec is None: + clan = await app.state.services.database.fetch_one(query, params) + if clan is None: return None query = """\ @@ -193,5 +193,5 @@ async def delete(id: int) -> Clan | None: params = { "id": id, } - clan = await app.state.services.database.execute(query, params) - return cast(Clan, dict(clan._mapping)) if clan is not None else None + await app.state.services.database.execute(query, params) + return cast(Clan, dict(clan._mapping)) diff --git a/app/state/sessions.py b/app/state/sessions.py index 4fe183b82..43356ef8f 100644 --- a/app/state/sessions.py +++ b/app/state/sessions.py @@ -7,7 +7,6 @@ from app.logging import Ansi from app.logging import log from app.objects.collections import Channels -from app.objects.collections import Clans from app.objects.collections import Matches from app.objects.collections import Players @@ -16,7 +15,6 @@ players = Players() channels = Channels() -clans = Clans() matches = Matches() api_keys: dict[str, int] = {} diff --git a/app/utils.py b/app/utils.py index 0e9856a20..bb2df5f75 100644 --- a/app/utils.py +++ b/app/utils.py @@ -7,6 +7,7 @@ import sys from collections.abc import Callable from pathlib import Path +from typing import TYPE_CHECKING from typing import Any from typing import TypedDict from typing import TypeVar @@ -18,6 +19,9 @@ from app.logging import Ansi from app.logging import log +if TYPE_CHECKING: + from app.repositories.users import Player + T = TypeVar("T") @@ -31,6 +35,10 @@ def make_safe_name(name: str) -> str: return name.lower().replace(" ", "_") +def determine_highest_ranking_clan_member(members: list[Player]) -> Player: + return next(iter(sorted(members, key=lambda m: m["clan_priv"], reverse=True))) + + def _download_achievement_images_osu(achievements_path: Path) -> bool: """Download all used achievement images (one by one, from osu!).""" achs: list[str] = [] From 336db494e8d37061ca586ad4831568c4b45a7ec6 Mon Sep 17 00:00:00 2001 From: cmyui Date: Sat, 17 Feb 2024 17:38:23 -0500 Subject: [PATCH 23/48] Rename Player to User to match repo terminology --- app/api/domains/cho.py | 2 +- app/api/v1/api.py | 2 +- app/repositories/users.py | 38 +++++++++++++++++++------------------- app/utils.py | 4 ++-- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/app/api/domains/cho.py b/app/api/domains/cho.py index 0d153c9cf..a0f5a693f 100644 --- a/app/api/domains/cho.py +++ b/app/api/domains/cho.py @@ -598,7 +598,7 @@ def parse_adapters_string(adapters_string: str) -> tuple[list[str], bool]: async def authenticate( username: str, untrusted_password: bytes, -) -> users_repo.Player | None: +) -> users_repo.User | None: user_info = await users_repo.fetch_one( name=username, fetch_all_fields=True, diff --git a/app/api/v1/api.py b/app/api/v1/api.py index 6789d32c4..0109a80d5 100644 --- a/app/api/v1/api.py +++ b/app/api/v1/api.py @@ -972,7 +972,7 @@ async def api_get_pool( ) pool_creator_clan = await clans_repo.fetch_one(id=pool_creator.clan_id) - pool_creator_clan_members: list[users_repo.Player] = [] + pool_creator_clan_members: list[users_repo.User] = [] if pool_creator_clan is not None: pool_creator_clan_members = await users_repo.fetch_many( clan_id=pool_creator.clan_id, diff --git a/app/repositories/users.py b/app/repositories/users.py index 369572e7e..d38c05c56 100644 --- a/app/repositories/users.py +++ b/app/repositories/users.py @@ -43,7 +43,7 @@ ) -class Player(TypedDict): +class User(TypedDict): id: int name: str safe_name: str @@ -89,8 +89,8 @@ async def create( email: str, pw_bcrypt: bytes, country: str, -) -> Player: - """Create a new player in the database.""" +) -> User: + """Create a new user in the database.""" query = f"""\ INSERT INTO users (name, safe_name, email, pw_bcrypt, country, creation_time, latest_activity) VALUES (:name, :safe_name, :email, :pw_bcrypt, :country, UNIX_TIMESTAMP(), UNIX_TIMESTAMP()) @@ -112,10 +112,10 @@ async def create( params = { "id": rec_id, } - player = await app.state.services.database.fetch_one(query, params) + user = await app.state.services.database.fetch_one(query, params) - assert player is not None - return cast(Player, dict(player._mapping)) + assert user is not None + return cast(User, dict(user._mapping)) async def fetch_one( @@ -123,8 +123,8 @@ async def fetch_one( name: str | None = None, email: str | None = None, fetch_all_fields: bool = False, # TODO: probably remove this if possible -) -> Player | None: - """Fetch a single player from the database.""" +) -> User | None: + """Fetch a single user from the database.""" if id is None and name is None and email is None: raise ValueError("Must provide at least one parameter.") @@ -140,8 +140,8 @@ async def fetch_one( "safe_name": make_safe_name(name) if name is not None else None, "email": email, } - player = await app.state.services.database.fetch_one(query, params) - return cast(Player, dict(player._mapping)) if player is not None else None + user = await app.state.services.database.fetch_one(query, params) + return cast(User, dict(user._mapping)) if user is not None else None async def fetch_count( @@ -152,7 +152,7 @@ async def fetch_count( preferred_mode: int | None = None, play_style: int | None = None, ) -> int: - """Fetch the number of players in the database.""" + """Fetch the number of users in the database.""" query = """\ SELECT COUNT(*) AS count FROM users @@ -185,8 +185,8 @@ async def fetch_many( play_style: int | None = None, page: int | None = None, page_size: int | None = None, -) -> list[Player]: - """Fetch multiple players from the database.""" +) -> list[User]: + """Fetch multiple users from the database.""" query = f"""\ SELECT {READ_PARAMS} FROM users @@ -214,8 +214,8 @@ async def fetch_many( params["limit"] = page_size params["offset"] = (page - 1) * page_size - players = await app.state.services.database.fetch_all(query, params) - return cast(list[Player], [dict(p._mapping) for p in players]) + users = await app.state.services.database.fetch_all(query, params) + return cast(list[User], [dict(p._mapping) for p in users]) async def update( @@ -236,8 +236,8 @@ async def update( custom_badge_icon: str | None | _UnsetSentinel = UNSET, userpage_content: str | None | _UnsetSentinel = UNSET, api_key: str | None | _UnsetSentinel = UNSET, -) -> Player | None: - """Update a player in the database.""" +) -> User | None: + """Update a user in the database.""" update_fields: PlayerUpdateFields = {} if not isinstance(name, _UnsetSentinel): update_fields["name"] = name @@ -291,8 +291,8 @@ async def update( params = { "id": id, } - player = await app.state.services.database.fetch_one(query, params) - return cast(Player, dict(player._mapping)) if player is not None else None + user = await app.state.services.database.fetch_one(query, params) + return cast(User, dict(user._mapping)) if user is not None else None # TODO: delete? diff --git a/app/utils.py b/app/utils.py index bb2df5f75..64f675c49 100644 --- a/app/utils.py +++ b/app/utils.py @@ -20,7 +20,7 @@ from app.logging import log if TYPE_CHECKING: - from app.repositories.users import Player + from app.repositories.users import User T = TypeVar("T") @@ -35,7 +35,7 @@ def make_safe_name(name: str) -> str: return name.lower().replace(" ", "_") -def determine_highest_ranking_clan_member(members: list[Player]) -> Player: +def determine_highest_ranking_clan_member(members: list[User]) -> User: return next(iter(sorted(members, key=lambda m: m["clan_priv"], reverse=True))) From 26e460dc1bdd00f62f155eb7240cf8bceb8ee9be Mon Sep 17 00:00:00 2001 From: cmyui Date: Sat, 17 Feb 2024 17:39:38 -0500 Subject: [PATCH 24/48] rename PlayerUpdateFields to UserUpdateFields --- app/repositories/users.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/repositories/users.py b/app/repositories/users.py index d38c05c56..a634f8771 100644 --- a/app/repositories/users.py +++ b/app/repositories/users.py @@ -64,7 +64,7 @@ class User(TypedDict): api_key: str | None -class PlayerUpdateFields(TypedDict, total=False): +class UserUpdateFields(TypedDict, total=False): name: str safe_name: str email: str @@ -238,7 +238,7 @@ async def update( api_key: str | None | _UnsetSentinel = UNSET, ) -> User | None: """Update a user in the database.""" - update_fields: PlayerUpdateFields = {} + update_fields: UserUpdateFields = {} if not isinstance(name, _UnsetSentinel): update_fields["name"] = name update_fields["safe_name"] = make_safe_name(name) From f38dbfe64857b7cdcf330c0af0807344b149b2b4 Mon Sep 17 00:00:00 2001 From: cmyui Date: Sun, 18 Feb 2024 00:20:21 -0500 Subject: [PATCH 25/48] add ip to geoloc fetch error log --- app/state/services.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/state/services.py b/app/state/services.py index 0365a552d..9f28ab219 100644 --- a/app/state/services.py +++ b/app/state/services.py @@ -218,7 +218,7 @@ async def _fetch_geoloc_from_ip(ip: IPAddress) -> Geolocation | None: if err_msg == "invalid query": err_msg += f" ({url})" - log(f"Failed to get geoloc data: {err_msg}.", Ansi.LRED) + log(f"Failed to get geoloc data: {err_msg} for ip {ip}.", Ansi.LRED) return None country_acronym = lines[0].lower() From aeafffa309132defa2e5a9c6a7b2c426e2506f06 Mon Sep 17 00:00:00 2001 From: cmyui Date: Sun, 18 Feb 2024 01:24:23 -0500 Subject: [PATCH 26/48] Revert "only allow loopback address to use server ip for geoloc" This reverts commit 112012a4a39c6cbb39f5a19a7523b084d70cc3cf. --- app/state/services.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/state/services.py b/app/state/services.py index 9f28ab219..0b47166df 100644 --- a/app/state/services.py +++ b/app/state/services.py @@ -196,7 +196,7 @@ def __fetch_geoloc_nginx(headers: Mapping[str, str]) -> Geolocation | None: async def _fetch_geoloc_from_ip(ip: IPAddress) -> Geolocation | None: """Fetch geolocation data based on ip (using ip-api).""" - if not ip.is_loopback: + if not ip.is_private: url = f"http://ip-api.com/line/{ip}" else: url = "http://ip-api.com/line/" From bec0b779fba8b784d5a65fb9530a21c0c4bf6cfe Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Sun, 18 Feb 2024 23:32:28 -0500 Subject: [PATCH 27/48] fix migrations (#624) --- app/state/services.py | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/app/state/services.py b/app/state/services.py index 0b47166df..06728b324 100644 --- a/app/state/services.py +++ b/app/state/services.py @@ -401,16 +401,26 @@ async def _get_current_sql_structure_version() -> Version | None: async def run_sql_migrations() -> None: """Update the sql structure, if it has changed.""" - current_ver = await _get_current_sql_structure_version() - if not current_ver: - return # already up to date (server has never run before) - - latest_ver = Version.from_str(app.settings.VERSION) - - if latest_ver is None: + software_version = Version.from_str(app.settings.VERSION) + if software_version is None: raise RuntimeError(f"Invalid bancho.py version '{app.settings.VERSION}'") - if latest_ver == current_ver: + last_run_migration_version = await _get_current_sql_structure_version() + if not last_run_migration_version: + # Migrations have never run before - this is the first time starting the server. + # We'll insert the current version into the database, so future versions know to migrate. + await app.state.services.database.execute( + "INSERT INTO startups (ver_major, ver_minor, ver_micro, datetime) " + "VALUES (:major, :minor, :micro, NOW())", + { + "major": software_version.major, + "minor": software_version.minor, + "micro": software_version.micro, + }, + ) + return # already up to date (server has never run before) + + if software_version == last_run_migration_version: return # already up to date # version changed; there may be sql changes. @@ -437,7 +447,7 @@ async def run_sql_migrations() -> None: # we only need the updates between the # previous and new version of the server. - if current_ver < update_ver <= latest_ver: + if last_run_migration_version < update_ver <= software_version: if line.endswith(";"): if q_lines: q_lines.append(line) @@ -450,7 +460,7 @@ async def run_sql_migrations() -> None: if queries: log( - f"Updating mysql structure (v{current_ver!r} -> v{latest_ver!r}).", + f"Updating mysql structure (v{last_run_migration_version!r} -> v{software_version!r}).", Ansi.LMAGENTA, ) @@ -477,8 +487,8 @@ async def run_sql_migrations() -> None: "INSERT INTO startups (ver_major, ver_minor, ver_micro, datetime) " "VALUES (:major, :minor, :micro, NOW())", { - "major": latest_ver.major, - "minor": latest_ver.minor, - "micro": latest_ver.micro, + "major": software_version.major, + "minor": software_version.minor, + "micro": software_version.micro, }, ) From 2e389601fa8147171c62f2b3c629a11fab4bc919 Mon Sep 17 00:00:00 2001 From: cmyui Date: Mon, 19 Feb 2024 04:00:39 -0500 Subject: [PATCH 28/48] add a basic explanation of bancho packets to packets.py --- app/packets.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/app/packets.py b/app/packets.py index d35e838ac..9c426032e 100644 --- a/app/packets.py +++ b/app/packets.py @@ -24,8 +24,22 @@ from app.objects.match import Match from app.objects.player import Player -# tuple of some of struct's format specifiers -# for clean access within packet pack/unpack. +# packets are comprised of 3 parts: +# - a unique identifier (the packet id), representing the type of request +# - the length of the request data +# - request data; specific to the packet id + +# the packet id is sent over the wire as an unsigned short (2 bytes, u16) +# the packet data length is sent as an unsigned long (4 bytes, u32) +# the packet data +# - is of variable length +# - may comprise of multiple objects +# - is specific to the request type (packet id) +# - types can vary, but are from a fixed set of possibilities (u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, string, and some higher level types comprising of these primitives) + +# osu! packets are sent in "little endian" ordering. +# little endian: [2, 0, 0, 0] == 2 +# big endian: [0, 0, 0, 2] == 2 @unique @@ -305,7 +319,7 @@ class BanchoPacketReader: Intended Usage: >>> with memoryview(await request.body()) as body_view: - ... for packet in BanchoPacketReader(conn.body): + ... for packet in BanchoPacketReader(body_view): ... await packet.handle() """ From 744429db1b8b7f70d478d0e7c1db9999970d9d59 Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Mon, 19 Feb 2024 06:07:00 -0500 Subject: [PATCH 29/48] Basic indices for SQL (#625) * Basic indices for SQL * add new indices to migrations.sql * bump to v5.0.1 * add 3 more indices * add 3 more indices --- migrations/base.sql | 56 ++++++++++++++++++++++++++++++++++- migrations/migrations.sql | 62 +++++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 3 files changed, 118 insertions(+), 2 deletions(-) diff --git a/migrations/base.sql b/migrations/base.sql index 0e1b08304..b3fdcdc28 100644 --- a/migrations/base.sql +++ b/migrations/base.sql @@ -26,6 +26,8 @@ create table channels constraint channels_name_uindex unique (name) ); +create index channels_auto_join_index + on channels (auto_join); create table clans ( @@ -147,6 +149,18 @@ create table maps constraint maps_md5_uindex unique (md5) ); +create index maps_set_id_index + on maps (set_id); +create index maps_status_index + on maps (status); +create index maps_filename_index + on maps (filename); +create index maps_plays_index + on maps (plays); +create index maps_mode_index + on maps (mode); +create index maps_frozen_index + on maps (frozen); create table mapsets ( @@ -222,6 +236,25 @@ create table scores perfect tinyint(1) not null, online_checksum char(32) not null ); +create index scores_map_md5_index + on scores (map_md5); +create index scores_score_index + on scores (score); +create index scores_pp_index + on scores (pp); +create index scores_mods_index + on scores (mods); +create index scores_status_index + on scores (status); +create index scores_mode_index + on scores (mode); +create index scores_play_time_index + on scores (play_time); +create index scores_userid_index + on scores (userid); +create index scores_online_checksum_index + on scores (online_checksum); + create table startups ( @@ -253,6 +286,14 @@ create table stats a_count int unsigned default 0 not null, primary key (id, mode) ); +create index stats_mode_index + on stats (mode); +create index stats_pp_index + on stats (pp); +create index stats_tscore_index + on stats (tscore); +create index stats_rscore_index + on stats (rscore); create table tourney_pool_maps ( @@ -262,7 +303,8 @@ create table tourney_pool_maps slot tinyint not null, primary key (map_id, pool_id) ); - +create index tourney_pool_maps_mods_slot_index + on tourney_pool_maps (mods, slot); create index tourney_pool_maps_tourney_pools_id_fk on tourney_pool_maps (pool_id); @@ -284,6 +326,10 @@ create table user_achievements achid int not null, primary key (userid, achid) ); +create index user_achievements_achid_index + on user_achievements (achid); +create index user_achievements_userid_index + on user_achievements (userid); create table users ( @@ -316,6 +362,14 @@ create table users constraint users_safe_name_uindex unique (safe_name) ); +create index users_priv_index + on users (priv); +create index users_clan_id_index + on users (clan_id); +create index users_clan_priv_index + on users (clan_priv); +create index users_country_index + on users (country); insert into users (id, name, safe_name, priv, country, silence_end, email, pw_bcrypt, creation_time, latest_activity) values (1, 'BanchoBot', 'banchobot', 1, 'ca', 0, 'bot@akatsuki.pw', diff --git a/migrations/migrations.sql b/migrations/migrations.sql index e524aace7..2db215246 100644 --- a/migrations/migrations.sql +++ b/migrations/migrations.sql @@ -409,3 +409,65 @@ alter table maps drop primary key; alter table maps add primary key (id); alter table maps modify column server enum('osu!', 'private') not null default 'osu!' after id; unlock tables; + +# v5.0.1 +create index channels_auto_join_index + on channels (auto_join); + +create index maps_set_id_index + on maps (set_id); +create index maps_status_index + on maps (status); +create index maps_filename_index + on maps (filename); +create index maps_plays_index + on maps (plays); +create index maps_mode_index + on maps (mode); +create index maps_frozen_index + on maps (frozen); + +create index scores_map_md5_index + on scores (map_md5); +create index scores_score_index + on scores (score); +create index scores_pp_index + on scores (pp); +create index scores_mods_index + on scores (mods); +create index scores_status_index + on scores (status); +create index scores_mode_index + on scores (mode); +create index scores_play_time_index + on scores (play_time); +create index scores_userid_index + on scores (userid); +create index scores_online_checksum_index + on scores (online_checksum); + +create index stats_mode_index + on stats (mode); +create index stats_pp_index + on stats (pp); +create index stats_tscore_index + on stats (tscore); +create index stats_rscore_index + on stats (rscore); + +create index tourney_pool_maps_mods_slot_index + on tourney_pool_maps (mods, slot); + +create index user_achievements_achid_index + on user_achievements (achid); +create index user_achievements_userid_index + on user_achievements (userid); + +create index users_priv_index + on users (priv); +create index users_clan_id_index + on users (clan_id); +create index users_clan_priv_index + on users (clan_priv); +create index users_country_index + on users (country); diff --git a/pyproject.toml b/pyproject.toml index 6f9946b13..f130677bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ profile = "black" [tool.poetry] name = "bancho-py" -version = "5.0.0" +version = "5.0.1" description = "An osu! server implementation optimized for maintainability in modern python" authors = ["Akatsuki Team"] license = "MIT" From 0a16cd6ac98f70e1f91f7b6efc1ab19943b336bb Mon Sep 17 00:00:00 2001 From: cmyui Date: Mon, 19 Feb 2024 09:06:23 -0500 Subject: [PATCH 30/48] Remove the !restart command (does not work) --- app/commands.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/app/commands.py b/app/commands.py index e0fca667b..dc122e03b 100644 --- a/app/commands.py +++ b/app/commands.py @@ -995,14 +995,9 @@ async def switchserv(ctx: Context) -> str | None: return "Have a nice journey.." -@command(Privileges.ADMINISTRATOR, aliases=["restart"]) +@command(Privileges.ADMINISTRATOR) async def shutdown(ctx: Context) -> str | None | NoReturn: """Gracefully shutdown the server.""" - if ctx.trigger == "restart": - _signal = signal.SIGUSR1 - else: - _signal = signal.SIGTERM - if ctx.args: # shutdown after a delay delay = timeparse(ctx.args[0]) if not delay: @@ -1020,10 +1015,10 @@ async def shutdown(ctx: Context) -> str | None | NoReturn: app.state.sessions.players.enqueue(app.packets.notification(alert_msg)) - app.state.loop.call_later(delay, os.kill, os.getpid(), _signal) + app.state.loop.call_later(delay, os.kill, os.getpid(), signal.SIGTERM) return f"Enqueued {ctx.trigger}." else: # shutdown immediately - os.kill(os.getpid(), _signal) + os.kill(os.getpid(), signal.SIGTERM) return "Process killed" From 238d1a1aa6ccefb8ccaefac052183e2f5a457385 Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Mon, 19 Feb 2024 16:31:35 -0500 Subject: [PATCH 31/48] Refactor scores, stats and channels repos to build queries with sqlalchemy core 1.4 (#630) * bugfixes Co-authored-git stasby: James Wilson * begin to refactor Co-authored-by: James Wilson * channels, stats Co-authored-by: James Wilson * resolve dependencies + add sqlalchemy type stubs * sqlalchemy-stubs * le bugfix de la typing * _POLEASE * type fixes * type fixes * remove non working test makefile stuff * fixes * re-add pymypysql/aiomysql --------- Co-authored-by: James Wilson --- Makefile | 6 - app/api/domains/osu.py | 8 +- app/commands.py | 6 +- app/repositories/__init__.py | 23 +++ app/repositories/achievements.py | 206 ++++++++----------- app/repositories/channels.py | 253 +++++++++++------------ app/repositories/scores.py | 339 +++++++++++++++--------------- app/repositories/stats.py | 340 +++++++++++++++---------------- poetry.lock | 147 +++++++------ pyproject.toml | 5 +- 10 files changed, 656 insertions(+), 677 deletions(-) diff --git a/Makefile b/Makefile index e2fe84723..e3d1922c0 100644 --- a/Makefile +++ b/Makefile @@ -22,12 +22,6 @@ test: docker-compose -f docker-compose.test.yml up -d bancho-test mysql-test redis-test docker-compose -f docker-compose.test.yml exec -T bancho-test /srv/root/scripts/run-tests.sh -test-local: - poetry run pytest -vv tests/ - -test-dbg: - poetry run pytest -vv --pdb -s tests/ - lint: poetry run pre-commit run --all-files diff --git a/app/api/domains/osu.py b/app/api/domains/osu.py index 2795f8d8a..605b6468d 100644 --- a/app/api/domains/osu.py +++ b/app/api/domains/osu.py @@ -928,7 +928,7 @@ async def osuSubmitModularSelector( # update global & country ranking stats.rank = await score.player.update_rank(score.mode) - await stats_repo.update( + await stats_repo.partial_update( score.player.id, score.mode.value, plays=stats_updates.get("plays", UNSET), @@ -1388,7 +1388,11 @@ async def getScores( return Response("\n".join(response_lines).encode()) if personal_best_score_row is not None: - user_clan = await clans_repo.fetch_one(id=player.clan_id) + user_clan = ( + await clans_repo.fetch_one(id=player.clan_id) + if player.clan_id is not None + else None + ) display_name = ( f"[{user_clan['tag']}] {player.name}" if user_clan is not None diff --git a/app/commands.py b/app/commands.py index dc122e03b..dfcb22acf 100644 --- a/app/commands.py +++ b/app/commands.py @@ -867,7 +867,11 @@ async def user(ctx: Context) -> str | None: else "False" ) - user_clan = await clans_repo.fetch_one(id=player.clan_id) + user_clan = ( + await clans_repo.fetch_one(id=player.clan_id) + if player.clan_id is not None + else None + ) display_name = ( f"[{user_clan['tag']}] {player.name}" if user_clan is not None else player.name ) diff --git a/app/repositories/__init__.py b/app/repositories/__init__.py index e69de29bb..75009bc1b 100644 --- a/app/repositories/__init__.py +++ b/app/repositories/__init__.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from sqlalchemy.dialects.mysql.mysqldb import MySQLDialect_mysqldb +from sqlalchemy.orm import DeclarativeMeta +from sqlalchemy.orm import registry + +mapper_registry = registry() + + +class Base(metaclass=DeclarativeMeta): + __abstract__ = True + + registry = mapper_registry + metadata = mapper_registry.metadata + + __init__ = mapper_registry.constructor + + +class MySQLDialect(MySQLDialect_mysqldb): + default_paramstyle = "named" + + +DIALECT = MySQLDialect() diff --git a/app/repositories/achievements.py b/app/repositories/achievements.py index 56d54e4d9..e2ee00ae5 100644 --- a/app/repositories/achievements.py +++ b/app/repositories/achievements.py @@ -1,6 +1,5 @@ from __future__ import annotations -import textwrap from collections.abc import Callable from typing import TYPE_CHECKING from typing import Any @@ -10,24 +9,45 @@ import app.state.services from app._typing import UNSET from app._typing import _UnsetSentinel +from app.repositories import DIALECT +from app.repositories import Base if TYPE_CHECKING: from app.objects.score import Score -# +-------+--------------+------+-----+---------+----------------+ -# | Field | Type | Null | Key | Default | Extra | -# +-------+--------------+------+-----+---------+----------------+ -# | id | int | NO | PRI | NULL | auto_increment | -# | file | varchar(128) | NO | UNI | NULL | | -# | name | varchar(128) | NO | UNI | NULL | | -# | desc | varchar(256) | NO | UNI | NULL | | -# | cond | varchar(64) | NO | | NULL | | -# +-------+--------------+------+-----+---------+----------------+ - -READ_PARAMS = textwrap.dedent( - """\ - id, file, name, `desc`, cond - """, +from sqlalchemy import Column +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import delete +from sqlalchemy import func +from sqlalchemy import insert +from sqlalchemy import select +from sqlalchemy import update + + +class AchievementsTable(Base): + __tablename__ = "achievements" + + id = Column("id", Integer, primary_key=True) + file = Column("file", String(128), nullable=False) + name = Column("name", String(128, collation="utf8"), nullable=False) + desc = Column("desc", String(256, collation="utf8"), nullable=False) + cond = Column("cond", String(64), nullable=False) + + __table_args__ = ( + Index("achievements_desc_uindex", desc, unique=True), + Index("achievements_file_uindex", file, unique=True), + Index("achievements_name_uindex", name, unique=True), + ) + + +READ_PARAMS = ( + AchievementsTable.id, + AchievementsTable.file, + AchievementsTable.name, + AchievementsTable.desc, + AchievementsTable.cond, ) @@ -39,13 +59,6 @@ class Achievement(TypedDict): cond: Callable[[Score, int], bool] -class AchievementUpdateFields(TypedDict, total=False): - file: str - name: str - desc: str - cond: str - - async def create( file: str, name: str, @@ -53,27 +66,20 @@ async def create( cond: str, ) -> Achievement: """Create a new achievement.""" - query = """\ - INSERT INTO achievements (file, name, desc, cond) - VALUES (:file, :name, :desc, :cond) - """ - params: dict[str, Any] = { - "file": file, - "name": name, - "desc": desc, - "cond": cond, - } - rec_id = await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM achievements - WHERE id = :id - """ - params = { - "id": rec_id, - } - rec = await app.state.services.database.fetch_one(query, params) + insert_stmt = insert(AchievementsTable).values( + file=file, + name=name, + desc=desc, + cond=cond, + ) + compiled = insert_stmt.compile(dialect=DIALECT) + + rec_id = await app.state.services.database.execute(str(compiled), compiled.params) + + select_stmt = select(READ_PARAMS).where(AchievementsTable.id == rec_id) + compiled = select_stmt.compile(dialect=DIALECT) + + rec = await app.state.services.database.fetch_one(str(compiled), compiled.params) assert rec is not None achievement = dict(rec._mapping) @@ -89,17 +95,15 @@ async def fetch_one( if id is None and name is None: raise ValueError("Must provide at least one parameter.") - query = f"""\ - SELECT {READ_PARAMS} - FROM achievements - WHERE id = COALESCE(:id, id) - OR name = COALESCE(:name, name) - """ - params: dict[str, Any] = { - "id": id, - "name": name, - } - rec = await app.state.services.database.fetch_one(query, params) + select_stmt = select(READ_PARAMS) + + if id is not None: + select_stmt = select_stmt.where(AchievementsTable.id == id) + if name is not None: + select_stmt = select_stmt.where(AchievementsTable.name == name) + + compiled = select_stmt.compile(dialect=DIALECT) + rec = await app.state.services.database.fetch_one(str(compiled), compiled.params) if rec is None: return None @@ -111,13 +115,10 @@ async def fetch_one( async def fetch_count() -> int: """Fetch the number of achievements.""" - query = """\ - SELECT COUNT(*) AS count - FROM achievements - """ - params: dict[str, Any] = {} + select_stmt = select(func.count().label("count")).select_from(AchievementsTable) + compiled = select_stmt.compile(dialect=DIALECT) - rec = await app.state.services.database.fetch_one(query, params) + rec = await app.state.services.database.fetch_one(str(compiled), compiled.params) assert rec is not None return cast(int, rec._mapping["count"]) @@ -127,21 +128,16 @@ async def fetch_many( page_size: int | None = None, ) -> list[Achievement]: """Fetch a list of achievements.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM achievements - """ - params: dict[str, Any] = {} - + select_stmt = select(READ_PARAMS) if page is not None and page_size is not None: - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["page_size"] = page_size - params["offset"] = (page - 1) * page_size + select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) - records = await app.state.services.database.fetch_all(query, params) + compiled = select_stmt.compile(dialect=DIALECT) + + records = await app.state.services.database.fetch_all( + str(compiled), + compiled.params, + ) achievements: list[dict[str, Any]] = [] @@ -153,7 +149,7 @@ async def fetch_many( return cast(list[Achievement], achievements) -async def update( +async def partial_update( id: int, file: str | _UnsetSentinel = UNSET, name: str | _UnsetSentinel = UNSET, @@ -161,35 +157,22 @@ async def update( cond: str | _UnsetSentinel = UNSET, ) -> Achievement | None: """Update an existing achievement.""" - update_fields: AchievementUpdateFields = {} + update_stmt = update(AchievementsTable).where(AchievementsTable.id == id) if not isinstance(file, _UnsetSentinel): - update_fields["file"] = file + update_stmt = update_stmt.values(file=file) if not isinstance(name, _UnsetSentinel): - update_fields["name"] = name + update_stmt = update_stmt.values(name=name) if not isinstance(desc, _UnsetSentinel): - update_fields["desc"] = desc + update_stmt = update_stmt.values(desc=desc) if not isinstance(cond, _UnsetSentinel): - update_fields["cond"] = cond - - query = f"""\ - UPDATE achievements - SET {",".join(f"{k} = COALESCE(:{k}, {k})" for k in update_fields)} - WHERE id = :id - """ - params: dict[str, Any] = { - "id": id, - } | update_fields - await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM achievements - WHERE id = :id - """ - params = { - "id": id, - } - rec = await app.state.services.database.fetch_one(query, params) + update_stmt = update_stmt.values(cond=cond) + + compiled = update_stmt.compile(dialect=DIALECT) + await app.state.services.database.execute(str(compiled), compiled.params) + + select_stmt = select(READ_PARAMS).where(AchievementsTable.id == id) + compiled = select_stmt.compile(dialect=DIALECT) + rec = await app.state.services.database.fetch_one(str(compiled), compiled.params) assert rec is not None achievement = dict(rec._mapping) @@ -197,30 +180,19 @@ async def update( return cast(Achievement, achievement) -async def delete( +async def delete_one( id: int, ) -> Achievement | None: """Delete an existing achievement.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM achievements - WHERE id = :id - """ - params: dict[str, Any] = { - "id": id, - } - rec = await app.state.services.database.fetch_one(query, params) + select_stmt = select(READ_PARAMS).where(AchievementsTable.id == id) + compiled = select_stmt.compile(dialect=DIALECT) + rec = await app.state.services.database.fetch_one(str(compiled), compiled.params) if rec is None: return None - query = """\ - DELETE FROM achievements - WHERE id = :id - """ - params = { - "id": id, - } - await app.state.services.database.execute(query, params) + delete_stmt = delete(AchievementsTable).where(AchievementsTable.id == id) + compiled = delete_stmt.compile(dialect=DIALECT) + await app.state.services.database.execute(str(compiled), compiled.params) achievement = dict(rec._mapping) achievement["cond"] = eval(f'lambda score, mode_vn: {rec["cond"]}') diff --git a/app/repositories/channels.py b/app/repositories/channels.py index da89f0595..35f7418f4 100644 --- a/app/repositories/channels.py +++ b/app/repositories/channels.py @@ -5,25 +5,47 @@ from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import delete +from sqlalchemy import func +from sqlalchemy import insert +from sqlalchemy import select +from sqlalchemy import update +from sqlalchemy.dialects.mysql import TINYINT + import app.state.services from app._typing import UNSET from app._typing import _UnsetSentinel +from app.repositories import DIALECT +from app.repositories import Base + + +class ChannelsTable(Base): + __tablename__ = "channels" + + id = Column("id", Integer, primary_key=True) + name = Column("name", String(32), nullable=False) + topic = Column("topic", String(256), nullable=False) + read_priv = Column("read_priv", Integer, nullable=False, server_default="1") + write_priv = Column("write_priv", Integer, nullable=False, server_default="2") + auto_join = Column("auto_join", TINYINT(1), nullable=False, server_default="0") + + __table_args__ = ( + Index("channels_name_uindex", name, unique=True), + Index("channels_auto_join_index", auto_join), + ) + -# +------------+--------------+------+-----+---------+----------------+ -# | Field | Type | Null | Key | Default | Extra | -# +------------+--------------+------+-----+---------+----------------+ -# | id | int | NO | PRI | NULL | auto_increment | -# | name | varchar(32) | NO | UNI | NULL | | -# | topic | varchar(256) | NO | | NULL | | -# | read_priv | int | NO | | 1 | | -# | write_priv | int | NO | | 2 | | -# | auto_join | tinyint(1) | NO | | 0 | | -# +------------+--------------+------+-----+---------+----------------+ - -READ_PARAMS = textwrap.dedent( - """\ - id, name, topic, read_priv, write_priv, auto_join - """, +READ_PARAMS = ( + ChannelsTable.id, + ChannelsTable.name, + ChannelsTable.topic, + ChannelsTable.read_priv, + ChannelsTable.write_priv, + ChannelsTable.auto_join, ) @@ -36,13 +58,6 @@ class Channel(TypedDict): auto_join: bool -class ChannelUpdateFields(TypedDict, total=False): - topic: str - read_priv: int - write_priv: int - auto_join: bool - - async def create( name: str, topic: str, @@ -51,30 +66,22 @@ async def create( auto_join: bool, ) -> Channel: """Create a new channel.""" - query = """\ - INSERT INTO channels (name, topic, read_priv, write_priv, auto_join) - VALUES (:name, :topic, :read_priv, :write_priv, :auto_join) - - """ - params: dict[str, Any] = { - "name": name, - "topic": topic, - "read_priv": read_priv, - "write_priv": write_priv, - "auto_join": auto_join, - } - rec_id = await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM channels - WHERE id = :id - """ - params = { - "id": rec_id, - } - - channel = await app.state.services.database.fetch_one(query, params) + insert_stmt = insert(ChannelsTable).values( + name=name, + topic=topic, + read_priv=read_priv, + write_priv=write_priv, + auto_join=auto_join, + ) + compiled = insert_stmt.compile(dialect=DIALECT) + rec_id = await app.state.services.database.execute(str(compiled), compiled.params) + + select_stmt = select(READ_PARAMS).where(ChannelsTable.id == rec_id) + compiled = select_stmt.compile(dialect=DIALECT) + channel = await app.state.services.database.fetch_one( + str(compiled), + compiled.params, + ) assert channel is not None return cast(Channel, dict(channel._mapping)) @@ -87,17 +94,19 @@ async def fetch_one( """Fetch a single channel.""" if id is None and name is None: raise ValueError("Must provide at least one parameter.") - query = f"""\ - SELECT {READ_PARAMS} - FROM channels - WHERE id = COALESCE(:id, id) - AND name = COALESCE(:name, name) - """ - params: dict[str, Any] = { - "id": id, - "name": name, - } - channel = await app.state.services.database.fetch_one(query, params) + + select_stmt = select(READ_PARAMS) + + if id is not None: + select_stmt = select_stmt.where(ChannelsTable.id == id) + if name is not None: + select_stmt = select_stmt.where(ChannelsTable.name == name) + + compiled = select_stmt.compile(dialect=DIALECT) + channel = await app.state.services.database.fetch_one( + str(compiled), + compiled.params, + ) return cast(Channel, dict(channel._mapping)) if channel is not None else None @@ -110,20 +119,17 @@ async def fetch_count( if read_priv is None and write_priv is None and auto_join is None: raise ValueError("Must provide at least one parameter.") - query = """\ - SELECT COUNT(*) AS count - FROM channels - WHERE read_priv = COALESCE(:read_priv, read_priv) - AND write_priv = COALESCE(:write_priv, write_priv) - AND auto_join = COALESCE(:auto_join, auto_join) - """ - params: dict[str, Any] = { - "read_priv": read_priv, - "write_priv": write_priv, - "auto_join": auto_join, - } - - rec = await app.state.services.database.fetch_one(query, params) + select_stmt = select(func.count().label("count")).select_from(ChannelsTable) + + if read_priv is not None: + select_stmt = select_stmt.where(ChannelsTable.read_priv == read_priv) + if write_priv is not None: + select_stmt = select_stmt.where(ChannelsTable.write_priv == write_priv) + if auto_join is not None: + select_stmt = select_stmt.where(ChannelsTable.auto_join == auto_join) + + compiled = select_stmt.compile(dialect=DIALECT) + rec = await app.state.services.database.fetch_one(str(compiled), compiled.params) assert rec is not None return cast(int, rec._mapping["count"]) @@ -136,32 +142,27 @@ async def fetch_many( page_size: int | None = None, ) -> list[Channel]: """Fetch multiple channels from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM channels - WHERE read_priv = COALESCE(:read_priv, read_priv) - AND write_priv = COALESCE(:write_priv, write_priv) - AND auto_join = COALESCE(:auto_join, auto_join) - """ - params: dict[str, Any] = { - "read_priv": read_priv, - "write_priv": write_priv, - "auto_join": auto_join, - } + select_stmt = select(READ_PARAMS) + + if read_priv is not None: + select_stmt = select_stmt.where(ChannelsTable.read_priv == read_priv) + if write_priv is not None: + select_stmt = select_stmt.where(ChannelsTable.write_priv == write_priv) + if auto_join is not None: + select_stmt = select_stmt.where(ChannelsTable.auto_join == auto_join) if page is not None and page_size is not None: - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["limit"] = page_size - params["offset"] = (page - 1) * page_size - - channels = await app.state.services.database.fetch_all(query, params) + select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) + + compiled = select_stmt.compile(dialect=DIALECT) + channels = await app.state.services.database.fetch_all( + str(compiled), + compiled.params, + ) return cast(list[Channel], [dict(c._mapping) for c in channels]) -async def update( +async def partial_update( name: str, topic: str | _UnsetSentinel = UNSET, read_priv: int | _UnsetSentinel = UNSET, @@ -169,60 +170,40 @@ async def update( auto_join: bool | _UnsetSentinel = UNSET, ) -> Channel | None: """Update a channel in the database.""" - update_fields: ChannelUpdateFields = {} + update_stmt = update(ChannelsTable).where(ChannelsTable.name == name) + if not isinstance(topic, _UnsetSentinel): - update_fields["topic"] = topic + update_stmt = update_stmt.values(topic=topic) if not isinstance(read_priv, _UnsetSentinel): - update_fields["read_priv"] = read_priv + update_stmt = update_stmt.values(read_priv=read_priv) if not isinstance(write_priv, _UnsetSentinel): - update_fields["write_priv"] = write_priv + update_stmt = update_stmt.values(write_priv=write_priv) if not isinstance(auto_join, _UnsetSentinel): - update_fields["auto_join"] = auto_join - - query = f"""\ - UPDATE channels - SET {",".join(f"{k} = COALESCE(:{k}, {k})" for k in update_fields)} - WHERE name = :name - """ - params: dict[str, Any] = { - "name": name, - } | update_fields - await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM channels - WHERE name = :name - """ - params = { - "name": name, - } - channel = await app.state.services.database.fetch_one(query, params) + update_stmt = update_stmt.values(auto_join=auto_join) + + compiled = update_stmt.compile(dialect=DIALECT) + await app.state.services.database.execute(str(compiled), compiled.params) + + select_stmt = select(READ_PARAMS).where(ChannelsTable.name == name) + compiled = select_stmt.compile(dialect=DIALECT) + channel = await app.state.services.database.fetch_one( + str(compiled), + compiled.params, + ) return cast(Channel, dict(channel._mapping)) if channel is not None else None -async def delete( +async def delete_one( name: str, ) -> Channel | None: """Delete a channel from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM channels - WHERE name = :name - """ - params: dict[str, Any] = { - "name": name, - } - rec = await app.state.services.database.fetch_one(query, params) + select_stmt = select(READ_PARAMS).where(ChannelsTable.name == name) + compiled = select_stmt.compile(dialect=DIALECT) + rec = await app.state.services.database.fetch_one(str(compiled), compiled.params) if rec is None: return None - query = """\ - DELETE FROM channels - WHERE name = :name - """ - params = { - "name": name, - } - channel = await app.state.services.database.execute(query, params) - return cast(Channel, dict(channel._mapping)) if channel is not None else None + delete_stmt = delete(ChannelsTable).where(ChannelsTable.name == name) + compiled = delete_stmt.compile(dialect=DIALECT) + await app.state.services.database.execute(str(compiled), compiled.params) + return cast(Channel, dict(rec._mapping)) if rec is not None else None diff --git a/app/repositories/scores.py b/app/repositories/scores.py index b774307c2..d1b9dc4fc 100644 --- a/app/repositories/scores.py +++ b/app/repositories/scores.py @@ -1,47 +1,90 @@ from __future__ import annotations -import textwrap from datetime import datetime -from typing import Any from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import func +from sqlalchemy import insert +from sqlalchemy import select +from sqlalchemy import update +from sqlalchemy.dialects.mysql import FLOAT +from sqlalchemy.dialects.mysql import TINYINT + import app.state.services from app._typing import UNSET from app._typing import _UnsetSentinel +from app.repositories import DIALECT +from app.repositories import Base + + +class ScoresTable(Base): + __tablename__ = "scores" + + id = Column("id", Integer, primary_key=True) + map_md5 = Column("map_md5", String(32), nullable=False) + score = Column("score", Integer, nullable=False) + pp = Column("pp", FLOAT(precision=6, scale=3), nullable=False) + acc = Column("acc", FLOAT(precision=6, scale=3), nullable=False) + max_combo = Column("max_combo", Integer, nullable=False) + mods = Column("mods", Integer, nullable=False) + n300 = Column("n300", Integer, nullable=False) + n100 = Column("n100", Integer, nullable=False) + n50 = Column("n50", Integer, nullable=False) + nmiss = Column("nmiss", Integer, nullable=False) + ngeki = Column("ngeki", Integer, nullable=False) + nkatu = Column("nkatu", Integer, nullable=False) + grade = Column("grade", String(2), nullable=False, server_default="N") + status = Column("status", Integer, nullable=False) + mode = Column("mode", Integer, nullable=False) + play_time = Column("play_time", DateTime, nullable=False) + time_elapsed = Column("time_elapsed", Integer, nullable=False) + client_flags = Column("client_flags", Integer, nullable=False) + userid = Column("userid", Integer, nullable=False) + perfect = Column("perfect", TINYINT(1), nullable=False) + online_checksum = Column("online_checksum", String(32), nullable=False) -# +-----------------+-----------------+------+-----+---------+----------------+ -# | Field | Type | Null | Key | Default | Extra | -# +-----------------+-----------------+------+-----+---------+----------------+ -# | id | bigint unsigned | NO | PRI | NULL | auto_increment | -# | map_md5 | char(32) | NO | | NULL | | -# | score | int | NO | | NULL | | -# | pp | float(7,3) | NO | | NULL | | -# | acc | float(6,3) | NO | | NULL | | -# | max_combo | int | NO | | NULL | | -# | mods | int | NO | | NULL | | -# | n300 | int | NO | | NULL | | -# | n100 | int | NO | | NULL | | -# | n50 | int | NO | | NULL | | -# | nmiss | int | NO | | NULL | | -# | ngeki | int | NO | | NULL | | -# | nkatu | int | NO | | NULL | | -# | grade | varchar(2) | NO | | N | | -# | status | tinyint | NO | | NULL | | -# | mode | tinyint | NO | | NULL | | -# | play_time | datetime | NO | | NULL | | -# | time_elapsed | int | NO | | NULL | | -# | client_flags | int | NO | | NULL | | -# | userid | int | NO | | NULL | | -# | perfect | tinyint(1) | NO | | NULL | | -# | online_checksum | char(32) | NO | | NULL | | -# +-----------------+-----------------+------+-----+---------+----------------+ + __table_args__ = ( + Index("scores_map_md5_index", map_md5), + Index("scores_score_index", score), + Index("scores_pp_index", pp), + Index("scores_mods_index", mods), + Index("scores_status_index", status), + Index("scores_mode_index", mode), + Index("scores_play_time_index", play_time), + Index("scores_userid_index", userid), + Index("scores_online_checksum_index", online_checksum), + ) -READ_PARAMS = textwrap.dedent( - """\ - id, map_md5, score, pp, acc, max_combo, mods, n300, n100, n50, nmiss, ngeki, nkatu, - grade, status, mode, play_time, time_elapsed, client_flags, userid, perfect, online_checksum - """, + +READ_PARAMS = ( + ScoresTable.id, + ScoresTable.map_md5, + ScoresTable.score, + ScoresTable.pp, + ScoresTable.acc, + ScoresTable.max_combo, + ScoresTable.mods, + ScoresTable.n300, + ScoresTable.n100, + ScoresTable.n50, + ScoresTable.nmiss, + ScoresTable.ngeki, + ScoresTable.nkatu, + ScoresTable.grade, + ScoresTable.status, + ScoresTable.mode, + ScoresTable.play_time, + ScoresTable.time_elapsed, + ScoresTable.client_flags, + ScoresTable.userid, + ScoresTable.perfect, + ScoresTable.online_checksum, ) @@ -70,30 +113,6 @@ class Score(TypedDict): online_checksum: str -class ScoreUpdateFields(TypedDict, total=False): - map_md5: str - score: int - pp: float - acc: float - max_combo: int - mods: int - n300: int - n100: int - n50: int - nmiss: int - ngeki: int - nkatu: int - grade: str - status: int - mode: int - play_time: datetime - time_elapsed: int - client_flags: int - userid: int - perfect: int - online_checksum: str - - async def create( map_md5: str, score: int, @@ -117,65 +136,52 @@ async def create( perfect: int, online_checksum: str, ) -> Score: - query = """\ - INSERT INTO scores (map_md5, score, pp, acc, max_combo, mods, n300, - n100, n50, nmiss, ngeki, nkatu, grade, status, - mode, play_time, time_elapsed, client_flags, - userid, perfect, online_checksum) - VALUES (:map_md5, :score, :pp, :acc, :max_combo, :mods, :n300, - :n100, :n50, :nmiss, :ngeki, :nkatu, :grade, :status, - :mode, :play_time, :time_elapsed, :client_flags, - :userid, :perfect, :online_checksum) - """ - params: dict[str, Any] = { - "map_md5": map_md5, - "score": score, - "pp": pp, - "acc": acc, - "max_combo": max_combo, - "mods": mods, - "n300": n300, - "n100": n100, - "n50": n50, - "nmiss": nmiss, - "ngeki": ngeki, - "nkatu": nkatu, - "grade": grade, - "status": status, - "mode": mode, - "play_time": play_time, - "time_elapsed": time_elapsed, - "client_flags": client_flags, - "userid": user_id, - "perfect": perfect, - "online_checksum": online_checksum, - } - rec_id = await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM scores - WHERE id = :id - """ - params = { - "id": rec_id, - } - rec = await app.state.services.database.fetch_one(query, params) + insert_stmt = insert(ScoresTable).values( + map_md5=map_md5, + score=score, + pp=pp, + acc=acc, + max_combo=max_combo, + mods=mods, + n300=n300, + n100=n100, + n50=n50, + nmiss=nmiss, + ngeki=ngeki, + nkatu=nkatu, + grade=grade, + status=status, + mode=mode, + play_time=play_time, + time_elapsed=time_elapsed, + client_flags=client_flags, + userid=user_id, + perfect=perfect, + online_checksum=online_checksum, + ) + compiled = insert_stmt.compile(dialect=DIALECT) + rec_id = await app.state.services.database.execute( + query=str(compiled), + values=compiled.params, + ) + select_stmt = select(READ_PARAMS).where(ScoresTable.id == rec_id) + compiled = select_stmt.compile(dialect=DIALECT) + rec = await app.state.services.database.fetch_one( + query=str(compiled), + values=compiled.params, + ) assert rec is not None return cast(Score, dict(rec._mapping)) async def fetch_one(id: int) -> Score | None: - query = f"""\ - SELECT {READ_PARAMS} - FROM scores - WHERE id = :id - """ - params: dict[str, Any] = { - "id": id, - } - rec = await app.state.services.database.fetch_one(query, params) + select_stmt = select(READ_PARAMS).where(ScoresTable.id == id) + compiled = select_stmt.compile(dialect=DIALECT) + rec = await app.state.services.database.fetch_one( + query=str(compiled), + values=compiled.params, + ) return cast(Score, dict(rec._mapping)) if rec is not None else None @@ -187,23 +193,23 @@ async def fetch_count( mode: int | None = None, user_id: int | None = None, ) -> int: - query = """\ - SELECT COUNT(*) AS count - FROM scores - WHERE map_md5 = COALESCE(:map_md5, map_md5) - AND mods = COALESCE(:mods, mods) - AND status = COALESCE(:status, status) - AND mode = COALESCE(:mode, mode) - AND userid = COALESCE(:userid, userid) - """ - params: dict[str, Any] = { - "map_md5": map_md5, - "mods": mods, - "status": status, - "mode": mode, - "userid": user_id, - } - rec = await app.state.services.database.fetch_one(query, params) + select_stmt = select(func.count().label("count")).select_from(ScoresTable) + if map_md5 is not None: + select_stmt = select_stmt.where(ScoresTable.map_md5 == map_md5) + if mods is not None: + select_stmt = select_stmt.where(ScoresTable.mods == mods) + if status is not None: + select_stmt = select_stmt.where(ScoresTable.status == status) + if mode is not None: + select_stmt = select_stmt.where(ScoresTable.mode == mode) + if user_id is not None: + select_stmt = select_stmt.where(ScoresTable.userid == user_id) + + compiled = select_stmt.compile(dialect=DIALECT) + rec = await app.state.services.database.fetch_one( + query=str(compiled), + values=compiled.params, + ) assert rec is not None return cast(int, rec._mapping["count"]) @@ -217,65 +223,52 @@ async def fetch_many( page: int | None = None, page_size: int | None = None, ) -> list[Score]: - query = f"""\ - SELECT {READ_PARAMS} - FROM scores - WHERE map_md5 = COALESCE(:map_md5, map_md5) - AND mods = COALESCE(:mods, mods) - AND status = COALESCE(:status, status) - AND mode = COALESCE(:mode, mode) - AND userid = COALESCE(:userid, userid) - """ - params: dict[str, Any] = { - "map_md5": map_md5, - "mods": mods, - "status": status, - "mode": mode, - "userid": user_id, - } + select_stmt = select(READ_PARAMS) + if map_md5 is not None: + select_stmt = select_stmt.where(ScoresTable.map_md5 == map_md5) + if mods is not None: + select_stmt = select_stmt.where(ScoresTable.mods == mods) + if status is not None: + select_stmt = select_stmt.where(ScoresTable.status == status) + if mode is not None: + select_stmt = select_stmt.where(ScoresTable.mode == mode) + if user_id is not None: + select_stmt = select_stmt.where(ScoresTable.userid == user_id) + if page is not None and page_size is not None: - query += """\ - LIMIT :page_size - OFFSET :offset - """ - params["page_size"] = page_size - params["offset"] = (page - 1) * page_size + select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) - recs = await app.state.services.database.fetch_all(query, params) + compiled = select_stmt.compile(dialect=DIALECT) + recs = await app.state.services.database.fetch_all( + query=str(compiled), + values=compiled.params, + ) return cast(list[Score], [dict(r._mapping) for r in recs]) -async def update( +async def partial_update( id: int, pp: float | _UnsetSentinel = UNSET, status: int | _UnsetSentinel = UNSET, ) -> Score | None: """Update an existing score.""" - update_fields: ScoreUpdateFields = {} + update_stmt = update(ScoresTable).where(ScoresTable.id == id) if not isinstance(pp, _UnsetSentinel): - update_fields["pp"] = pp + update_stmt = update_stmt.values(pp=pp) if not isinstance(status, _UnsetSentinel): - update_fields["status"] = status - - query = f"""\ - UPDATE scores - SET {",".join(f"{k} = COALESCE(:{k}, {k})" for k in update_fields)} - WHERE id = :id - """ - params: dict[str, Any] = { - "id": id, - } | update_fields - await app.state.services.database.execute(query, params) + update_stmt = update_stmt.values(status=status) + compiled = update_stmt.compile(dialect=DIALECT) + await app.state.services.database.execute( + query=str(compiled), + values=compiled.params, + ) - query = f"""\ - SELECT {READ_PARAMS} - FROM scores - WHERE id = :id - """ - params = { - "id": id, - } - rec = await app.state.services.database.fetch_one(query, params) + select_stmt = select(READ_PARAMS).where(ScoresTable.id == id) + compiled = select_stmt.compile(dialect=DIALECT) + rec = await app.state.services.database.fetch_one( + query=str(compiled), + values=compiled.params, + ) return cast(Score, dict(rec._mapping)) if rec is not None else None diff --git a/app/repositories/stats.py b/app/repositories/stats.py index 30d6635e4..12177f875 100644 --- a/app/repositories/stats.py +++ b/app/repositories/stats.py @@ -1,40 +1,77 @@ from __future__ import annotations -import textwrap -from typing import Any from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import func +from sqlalchemy import insert +from sqlalchemy import select +from sqlalchemy import update + +# from sqlalchemy import update +from sqlalchemy.dialects.mysql import FLOAT +from sqlalchemy.dialects.mysql import TINYINT + import app.state.services from app._typing import UNSET from app._typing import _UnsetSentinel +from app.repositories import DIALECT +from app.repositories import Base + + +class StatsTable(Base): + __tablename__ = "stats" -# +--------------+-----------------+------+-----+---------+----------------+ -# | Field | Type | Null | Key | Default | Extra | -# +--------------+-----------------+------+-----+---------+----------------+ -# | id | int | NO | PRI | NULL | auto_increment | -# | mode | tinyint(1) | NO | PRI | NULL | | -# | tscore | bigint unsigned | NO | | 0 | | -# | rscore | bigint unsigned | NO | | 0 | | -# | pp | int unsigned | NO | | 0 | | -# | plays | int unsigned | NO | | 0 | | -# | playtime | int unsigned | NO | | 0 | | -# | acc | float(6,3) | NO | | 0.000 | | -# | max_combo | int unsigned | NO | | 0 | | -# | total_hits | int unsigned | NO | | 0 | | -# | replay_views | int unsigned | NO | | 0 | | -# | xh_count | int unsigned | NO | | 0 | | -# | x_count | int unsigned | NO | | 0 | | -# | sh_count | int unsigned | NO | | 0 | | -# | s_count | int unsigned | NO | | 0 | | -# | a_count | int unsigned | NO | | 0 | | -# +--------------+-----------------+------+-----+---------+----------------+ + id = Column("id", Integer, primary_key=True) + mode = Column("mode", TINYINT(1), primary_key=True) + tscore = Column("tscore", Integer, nullable=False, server_default="0") + rscore = Column("rscore", Integer, nullable=False, server_default="0") + pp = Column("pp", Integer, nullable=False, server_default="0") + plays = Column("plays", Integer, nullable=False, server_default="0") + playtime = Column("playtime", Integer, nullable=False, server_default="0") + acc = Column( + "acc", + FLOAT(precision=6, scale=3), + nullable=False, + server_default="0.000", + ) + max_combo = Column("max_combo", Integer, nullable=False, server_default="0") + total_hits = Column("total_hits", Integer, nullable=False, server_default="0") + replay_views = Column("replay_views", Integer, nullable=False, server_default="0") + xh_count = Column("xh_count", Integer, nullable=False, server_default="0") + x_count = Column("x_count", Integer, nullable=False, server_default="0") + sh_count = Column("sh_count", Integer, nullable=False, server_default="0") + s_count = Column("s_count", Integer, nullable=False, server_default="0") + a_count = Column("a_count", Integer, nullable=False, server_default="0") -READ_PARAMS = textwrap.dedent( - """\ - id, mode, tscore, rscore, pp, plays, playtime, acc, max_combo, total_hits, - replay_views, xh_count, x_count, sh_count, s_count, a_count - """, + __table_args__ = ( + Index("stats_mode_index", mode), + Index("stats_pp_index", pp), + Index("stats_tscore_index", tscore), + Index("stats_rscore_index", rscore), + ) + + +READ_PARAMS = ( + StatsTable.id, + StatsTable.mode, + StatsTable.tscore, + StatsTable.rscore, + StatsTable.pp, + StatsTable.plays, + StatsTable.playtime, + StatsTable.acc, + StatsTable.max_combo, + StatsTable.total_hits, + StatsTable.replay_views, + StatsTable.xh_count, + StatsTable.x_count, + StatsTable.sh_count, + StatsTable.s_count, + StatsTable.a_count, ) @@ -57,100 +94,66 @@ class Stat(TypedDict): a_count: int -class StatUpdateFields(TypedDict, total=False): - tscore: int - rscore: int - pp: int - plays: int - playtime: int - acc: float - max_combo: int - total_hits: int - replay_views: int - xh_count: int - x_count: int - sh_count: int - s_count: int - a_count: int - - -async def create( - player_id: int, - mode: int, - # TODO: should we allow init with values? -) -> Stat: +async def create(player_id: int, mode: int) -> Stat: """Create a new player stats entry in the database.""" - query = f"""\ - INSERT INTO stats (id, mode) - VALUES (:id, :mode) - """ - params: dict[str, Any] = { - "id": player_id, - "mode": mode, - } - rec_id = await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM stats - WHERE id = :id - """ - params = { - "id": rec_id, - } - stat = await app.state.services.database.fetch_one(query, params) + insert_stmt = insert(StatsTable).values(id=player_id, mode=mode) + compiled = insert_stmt.compile(dialect=DIALECT) + rec_id = await app.state.services.database.execute( + query=str(compiled), + values=compiled.params, + ) + select_stmt = select(READ_PARAMS).where(StatsTable.id == rec_id) + compiled = select_stmt.compile(dialect=DIALECT) + stat = await app.state.services.database.fetch_one( + query=str(compiled), + values=compiled.params, + ) assert stat is not None return cast(Stat, dict(stat._mapping)) async def create_all_modes(player_id: int) -> list[Stat]: """Create new player stats entries for each game mode in the database.""" - query = f"""\ - INSERT INTO stats (id, mode) - VALUES (:id, :mode) - """ - params_list = [ - {"id": player_id, "mode": mode} - for mode in ( - 0, # vn!std - 1, # vn!taiko - 2, # vn!catch - 3, # vn!mania - 4, # rx!std - 5, # rx!taiko - 6, # rx!catch - 8, # ap!std - ) - ] - await app.state.services.database.execute_many(query, params_list) + insert_stmt = insert(StatsTable).values( + [ + {"id": player_id, "mode": mode} + for mode in ( + 0, # vn!std + 1, # vn!taiko + 2, # vn!catch + 3, # vn!mania + 4, # rx!std + 5, # rx!taiko + 6, # rx!catch + 8, # ap!std + ) + ], + ) + compiled = insert_stmt.compile(dialect=DIALECT) + await app.state.services.database.execute(str(compiled), compiled.params) - query = f"""\ - SELECT {READ_PARAMS} - FROM stats - WHERE id = :id - """ - params: dict[str, Any] = { - "id": player_id, - } - stats = await app.state.services.database.fetch_all(query, params) + select_stmt = select(READ_PARAMS).where(StatsTable.id == player_id) + compiled = select_stmt.compile(dialect=DIALECT) + stats = await app.state.services.database.fetch_all( + query=str(compiled), + values=compiled.params, + ) return cast(list[Stat], [dict(s._mapping) for s in stats]) async def fetch_one(player_id: int, mode: int) -> Stat | None: """Fetch a player stats entry from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM stats - WHERE id = :id - AND mode = :mode - """ - params: dict[str, Any] = { - "id": player_id, - "mode": mode, - } - stat = await app.state.services.database.fetch_one(query, params) - + select_stmt = ( + select(READ_PARAMS) + .where(StatsTable.id == player_id) + .where(StatsTable.mode == mode) + ) + compiled = select_stmt.compile(dialect=DIALECT) + stat = await app.state.services.database.fetch_one( + query=str(compiled), + values=compiled.params, + ) return cast(Stat, dict(stat._mapping)) if stat is not None else None @@ -158,17 +161,16 @@ async def fetch_count( player_id: int | None = None, mode: int | None = None, ) -> int: - query = """\ - SELECT COUNT(*) AS count - FROM stats - WHERE id = COALESCE(:id, id) - AND mode = COALESCE(:mode, mode) - """ - params: dict[str, Any] = { - "id": player_id, - "mode": mode, - } - rec = await app.state.services.database.fetch_one(query, params) + select_stmt = select(func.count().label("count")).select_from(StatsTable) + if player_id is not None: + select_stmt = select_stmt.where(StatsTable.id == player_id) + if mode is not None: + select_stmt = select_stmt.where(StatsTable.mode == mode) + compiled = select_stmt.compile(dialect=DIALECT) + rec = await app.state.services.database.fetch_one( + query=str(compiled), + values=compiled.params, + ) assert rec is not None return cast(int, rec._mapping["count"]) @@ -179,30 +181,22 @@ async def fetch_many( page: int | None = None, page_size: int | None = None, ) -> list[Stat]: - query = f"""\ - SELECT {READ_PARAMS} - FROM stats - WHERE id = COALESCE(:id, id) - AND mode = COALESCE(:mode, mode) - """ - params: dict[str, Any] = { - "id": player_id, - "mode": mode, - } - + select_stmt = select(READ_PARAMS) + if player_id is not None: + select_stmt = select_stmt.where(StatsTable.id == player_id) + if mode is not None: + select_stmt = select_stmt.where(StatsTable.mode == mode) if page is not None and page_size is not None: - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["limit"] = page_size - params["offset"] = (page - 1) * page_size - - stats = await app.state.services.database.fetch_all(query, params) + select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) + compiled = select_stmt.compile(dialect=DIALECT) + stats = await app.state.services.database.fetch_all( + query=str(compiled), + values=compiled.params, + ) return cast(list[Stat], [dict(s._mapping) for s in stats]) -async def update( +async def partial_update( player_id: int, mode: int, tscore: int | _UnsetSentinel = UNSET, @@ -221,60 +215,54 @@ async def update( a_count: int | _UnsetSentinel = UNSET, ) -> Stat | None: """Update a player stats entry in the database.""" - update_fields: StatUpdateFields = {} + update_stmt = ( + update(StatsTable) + .where(StatsTable.id == player_id) + .where(StatsTable.mode == mode) + ) if not isinstance(tscore, _UnsetSentinel): - update_fields["tscore"] = tscore + update_stmt = update_stmt.values(tscore=tscore) if not isinstance(rscore, _UnsetSentinel): - update_fields["rscore"] = rscore + update_stmt = update_stmt.values(rscore=rscore) if not isinstance(pp, _UnsetSentinel): - update_fields["pp"] = pp + update_stmt = update_stmt.values(pp=pp) if not isinstance(plays, _UnsetSentinel): - update_fields["plays"] = plays + update_stmt = update_stmt.values(plays=plays) if not isinstance(playtime, _UnsetSentinel): - update_fields["playtime"] = playtime + update_stmt = update_stmt.values(playtime=playtime) if not isinstance(acc, _UnsetSentinel): - update_fields["acc"] = acc + update_stmt = update_stmt.values(acc=acc) if not isinstance(max_combo, _UnsetSentinel): - update_fields["max_combo"] = max_combo + update_stmt = update_stmt.values(max_combo=max_combo) if not isinstance(total_hits, _UnsetSentinel): - update_fields["total_hits"] = total_hits + update_stmt = update_stmt.values(total_hits=total_hits) if not isinstance(replay_views, _UnsetSentinel): - update_fields["replay_views"] = replay_views + update_stmt = update_stmt.values(replay_views=replay_views) if not isinstance(xh_count, _UnsetSentinel): - update_fields["xh_count"] = xh_count + update_stmt = update_stmt.values(xh_count=xh_count) if not isinstance(x_count, _UnsetSentinel): - update_fields["x_count"] = x_count + update_stmt = update_stmt.values(x_count=x_count) if not isinstance(sh_count, _UnsetSentinel): - update_fields["sh_count"] = sh_count + update_stmt = update_stmt.values(sh_count=sh_count) if not isinstance(s_count, _UnsetSentinel): - update_fields["s_count"] = s_count + update_stmt = update_stmt.values(s_count=s_count) if not isinstance(a_count, _UnsetSentinel): - update_fields["a_count"] = a_count + update_stmt = update_stmt.values(a_count=a_count) - query = f"""\ - UPDATE stats - SET {",".join(f"{k} = COALESCE(:{k}, {k})" for k in update_fields)} - WHERE id = :id - AND mode = :mode - """ - params: dict[str, Any] = { - "id": player_id, - "mode": mode, - } | update_fields - await app.state.services.database.execute(query, params) + compiled = update_stmt.compile(dialect=DIALECT) + await app.state.services.database.execute(str(compiled), compiled.params) - query = f"""\ - SELECT {READ_PARAMS} - FROM stats - WHERE id = :id - AND mode = :mode - """ - params = { - "id": player_id, - "mode": mode, - } - stats = await app.state.services.database.fetch_one(query, params) - return cast(Stat, dict(stats._mapping)) if stats is not None else None + select_stmt = ( + select(READ_PARAMS) + .where(StatsTable.id == player_id) + .where(StatsTable.mode == mode) + ) + compiled = select_stmt.compile(dialect=DIALECT) + stat = await app.state.services.database.fetch_one( + query=str(compiled), + values=compiled.params, + ) + return cast(Stat, dict(stat._mapping)) if stat is not None else None # TODO: delete? diff --git a/poetry.lock b/poetry.lock index 51c6809af..ef594dcf9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -76,13 +76,13 @@ files = [ [[package]] name = "anyio" -version = "4.2.0" +version = "4.3.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.8" files = [ - {file = "anyio-4.2.0-py3-none-any.whl", hash = "sha256:745843b39e829e108e518c489b31dc757de7d2131d53fac32bd8df268227bfee"}, - {file = "anyio-4.2.0.tar.gz", hash = "sha256:e1875bb4b4e2de1669f4bc7869b6d3f54231cdced71605e6e64c9be77e3be50f"}, + {file = "anyio-4.3.0-py3-none-any.whl", hash = "sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8"}, + {file = "anyio-4.3.0.tar.gz", hash = "sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6"}, ] [package.dependencies] @@ -614,18 +614,18 @@ files = [ [[package]] name = "databases" -version = "0.6.2" +version = "0.8.0" description = "Async database support for Python." optional = false python-versions = ">=3.7" files = [ - {file = "databases-0.6.2-py3-none-any.whl", hash = "sha256:ff4010136ac2bb9da2322a2ffda4ef9185ae1c365e5891e52924dd9499d33dc4"}, - {file = "databases-0.6.2.tar.gz", hash = "sha256:b09c370ad7c2f64c7f4316c096e265dc2e28304732639889272390decda2f893"}, + {file = "databases-0.8.0-py3-none-any.whl", hash = "sha256:0ceb7fd5c740d846e1f4f58c0256d780a6786841ec8e624a21f1eb1b51a9093d"}, + {file = "databases-0.8.0.tar.gz", hash = "sha256:6544d82e9926f233d694ec29cd018403444c7fb6e863af881a8304d1ff5cfb90"}, ] [package.dependencies] aiomysql = {version = "*", optional = true, markers = "extra == \"mysql\""} -sqlalchemy = ">=1.4,<=1.4.41" +sqlalchemy = ">=1.4.42,<1.5" [package.extras] aiomysql = ["aiomysql"] @@ -899,13 +899,13 @@ files = [ [[package]] name = "httpcore" -version = "1.0.2" +version = "1.0.3" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"}, - {file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"}, + {file = "httpcore-1.0.3-py3-none-any.whl", hash = "sha256:9a6a501c3099307d9fd76ac244e08503427679b1e81ceb1d922485e2f2462ad2"}, + {file = "httpcore-1.0.3.tar.gz", hash = "sha256:5c0f9546ad17dac4d0772b0808856eb616eb8b48ce94f49ed819fd6982a8a544"}, ] [package.dependencies] @@ -916,7 +916,7 @@ h11 = ">=0.13,<0.15" asyncio = ["anyio (>=4.0,<5.0)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] -trio = ["trio (>=0.22.0,<0.23.0)"] +trio = ["trio (>=0.22.0,<0.24.0)"] [[package]] name = "httpx" @@ -944,13 +944,13 @@ socks = ["socksio (==1.*)"] [[package]] name = "identify" -version = "2.5.34" +version = "2.5.35" description = "File identification library for Python" optional = false python-versions = ">=3.8" files = [ - {file = "identify-2.5.34-py2.py3-none-any.whl", hash = "sha256:a4316013779e433d08b96e5eabb7f641e6c7942e4ab5d4c509ebd2e7a8994aed"}, - {file = "identify-2.5.34.tar.gz", hash = "sha256:ee17bc9d499899bc9eaec1ac7bf2dc9eedd480db9d88b96d123d3b64a9d34f5d"}, + {file = "identify-2.5.35-py2.py3-none-any.whl", hash = "sha256:c4de0081837b211594f8e877a6b4fad7ca32bbfc1a9307fdd61c28bfe923f13e"}, + {file = "identify-2.5.35.tar.gz", hash = "sha256:10a7ca245cfcd756a554a7288159f72ff105ad233c7c4b9c6f0f4d108f5f6791"}, ] [package.extras] @@ -1609,60 +1609,65 @@ files = [ [[package]] name = "sqlalchemy" -version = "1.4.41" +version = "1.4.51" description = "Database Abstraction Library" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ - {file = "SQLAlchemy-1.4.41-cp27-cp27m-macosx_10_14_x86_64.whl", hash = "sha256:13e397a9371ecd25573a7b90bd037db604331cf403f5318038c46ee44908c44d"}, - {file = "SQLAlchemy-1.4.41-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:2d6495f84c4fd11584f34e62f9feec81bf373787b3942270487074e35cbe5330"}, - {file = "SQLAlchemy-1.4.41-cp27-cp27m-win32.whl", hash = "sha256:e570cfc40a29d6ad46c9aeaddbdcee687880940a3a327f2c668dd0e4ef0a441d"}, - {file = "SQLAlchemy-1.4.41-cp27-cp27m-win_amd64.whl", hash = "sha256:5facb7fd6fa8a7353bbe88b95695e555338fb038ad19ceb29c82d94f62775a05"}, - {file = "SQLAlchemy-1.4.41-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:f37fa70d95658763254941ddd30ecb23fc4ec0c5a788a7c21034fc2305dab7cc"}, - {file = "SQLAlchemy-1.4.41-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:361f6b5e3f659e3c56ea3518cf85fbdae1b9e788ade0219a67eeaaea8a4e4d2a"}, - {file = "SQLAlchemy-1.4.41-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0990932f7cca97fece8017414f57fdd80db506a045869d7ddf2dda1d7cf69ecc"}, - {file = "SQLAlchemy-1.4.41-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cd767cf5d7252b1c88fcfb58426a32d7bd14a7e4942497e15b68ff5d822b41ad"}, - {file = "SQLAlchemy-1.4.41-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5102fb9ee2c258a2218281adcb3e1918b793c51d6c2b4666ce38c35101bb940e"}, - {file = "SQLAlchemy-1.4.41-cp310-cp310-win32.whl", hash = "sha256:2082a2d2fca363a3ce21cfa3d068c5a1ce4bf720cf6497fb3a9fc643a8ee4ddd"}, - {file = "SQLAlchemy-1.4.41-cp310-cp310-win_amd64.whl", hash = "sha256:e4b12e3d88a8fffd0b4ca559f6d4957ed91bd4c0613a4e13846ab8729dc5c251"}, - {file = "SQLAlchemy-1.4.41-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:90484a2b00baedad361402c257895b13faa3f01780f18f4a104a2f5c413e4536"}, - {file = "SQLAlchemy-1.4.41-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b67fc780cfe2b306180e56daaa411dd3186bf979d50a6a7c2a5b5036575cbdbb"}, - {file = "SQLAlchemy-1.4.41-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ad2b727fc41c7f8757098903f85fafb4bf587ca6605f82d9bf5604bd9c7cded"}, - {file = "SQLAlchemy-1.4.41-cp311-cp311-win32.whl", hash = "sha256:59bdc291165b6119fc6cdbc287c36f7f2859e6051dd923bdf47b4c55fd2f8bd0"}, - {file = "SQLAlchemy-1.4.41-cp311-cp311-win_amd64.whl", hash = "sha256:d2e054aed4645f9b755db85bc69fc4ed2c9020c19c8027976f66576b906a74f1"}, - {file = "SQLAlchemy-1.4.41-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:4ba7e122510bbc07258dc42be6ed45997efdf38129bde3e3f12649be70683546"}, - {file = "SQLAlchemy-1.4.41-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c0dcf127bb99458a9d211e6e1f0f3edb96c874dd12f2503d4d8e4f1fd103790b"}, - {file = "SQLAlchemy-1.4.41-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e16c2be5cb19e2c08da7bd3a87fed2a0d4e90065ee553a940c4fc1a0fb1ab72b"}, - {file = "SQLAlchemy-1.4.41-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5ebeeec5c14533221eb30bad716bc1fd32f509196318fb9caa7002c4a364e4c"}, - {file = "SQLAlchemy-1.4.41-cp36-cp36m-win32.whl", hash = "sha256:3e2ef592ac3693c65210f8b53d0edcf9f4405925adcfc031ff495e8d18169682"}, - {file = "SQLAlchemy-1.4.41-cp36-cp36m-win_amd64.whl", hash = "sha256:eb30cf008850c0a26b72bd1b9be6730830165ce049d239cfdccd906f2685f892"}, - {file = "SQLAlchemy-1.4.41-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:c23d64a0b28fc78c96289ffbd0d9d1abd48d267269b27f2d34e430ea73ce4b26"}, - {file = "SQLAlchemy-1.4.41-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8eb8897367a21b578b26f5713833836f886817ee2ffba1177d446fa3f77e67c8"}, - {file = "SQLAlchemy-1.4.41-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:14576238a5f89bcf504c5f0a388d0ca78df61fb42cb2af0efe239dc965d4f5c9"}, - {file = "SQLAlchemy-1.4.41-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:639e1ae8d48b3c86ffe59c0daa9a02e2bfe17ca3d2b41611b30a0073937d4497"}, - {file = "SQLAlchemy-1.4.41-cp37-cp37m-win32.whl", hash = "sha256:0005bd73026cd239fc1e8ccdf54db58b6193be9a02b3f0c5983808f84862c767"}, - {file = "SQLAlchemy-1.4.41-cp37-cp37m-win_amd64.whl", hash = "sha256:5323252be2bd261e0aa3f33cb3a64c45d76829989fa3ce90652838397d84197d"}, - {file = "SQLAlchemy-1.4.41-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:05f0de3a1dc3810a776275763764bb0015a02ae0f698a794646ebc5fb06fad33"}, - {file = "SQLAlchemy-1.4.41-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0002e829142b2af00b4eaa26c51728f3ea68235f232a2e72a9508a3116bd6ed0"}, - {file = "SQLAlchemy-1.4.41-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:22ff16cedab5b16a0db79f1bc99e46a6ddececb60c396562e50aab58ddb2871c"}, - {file = "SQLAlchemy-1.4.41-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccfd238f766a5bb5ee5545a62dd03f316ac67966a6a658efb63eeff8158a4bbf"}, - {file = "SQLAlchemy-1.4.41-cp38-cp38-win32.whl", hash = "sha256:58bb65b3274b0c8a02cea9f91d6f44d0da79abc993b33bdedbfec98c8440175a"}, - {file = "SQLAlchemy-1.4.41-cp38-cp38-win_amd64.whl", hash = "sha256:ce8feaa52c1640de9541eeaaa8b5fb632d9d66249c947bb0d89dd01f87c7c288"}, - {file = "SQLAlchemy-1.4.41-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:199a73c31ac8ea59937cc0bf3dfc04392e81afe2ec8a74f26f489d268867846c"}, - {file = "SQLAlchemy-1.4.41-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4676d51c9f6f6226ae8f26dc83ec291c088fe7633269757d333978df78d931ab"}, - {file = "SQLAlchemy-1.4.41-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:036d8472356e1d5f096c5e0e1a7e0f9182140ada3602f8fff6b7329e9e7cfbcd"}, - {file = "SQLAlchemy-1.4.41-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2307495d9e0ea00d0c726be97a5b96615035854972cc538f6e7eaed23a35886c"}, - {file = "SQLAlchemy-1.4.41-cp39-cp39-win32.whl", hash = "sha256:9c56e19780cd1344fcd362fd6265a15f48aa8d365996a37fab1495cae8fcd97d"}, - {file = "SQLAlchemy-1.4.41-cp39-cp39-win_amd64.whl", hash = "sha256:f5fa526d027d804b1f85cdda1eb091f70bde6fb7d87892f6dd5a48925bc88898"}, - {file = "SQLAlchemy-1.4.41.tar.gz", hash = "sha256:0292f70d1797e3c54e862e6f30ae474014648bc9c723e14a2fda730adb0a9791"}, + {file = "SQLAlchemy-1.4.51-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:1a09d5bd1a40d76ad90e5570530e082ddc000e1d92de495746f6257dc08f166b"}, + {file = "SQLAlchemy-1.4.51-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2be4e6294c53f2ec8ea36486b56390e3bcaa052bf3a9a47005687ccf376745d1"}, + {file = "SQLAlchemy-1.4.51-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ca484ca11c65e05639ffe80f20d45e6be81fbec7683d6c9a15cd421e6e8b340"}, + {file = "SQLAlchemy-1.4.51-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:0535d5b57d014d06ceeaeffd816bb3a6e2dddeb670222570b8c4953e2d2ea678"}, + {file = "SQLAlchemy-1.4.51-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af55cc207865d641a57f7044e98b08b09220da3d1b13a46f26487cc2f898a072"}, + {file = "SQLAlchemy-1.4.51-cp310-cp310-win32.whl", hash = "sha256:7af40425ac535cbda129d9915edcaa002afe35d84609fd3b9d6a8c46732e02ee"}, + {file = "SQLAlchemy-1.4.51-cp310-cp310-win_amd64.whl", hash = "sha256:8d1d7d63e5d2f4e92a39ae1e897a5d551720179bb8d1254883e7113d3826d43c"}, + {file = "SQLAlchemy-1.4.51-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:eaeeb2464019765bc4340214fca1143081d49972864773f3f1e95dba5c7edc7d"}, + {file = "SQLAlchemy-1.4.51-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7deeae5071930abb3669b5185abb6c33ddfd2398f87660fafdb9e6a5fb0f3f2f"}, + {file = "SQLAlchemy-1.4.51-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0892e7ac8bc76da499ad3ee8de8da4d7905a3110b952e2a35a940dab1ffa550e"}, + {file = "SQLAlchemy-1.4.51-cp311-cp311-win32.whl", hash = "sha256:50e074aea505f4427151c286955ea025f51752fa42f9939749336672e0674c81"}, + {file = "SQLAlchemy-1.4.51-cp311-cp311-win_amd64.whl", hash = "sha256:3b0cd89a7bd03f57ae58263d0f828a072d1b440c8c2949f38f3b446148321171"}, + {file = "SQLAlchemy-1.4.51-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a33cb3f095e7d776ec76e79d92d83117438b6153510770fcd57b9c96f9ef623d"}, + {file = "SQLAlchemy-1.4.51-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6cacc0b2dd7d22a918a9642fc89840a5d3cee18a0e1fe41080b1141b23b10916"}, + {file = "SQLAlchemy-1.4.51-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:245c67c88e63f1523e9216cad6ba3107dea2d3ee19adc359597a628afcabfbcb"}, + {file = "SQLAlchemy-1.4.51-cp312-cp312-win32.whl", hash = "sha256:8e702e7489f39375601c7ea5a0bef207256828a2bc5986c65cb15cd0cf097a87"}, + {file = "SQLAlchemy-1.4.51-cp312-cp312-win_amd64.whl", hash = "sha256:0525c4905b4b52d8ccc3c203c9d7ab2a80329ffa077d4bacf31aefda7604dc65"}, + {file = "SQLAlchemy-1.4.51-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:1980e6eb6c9be49ea8f89889989127daafc43f0b1b6843d71efab1514973cca0"}, + {file = "SQLAlchemy-1.4.51-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ec7a0ed9b32afdf337172678a4a0e6419775ba4e649b66f49415615fa47efbd"}, + {file = "SQLAlchemy-1.4.51-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:352df882088a55293f621328ec33b6ffca936ad7f23013b22520542e1ab6ad1b"}, + {file = "SQLAlchemy-1.4.51-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:86a22143a4001f53bf58027b044da1fb10d67b62a785fc1390b5c7f089d9838c"}, + {file = "SQLAlchemy-1.4.51-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c37bc677690fd33932182b85d37433845de612962ed080c3e4d92f758d1bd894"}, + {file = "SQLAlchemy-1.4.51-cp36-cp36m-win32.whl", hash = "sha256:d0a83afab5e062abffcdcbcc74f9d3ba37b2385294dd0927ad65fc6ebe04e054"}, + {file = "SQLAlchemy-1.4.51-cp36-cp36m-win_amd64.whl", hash = "sha256:a61184c7289146c8cff06b6b41807c6994c6d437278e72cf00ff7fe1c7a263d1"}, + {file = "SQLAlchemy-1.4.51-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:3f0ef620ecbab46e81035cf3dedfb412a7da35340500ba470f9ce43a1e6c423b"}, + {file = "SQLAlchemy-1.4.51-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c55040d8ea65414de7c47f1a23823cd9f3fad0dc93e6b6b728fee81230f817b"}, + {file = "SQLAlchemy-1.4.51-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38ef80328e3fee2be0a1abe3fe9445d3a2e52a1282ba342d0dab6edf1fef4707"}, + {file = "SQLAlchemy-1.4.51-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f8cafa6f885a0ff5e39efa9325195217bb47d5929ab0051636610d24aef45ade"}, + {file = "SQLAlchemy-1.4.51-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8f2df79a46e130235bc5e1bbef4de0583fb19d481eaa0bffa76e8347ea45ec6"}, + {file = "SQLAlchemy-1.4.51-cp37-cp37m-win32.whl", hash = "sha256:f2e5b6f5cf7c18df66d082604a1d9c7a2d18f7d1dbe9514a2afaccbb51cc4fc3"}, + {file = "SQLAlchemy-1.4.51-cp37-cp37m-win_amd64.whl", hash = "sha256:5e180fff133d21a800c4f050733d59340f40d42364fcb9d14f6a67764bdc48d2"}, + {file = "SQLAlchemy-1.4.51-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:7d8139ca0b9f93890ab899da678816518af74312bb8cd71fb721436a93a93298"}, + {file = "SQLAlchemy-1.4.51-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb18549b770351b54e1ab5da37d22bc530b8bfe2ee31e22b9ebe650640d2ef12"}, + {file = "SQLAlchemy-1.4.51-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55e699466106d09f028ab78d3c2e1f621b5ef2c8694598242259e4515715da7c"}, + {file = "SQLAlchemy-1.4.51-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2ad16880ccd971ac8e570550fbdef1385e094b022d6fc85ef3ce7df400dddad3"}, + {file = "SQLAlchemy-1.4.51-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b97fd5bb6b7c1a64b7ac0632f7ce389b8ab362e7bd5f60654c2a418496be5d7f"}, + {file = "SQLAlchemy-1.4.51-cp38-cp38-win32.whl", hash = "sha256:cecb66492440ae8592797dd705a0cbaa6abe0555f4fa6c5f40b078bd2740fc6b"}, + {file = "SQLAlchemy-1.4.51-cp38-cp38-win_amd64.whl", hash = "sha256:39b02b645632c5fe46b8dd30755682f629ffbb62ff317ecc14c998c21b2896ff"}, + {file = "SQLAlchemy-1.4.51-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:b03850c290c765b87102959ea53299dc9addf76ca08a06ea98383348ae205c99"}, + {file = "SQLAlchemy-1.4.51-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e646b19f47d655261b22df9976e572f588185279970efba3d45c377127d35349"}, + {file = "SQLAlchemy-1.4.51-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3cf56cc36d42908495760b223ca9c2c0f9f0002b4eddc994b24db5fcb86a9e4"}, + {file = "SQLAlchemy-1.4.51-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:0d661cff58c91726c601cc0ee626bf167b20cc4d7941c93c5f3ac28dc34ddbea"}, + {file = "SQLAlchemy-1.4.51-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3823dda635988e6744d4417e13f2e2b5fe76c4bf29dd67e95f98717e1b094cad"}, + {file = "SQLAlchemy-1.4.51-cp39-cp39-win32.whl", hash = "sha256:b00cf0471888823b7a9f722c6c41eb6985cf34f077edcf62695ac4bed6ec01ee"}, + {file = "SQLAlchemy-1.4.51-cp39-cp39-win_amd64.whl", hash = "sha256:a055ba17f4675aadcda3005df2e28a86feb731fdcc865e1f6b4f209ed1225cba"}, + {file = "SQLAlchemy-1.4.51.tar.gz", hash = "sha256:e7908c2025eb18394e32d65dd02d2e37e17d733cdbe7d78231c2b6d7eb20cdb9"}, ] [package.dependencies] greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"} [package.extras] -aiomysql = ["aiomysql", "greenlet (!=0.4.17)"] -aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing-extensions (!=3.10.0.1)"] +aiomysql = ["aiomysql (>=0.2.0)", "greenlet (!=0.4.17)"] +aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing_extensions (!=3.10.0.1)"] asyncio = ["greenlet (!=0.4.17)"] asyncmy = ["asyncmy (>=0.2.3,!=0.2.4)", "greenlet (!=0.4.17)"] mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2)"] @@ -1672,14 +1677,28 @@ mssql-pyodbc = ["pyodbc"] mypy = ["mypy (>=0.910)", "sqlalchemy2-stubs"] mysql = ["mysqlclient (>=1.4.0)", "mysqlclient (>=1.4.0,<2)"] mysql-connector = ["mysql-connector-python"] -oracle = ["cx-oracle (>=7)", "cx-oracle (>=7,<8)"] +oracle = ["cx_oracle (>=7)", "cx_oracle (>=7,<8)"] postgresql = ["psycopg2 (>=2.7)"] postgresql-asyncpg = ["asyncpg", "greenlet (!=0.4.17)"] postgresql-pg8000 = ["pg8000 (>=1.16.6,!=1.29.0)"] postgresql-psycopg2binary = ["psycopg2-binary"] postgresql-psycopg2cffi = ["psycopg2cffi"] pymysql = ["pymysql", "pymysql (<1)"] -sqlcipher = ["sqlcipher3-binary"] +sqlcipher = ["sqlcipher3_binary"] + +[[package]] +name = "sqlalchemy2-stubs" +version = "0.0.2a38" +description = "Typing Stubs for SQLAlchemy 1.4" +optional = false +python-versions = ">=3.6" +files = [ + {file = "sqlalchemy2-stubs-0.0.2a38.tar.gz", hash = "sha256:861d722abeb12f13eacd775a9f09379b11a5a9076f469ccd4099961b95800f9e"}, + {file = "sqlalchemy2_stubs-0.0.2a38-py3-none-any.whl", hash = "sha256:b62aa46943807287550e2033dafe07564b33b6a815fbaa3c144e396f9cc53bcb"}, +] + +[package.dependencies] +typing-extensions = ">=3.7.4" [[package]] name = "starlette" @@ -1793,13 +1812,13 @@ files = [ [[package]] name = "urllib3" -version = "2.2.0" +version = "2.2.1" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.0-py3-none-any.whl", hash = "sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224"}, - {file = "urllib3-2.2.0.tar.gz", hash = "sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20"}, + {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, + {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, ] [package.extras] @@ -1910,4 +1929,4 @@ cython = "*" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "7b491f78b9acb025f1df3a739965f32a40d1d108ef57bb35be2b57a3380f95f5" +content-hash = "b5d6c82b060f449f4b59e5b9e38e78871bbdf72b615f0a06cd6d13722e2871fe" diff --git a/pyproject.toml b/pyproject.toml index f130677bb..7dc5ca374 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,8 +60,7 @@ py3rijndael = "0.3.3" pytimeparse = "1.1.8" pydantic = "2.6.1" redis = { extras = ["hiredis"], version = "5.0.1" } -sqlalchemy = "1.4.41" -databases = { extras = ["mysql"], version = "0.6.2" } +sqlalchemy = ">=1.4.42,<1.5" akatsuki-pp-py = "1.0.0" cryptography = "42.0.2" tenacity = "8.2.3" @@ -73,6 +72,7 @@ asgi-lifespan = "2.1.0" respx = "0.20.2" tzdata = "2024.1" coverage = "^7.4.1" +databases = {version = "^0.8.0", extras = ["mysql"]} [tool.poetry.group.dev.dependencies] pre-commit = "3.6.1" @@ -84,6 +84,7 @@ types-pymysql = "1.1.0.1" types-requests = "2.31.0.20240125" mypy = "1.8.0" types-pyyaml = "^6.0.12.12" +sqlalchemy2-stubs = "^0.0.2a38" [build-system] requires = ["poetry-core"] From 4aeadf87307c2129cbc6183eb03a0489e7a9c747 Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Tue, 20 Feb 2024 06:44:33 -0500 Subject: [PATCH 32/48] Remove use of passed-around/reused db conns (#626) * Remove use of passed-around/reused db conns * update migration failure log --------- Co-authored-by: James Wilson --- app/api/domains/cho.py | 17 +++---- app/api/init_api.py | 3 +- app/commands.py | 2 +- app/objects/beatmap.py | 101 ++++++++++++++++++------------------- app/objects/collections.py | 8 +-- app/objects/player.py | 9 ++-- app/state/services.py | 52 +++++++++---------- 7 files changed, 92 insertions(+), 100 deletions(-) diff --git a/app/api/domains/cho.py b/app/api/domains/cho.py index a0f5a693f..643aa66a7 100644 --- a/app/api/domains/cho.py +++ b/app/api/domains/cho.py @@ -186,13 +186,11 @@ async def bancho_handler( if osu_token is None: # the client is performing a login - async with app.state.services.database.connection() as db_conn: - login_data = await handle_osu_login_request( - request.headers, - await request.body(), - ip, - db_conn, - ) + login_data = await handle_osu_login_request( + request.headers, + await request.body(), + ip, + ) return Response( content=login_data["response_body"], @@ -625,7 +623,6 @@ async def handle_osu_login_request( headers: Mapping[str, str], body: bytes, ip: IPAddress, - db_conn: databases.core.Connection, ) -> LoginResponse: """\ Login has no specific packet, but happens when the osu! @@ -895,8 +892,8 @@ async def handle_osu_login_request( # fetch some of the player's # information from sql to be cached. - await player.stats_from_sql_full(db_conn) - await player.relationships_from_sql(db_conn) + await player.stats_from_sql_full() + await player.relationships_from_sql() # TODO: fetch player.recent_scores from sql diff --git a/app/api/init_api.py b/app/api/init_api.py index 91b7cd755..c2ce4f398 100644 --- a/app/api/init_api.py +++ b/app/api/init_api.py @@ -97,8 +97,7 @@ async def lifespan(asgi_app: BanchoAPI) -> AsyncIterator[Never]: await app.state.services.run_sql_migrations() - async with app.state.services.database.connection() as db_conn: - await collections.initialize_ram_caches(db_conn) + await collections.initialize_ram_caches() await app.bg_loops.initialize_housekeeping_tasks() diff --git a/app/commands.py b/app/commands.py index dfcb22acf..65be3ad15 100644 --- a/app/commands.py +++ b/app/commands.py @@ -650,7 +650,7 @@ async def _map(ctx: Context) -> str | None: # for updating cache would be faster? # surely this will not scale as well... - async with app.state.services.database.connection() as db_conn: + async with app.state.services.database.transaction(): if ctx.args[1] == "set": # update all maps in the set for _bmap in bmap.set.maps: diff --git a/app/objects/beatmap.py b/app/objects/beatmap.py index 9e72cf822..c4d51d013 100644 --- a/app/objects/beatmap.py +++ b/app/objects/beatmap.py @@ -834,62 +834,61 @@ async def _from_bsid_cache(bsid: int) -> BeatmapSet | None: @classmethod async def _from_bsid_sql(cls, bsid: int) -> BeatmapSet | None: """Fetch a mapset from the database by set id.""" - async with app.state.services.database.connection() as db_conn: - last_osuapi_check = await db_conn.fetch_val( - "SELECT last_osuapi_check FROM mapsets WHERE id = :set_id", - {"set_id": bsid}, - column=0, # last_osuapi_check - ) + last_osuapi_check = await app.state.services.database.fetch_val( + "SELECT last_osuapi_check FROM mapsets WHERE id = :set_id", + {"set_id": bsid}, + column=0, # last_osuapi_check + ) - if last_osuapi_check is None: - return None - - bmap_set = cls(id=bsid, last_osuapi_check=last_osuapi_check) - - for row in await maps_repo.fetch_many(set_id=bsid): - bmap = Beatmap( - md5=row["md5"], - id=row["id"], - set_id=row["set_id"], - artist=row["artist"], - title=row["title"], - version=row["version"], - creator=row["creator"], - last_update=row["last_update"], - total_length=row["total_length"], - max_combo=row["max_combo"], - status=RankedStatus(row["status"]), - frozen=row["frozen"], - plays=row["plays"], - passes=row["passes"], - mode=GameMode(row["mode"]), - bpm=row["bpm"], - cs=row["cs"], - od=row["od"], - ar=row["ar"], - hp=row["hp"], - diff=row["diff"], - filename=row["filename"], - map_set=bmap_set, - ) + if last_osuapi_check is None: + return None + + bmap_set = cls(id=bsid, last_osuapi_check=last_osuapi_check) + + for row in await maps_repo.fetch_many(set_id=bsid): + bmap = Beatmap( + md5=row["md5"], + id=row["id"], + set_id=row["set_id"], + artist=row["artist"], + title=row["title"], + version=row["version"], + creator=row["creator"], + last_update=row["last_update"], + total_length=row["total_length"], + max_combo=row["max_combo"], + status=RankedStatus(row["status"]), + frozen=row["frozen"], + plays=row["plays"], + passes=row["passes"], + mode=GameMode(row["mode"]), + bpm=row["bpm"], + cs=row["cs"], + od=row["od"], + ar=row["ar"], + hp=row["hp"], + diff=row["diff"], + filename=row["filename"], + map_set=bmap_set, + ) - # XXX: tempfix for bancho.py None: if app.settings.DEBUG: log(f"{channel} removed from channels list.") - async def prepare(self, db_conn: databases.core.Connection) -> None: + async def prepare(self) -> None: """Fetch data from sql & return; preparing to run the server.""" log("Fetching channels from sql.", Ansi.LCYAN) for row in await channels_repo.fetch_many(): @@ -284,10 +284,10 @@ def remove(self, player: Player) -> None: super().remove(player) -async def initialize_ram_caches(db_conn: databases.core.Connection) -> None: +async def initialize_ram_caches() -> None: """Setup & cache the global collections before listening for connections.""" # fetch channels, clans and pools from db - await app.state.sessions.channels.prepare(db_conn) + await app.state.sessions.channels.prepare() bot = await users_repo.fetch_one(id=1) if bot is None: @@ -308,7 +308,7 @@ async def initialize_ram_caches(db_conn: databases.core.Connection) -> None: # static api keys app.state.sessions.api_keys = { row["api_key"]: row["id"] - for row in await db_conn.fetch_all( + for row in await app.state.services.database.fetch_all( "SELECT id, api_key FROM users WHERE api_key IS NOT NULL", ) } diff --git a/app/objects/player.py b/app/objects/player.py index d99ba1909..d13ad3e69 100644 --- a/app/objects/player.py +++ b/app/objects/player.py @@ -497,8 +497,7 @@ async def unrestrict(self, admin: Player, reason: str) -> None: ) if not self.is_online: - async with app.state.services.database.connection() as db_conn: - await self.stats_from_sql_full(db_conn) + await self.stats_from_sql_full() for mode, stats in self.stats.items(): await app.state.services.redis.zadd( @@ -894,9 +893,9 @@ async def remove_block(self, player: Player) -> None: log(f"{self} unblocked {player}.") - async def relationships_from_sql(self, db_conn: databases.core.Connection) -> None: + async def relationships_from_sql(self) -> None: """Retrieve `self`'s relationships from sql.""" - for row in await db_conn.fetch_all( + for row in await app.state.services.database.fetch_all( "SELECT user2, type FROM relationships WHERE user1 = :user1", {"user1": self.id}, ): @@ -949,7 +948,7 @@ async def update_rank(self, mode: GameMode) -> int: return await self.get_global_rank(mode) - async def stats_from_sql_full(self, db_conn: databases.core.Connection) -> None: + async def stats_from_sql_full(self) -> None: """Retrieve `self`'s stats (all modes) from sql.""" for row in await stats_repo.fetch_many(player_id=self.id): game_mode = GameMode(row["mode"]) diff --git a/app/state/services.py b/app/state/services.py index 06728b324..d206a9aa7 100644 --- a/app/state/services.py +++ b/app/state/services.py @@ -464,31 +464,29 @@ async def run_sql_migrations() -> None: Ansi.LMAGENTA, ) - # XXX: so it turns out we can't use a transaction here (at least with mysql) - # to roll back changes, as any structural changes to tables implicitly - # commit: https://dev.mysql.com/doc/refman/5.7/en/implicit-commit.html - async with app.state.services.database.connection() as db_conn: - for query in queries: - try: - await db_conn.execute(query) - except pymysql.err.MySQLError as exc: - log(f"Failed: {query}", Ansi.GRAY) - log(repr(exc)) - log( - "SQL failed to update - unless you've been " - "modifying sql and know what caused this, " - "please please contact cmyui#0425.", - Ansi.LRED, - ) - raise KeyboardInterrupt from exc - else: - # all queries executed successfully - await db_conn.execute( - "INSERT INTO startups (ver_major, ver_minor, ver_micro, datetime) " - "VALUES (:major, :minor, :micro, NOW())", - { - "major": software_version.major, - "minor": software_version.minor, - "micro": software_version.micro, - }, + # XXX: we can't use a transaction here with mysql as structural changes to + # tables implicitly commit: https://dev.mysql.com/doc/refman/5.7/en/implicit-commit.html + for query in queries: + try: + await app.state.services.database.execute(query) + except pymysql.err.MySQLError as exc: + log(f"Failed: {query}", Ansi.GRAY) + log(repr(exc)) + log( + "SQL failed to update - unless you've been " + "modifying sql and know what caused this, " + "please contact @cmyui on Discord.", + Ansi.LRED, ) + raise KeyboardInterrupt from exc + else: + # all queries executed successfully + await app.state.services.database.execute( + "INSERT INTO startups (ver_major, ver_minor, ver_micro, datetime) " + "VALUES (:major, :minor, :micro, NOW())", + { + "major": software_version.major, + "minor": software_version.minor, + "micro": software_version.micro, + }, + ) From 039ddcafa7cb84bdf773a7cfb967e1740de8fc97 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Feb 2024 11:50:35 +0000 Subject: [PATCH 33/48] [pre-commit.ci] pre-commit autoupdate (#633) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/asottile/pyupgrade: v3.15.0 → v3.15.1](https://github.com/asottile/pyupgrade/compare/v3.15.0...v3.15.1) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 592b530e5..60b4aa42f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: hooks: - id: black - repo: https://github.com/asottile/pyupgrade - rev: v3.15.0 + rev: v3.15.1 hooks: - id: pyupgrade args: [--py311-plus, --keep-runtime-typing] From 0dafc7e54cebda8343a8fe761414cd5fc6c2ff02 Mon Sep 17 00:00:00 2001 From: James Wilson Date: Tue, 20 Feb 2024 15:38:00 +0000 Subject: [PATCH 34/48] ensure `/get_mappool` doesn't crash if pool creator is not in a clan (#634) --- app/api/v1/api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/app/api/v1/api.py b/app/api/v1/api.py index 0109a80d5..a51a3f0d1 100644 --- a/app/api/v1/api.py +++ b/app/api/v1/api.py @@ -971,7 +971,11 @@ async def api_get_pool( status_code=status.HTTP_404_NOT_FOUND, ) - pool_creator_clan = await clans_repo.fetch_one(id=pool_creator.clan_id) + pool_creator_clan = ( + await clans_repo.fetch_one(id=pool_creator.clan_id) + if pool_creator.clan_id is not None + else None + ) pool_creator_clan_members: list[users_repo.User] = [] if pool_creator_clan is not None: pool_creator_clan_members = await users_repo.fetch_many( From eb62357535be6c33d4aec1bafc9a3339456497c4 Mon Sep 17 00:00:00 2001 From: cmyui Date: Thu, 22 Feb 2024 23:20:33 -0500 Subject: [PATCH 35/48] fix typing on lifespan --- app/api/init_api.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/app/api/init_api.py b/app/api/init_api.py index c2ce4f398..1602c558d 100644 --- a/app/api/init_api.py +++ b/app/api/init_api.py @@ -8,7 +8,6 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any -from typing import Never import starlette.routing from fastapi import FastAPI @@ -69,7 +68,7 @@ def openapi(self) -> dict[str, Any]: @asynccontextmanager -async def lifespan(asgi_app: BanchoAPI) -> AsyncIterator[Never]: +async def lifespan(asgi_app: BanchoAPI) -> AsyncIterator[None]: if isinstance(sys.stdout, io.TextIOWrapper): sys.stdout.reconfigure(encoding="utf-8") @@ -107,7 +106,7 @@ async def lifespan(asgi_app: BanchoAPI) -> AsyncIterator[Never]: Ansi.LMAGENTA, ) - yield # type: ignore + yield # we want to attempt to gracefully finish any ongoing connections # and shut down any of the housekeeping tasks running in the background. From 31a38fe8d61eaf8d452f043020aa43328852c636 Mon Sep 17 00:00:00 2001 From: kingdom5500 <37349466+kingdom5500@users.noreply.github.com> Date: Sun, 25 Feb 2024 20:06:15 +0000 Subject: [PATCH 36/48] fix: use all scores (not top 100) in recalc.py (#635) --- tools/recalc.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tools/recalc.py b/tools/recalc.py index df12b6baf..172a5053e 100644 --- a/tools/recalc.py +++ b/tools/recalc.py @@ -134,15 +134,13 @@ async def recalculate_user( if not total_scores: return - top_100_pp = best_scores[:100] - # calculate new total weighted accuracy - weighted_acc = sum(row["acc"] * 0.95**i for i, row in enumerate(top_100_pp)) + weighted_acc = sum(row["acc"] * 0.95**i for i, row in enumerate(best_scores)) bonus_acc = 100.0 / (20 * (1 - 0.95**total_scores)) acc = (weighted_acc * bonus_acc) / 100 # calculate new total weighted pp - weighted_pp = sum(row["pp"] * 0.95**i for i, row in enumerate(top_100_pp)) + weighted_pp = sum(row["pp"] * 0.95**i for i, row in enumerate(best_scores)) bonus_pp = 416.6667 * (1 - 0.9994**total_scores) pp = round(weighted_pp + bonus_pp) From 0cc409ff25f609f6c3180fb6bd7a494a54a934a3 Mon Sep 17 00:00:00 2001 From: cmyui Date: Sun, 25 Feb 2024 21:54:05 -0500 Subject: [PATCH 37/48] more clear !pool list output --- app/commands.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/app/commands.py b/app/commands.py index 65be3ad15..7efa8c454 100644 --- a/app/commands.py +++ b/app/commands.py @@ -2222,9 +2222,14 @@ async def pool_list(ctx: Context) -> str | None: l = [f"Mappools ({len(tourney_pools)})"] for pool in tourney_pools: + created_by = await users_repo.fetch_one(id=pool["created_by"]) + if created_by is None: + log(f"Could not find pool creator (Id {pool['created_by']}).", Ansi.LRED) + continue + l.append( - f"[{pool['created_at']:%Y-%m-%d}] {pool['id']} " - f"{pool['name']}, by {pool['created_by']}.", + f"[{pool['created_at']:%Y-%m-%d}] " + f"{pool['name']}, by {created_by['name']}.", ) return "\n".join(l) From 26b8595d4ed9bcea76d4bfa629e765a06e711e8d Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Sun, 25 Feb 2024 22:51:52 -0500 Subject: [PATCH 38/48] [v5.1.0] Refactor the rest of the repositories to use sqlalchemy core 1.4 (#632) * Refactor the rest of the repositories to use sqlalchemy core 1.4 Co-authored-by: James Wilson * fix types & nullables * fmt * add & migrate to thin database adapter, ensure all queries are compatible with sqlalchemy 2.0 * bugfix startup process * Add `render_postcompile` flag to support `IN` clauses * log db queries in debug mode * remove default from logs.time (doesn't exist) * Fix bugs in favorites table repo * Bump minor version -- to 5.1.0 --------- Co-authored-by: James Wilson Co-authored-by: tsunyoku --- app/adapters/__init__.py | 0 app/adapters/database.py | 100 ++++++ app/api/domains/cho.py | 2 +- app/api/domains/osu.py | 27 +- app/api/v1/api.py | 2 +- app/commands.py | 44 ++- app/objects/beatmap.py | 3 +- app/objects/player.py | 12 +- app/repositories/__init__.py | 8 - app/repositories/achievements.py | 80 ++--- app/repositories/channels.py | 71 ++-- app/repositories/clans.py | 213 +++++------- app/repositories/client_hashes.py | 170 +++++----- app/repositories/comments.py | 132 ++++---- app/repositories/favourites.py | 104 +++--- app/repositories/ingame_logins.py | 149 ++++----- app/repositories/logs.py | 79 ++--- app/repositories/mail.py | 152 +++++---- app/repositories/map_requests.py | 130 ++++---- app/repositories/maps.py | 449 ++++++++++++-------------- app/repositories/ratings.py | 119 +++---- app/repositories/scores.py | 67 ++-- app/repositories/stats.py | 77 ++--- app/repositories/tourney_pool_maps.py | 209 +++++------- app/repositories/tourney_pools.py | 156 ++++----- app/repositories/user_achievements.py | 108 +++---- app/repositories/users.py | 322 +++++++++--------- app/state/services.py | 10 +- app/usecases/user_achievements.py | 12 +- pyproject.toml | 2 +- 30 files changed, 1353 insertions(+), 1656 deletions(-) create mode 100644 app/adapters/__init__.py create mode 100644 app/adapters/database.py diff --git a/app/adapters/__init__.py b/app/adapters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/adapters/database.py b/app/adapters/database.py new file mode 100644 index 000000000..117469a90 --- /dev/null +++ b/app/adapters/database.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from typing import Any +from typing import cast + +from databases import Database as _Database +from databases.core import Transaction +from sqlalchemy.dialects.mysql.mysqldb import MySQLDialect_mysqldb +from sqlalchemy.sql.compiler import Compiled +from sqlalchemy.sql.expression import ClauseElement + +from app import settings + + +class MySQLDialect(MySQLDialect_mysqldb): + default_paramstyle = "named" + + +DIALECT = MySQLDialect() + +MySQLRow = dict[str, Any] +MySQLParams = dict[str, Any] | None +MySQLQuery = ClauseElement | str + + +class Database: + def __init__(self, url: str) -> None: + self._database = _Database(url) + + async def connect(self) -> None: + await self._database.connect() + + async def disconnect(self) -> None: + await self._database.disconnect() + + def _compile(self, clause_element: ClauseElement) -> tuple[str, MySQLParams]: + compiled: Compiled = clause_element.compile( + dialect=DIALECT, + compile_kwargs={"render_postcompile": True}, + ) + if settings.DEBUG: + print(str(compiled), compiled.params) + return str(compiled), compiled.params + + async def fetch_one( + self, + query: MySQLQuery, + params: MySQLParams = None, + ) -> MySQLRow | None: + if isinstance(query, ClauseElement): + query, params = self._compile(query) + + row = await self._database.fetch_one(query, params) + return dict(row._mapping) if row is not None else None + + async def fetch_all( + self, + query: MySQLQuery, + params: MySQLParams = None, + ) -> list[MySQLRow]: + if isinstance(query, ClauseElement): + query, params = self._compile(query) + + rows = await self._database.fetch_all(query, params) + return [dict(row._mapping) for row in rows] + + async def fetch_val( + self, + query: MySQLQuery, + params: MySQLParams = None, + column: Any = 0, + ) -> Any: + if isinstance(query, ClauseElement): + query, params = self._compile(query) + + val = await self._database.fetch_val(query, params, column) + return val + + async def execute(self, query: MySQLQuery, params: MySQLParams = None) -> int: + if isinstance(query, ClauseElement): + query, params = self._compile(query) + + rec_id = await self._database.execute(query, params) + return cast(int, rec_id) + + # NOTE: this accepts str since current execute_many uses are not using alchemy. + # alchemy does execute_many in a single query so this method will be unneeded once raw SQL is not in use. + async def execute_many(self, query: str, params: list[MySQLParams]) -> None: + if isinstance(query, ClauseElement): + query, _ = self._compile(query) + + await self._database.execute_many(query, params) + + def transaction( + self, + *, + force_rollback: bool = False, + **kwargs: Any, + ) -> Transaction: + return self._database.transaction(force_rollback=force_rollback, **kwargs) diff --git a/app/api/domains/cho.py b/app/api/domains/cho.py index 643aa66a7..14ca8b36f 100644 --- a/app/api/domains/cho.py +++ b/app/api/domains/cho.py @@ -813,7 +813,7 @@ async def handle_osu_login_request( # country wasn't stored on registration. log(f"Fixing {login_data['username']}'s country.", Ansi.LGREEN) - await users_repo.update( + await users_repo.partial_update( id=user_info["id"], country=geoloc["country"]["acronym"], ) diff --git a/app/api/domains/osu.py b/app/api/domains/osu.py index 605b6468d..0122c5ec0 100644 --- a/app/api/domains/osu.py +++ b/app/api/domains/osu.py @@ -468,19 +468,17 @@ async def osuSearchSetHandler( return Response(b"") # invalid args # Get all set data. - rec = await app.state.services.database.fetch_one( + bmapset = await app.state.services.database.fetch_one( "SELECT DISTINCT set_id, artist, " "title, status, creator, last_update " f"FROM maps WHERE {k} = :v", {"v": v}, ) - - if rec is None: + if bmapset is None: # TODO: get from osu! return Response(b"") rating = 10.0 # TODO: real data - bmapset = dict(rec._mapping) return Response( ( @@ -979,7 +977,7 @@ async def osuSubmitModularSelector( server_achievements = await achievements_usecases.fetch_many() player_achievements = await user_achievements_usecases.fetch_many( - score.player.id, + user_id=score.player.id, ) for server_achievement in server_achievements: @@ -1184,17 +1182,14 @@ async def get_leaderboard_scores( # TODO: customizability of the number of scores query.append("ORDER BY _score DESC LIMIT 50") - score_rows = [ - dict(r._mapping) - for r in await app.state.services.database.fetch_all( - " ".join(query), - params, - ) - ] + score_rows = await app.state.services.database.fetch_all( + " ".join(query), + params, + ) if score_rows: # None or [] # fetch player's personal best score - personal_best_score_rec = await app.state.services.database.fetch_one( + personal_best_score_row = await app.state.services.database.fetch_one( f"SELECT id, {scoring_metric} AS _score, " "max_combo, n50, n100, n300, " "nmiss, nkatu, ngeki, perfect, mods, " @@ -1206,9 +1201,7 @@ async def get_leaderboard_scores( {"map_md5": map_md5, "mode": mode, "user_id": player.id}, ) - if personal_best_score_rec is not None: - personal_best_score_row = dict(personal_best_score_rec._mapping) - + if personal_best_score_row is not None: # calculate the rank of the score. p_best_rank = 1 + await app.state.services.database.fetch_val( "SELECT COUNT(*) FROM scores s " @@ -1226,8 +1219,6 @@ async def get_leaderboard_scores( # attach rank to personal best row personal_best_score_row["rank"] = p_best_rank - else: - personal_best_score_row = None else: score_rows = [] personal_best_score_row = None diff --git a/app/api/v1/api.py b/app/api/v1/api.py index a51a3f0d1..b09918114 100644 --- a/app/api/v1/api.py +++ b/app/api/v1/api.py @@ -804,7 +804,7 @@ async def api_get_replay( 'attachment; filename="{username} - ' "{artist} - {title} [{version}] " '({play_time:%Y-%m-%d}).osr"' - ).format(**dict(row._mapping)), + ).format(**row), }, ) diff --git a/app/commands.py b/app/commands.py index 7efa8c454..5203d8d51 100644 --- a/app/commands.py +++ b/app/commands.py @@ -278,7 +278,7 @@ async def changename(ctx: Context) -> str | None: return "Username already taken by another player." # all checks passed, update their name - await users_repo.update(ctx.player.id, name=name) + await users_repo.partial_update(ctx.player.id, name=name) ctx.player.enqueue( app.packets.notification(f"Your username has been changed to {name}!"), @@ -388,21 +388,17 @@ async def top(ctx: Context) -> str | None: # !top rx!std mode = GAMEMODE_REPR_LIST.index(ctx.args[0]) - scores = [ - dict(s._mapping) - for s in await app.state.services.database.fetch_all( - "SELECT s.pp, b.artist, b.title, b.version, b.set_id map_set_id, b.id map_id " - "FROM scores s " - "LEFT JOIN maps b ON b.md5 = s.map_md5 " - "WHERE s.userid = :user_id " - "AND s.mode = :mode " - "AND s.status = 2 " - "AND b.status in (2, 3) " - "ORDER BY s.pp DESC LIMIT 10", - {"user_id": player.id, "mode": mode}, - ) - ] - + scores = await app.state.services.database.fetch_all( + "SELECT s.pp, b.artist, b.title, b.version, b.set_id map_set_id, b.id map_id " + "FROM scores s " + "LEFT JOIN maps b ON b.md5 = s.map_md5 " + "WHERE s.userid = :user_id " + "AND s.mode = :mode " + "AND s.status = 2 " + "AND b.status in (2, 3) " + "ORDER BY s.pp DESC LIMIT 10", + {"user_id": player.id, "mode": mode}, + ) if not scores: return "No scores" @@ -564,7 +560,7 @@ async def apikey(ctx: Context) -> str | None: # generate new token ctx.player.api_key = str(uuid.uuid4()) - await users_repo.update(ctx.player.id, api_key=ctx.player.api_key) + await users_repo.partial_update(ctx.player.id, api_key=ctx.player.api_key) app.state.sessions.api_keys[ctx.player.api_key] = ctx.player.id return f"API key generated. Copy your api key from (this url)[http://{ctx.player.api_key}]." @@ -654,7 +650,7 @@ async def _map(ctx: Context) -> str | None: if ctx.args[1] == "set": # update all maps in the set for _bmap in bmap.set.maps: - await maps_repo.update(_bmap.id, status=new_status, frozen=True) + await maps_repo.partial_update(_bmap.id, status=new_status, frozen=True) # make sure cache and db are synced about the newest change for _bmap in app.state.cache.beatmapset[bmap.set_id].maps: @@ -671,7 +667,7 @@ async def _map(ctx: Context) -> str | None: else: # update only map - await maps_repo.update(bmap.id, status=new_status, frozen=True) + await maps_repo.partial_update(bmap.id, status=new_status, frozen=True) # make sure cache and db are synced about the newest change if bmap.md5 in app.state.cache.beatmap: @@ -2326,7 +2322,7 @@ async def clan_create(ctx: Context) -> str | None: ctx.player.clan_id = new_clan["id"] ctx.player.clan_priv = ClanPrivileges.Owner - await users_repo.update( + await users_repo.partial_update( ctx.player.id, clan_id=new_clan["id"], clan_priv=ClanPrivileges.Owner, @@ -2362,7 +2358,7 @@ async def clan_disband(ctx: Context) -> str | None: if not clan: return "You're not a member of a clan!" - await clans_repo.delete(clan["id"]) + await clans_repo.delete_one(clan["id"]) # remove all members from the clan clan_member_ids = [ @@ -2370,7 +2366,7 @@ async def clan_disband(ctx: Context) -> str | None: for clan_member in await users_repo.fetch_many(clan_id=clan["id"]) ] for member_id in clan_member_ids: - await users_repo.update(member_id, clan_id=0, clan_priv=0) + await users_repo.partial_update(member_id, clan_id=0, clan_priv=0) member = app.state.sessions.players.get(id=member_id) if member: @@ -2423,7 +2419,7 @@ async def clan_leave(ctx: Context) -> str | None: clan_members = await users_repo.fetch_many(clan_id=clan["id"]) - await users_repo.update(ctx.player.id, clan_id=0, clan_priv=0) + await users_repo.partial_update(ctx.player.id, clan_id=0, clan_priv=0) ctx.player.clan_id = None ctx.player.clan_priv = None @@ -2431,7 +2427,7 @@ async def clan_leave(ctx: Context) -> str | None: if not clan_members: # no members left, disband clan - await clans_repo.delete(clan["id"]) + await clans_repo.delete_one(clan["id"]) # announce clan disbanding announce_chan = app.state.sessions.channels.get_by_name("#announce") diff --git a/app/objects/beatmap.py b/app/objects/beatmap.py index c4d51d013..6039f583b 100644 --- a/app/objects/beatmap.py +++ b/app/objects/beatmap.py @@ -885,8 +885,7 @@ async def _from_bsid_sql(cls, bsid: int) -> BeatmapSet | None: ) .translate(IGNORED_BEATMAP_CHARS) ) - - await maps_repo.update(bmap.id, filename=bmap.filename) + await maps_repo.partial_update(bmap.id, filename=bmap.filename) bmap_set.maps.append(bmap) diff --git a/app/objects/player.py b/app/objects/player.py index d13ad3e69..85e3193c3 100644 --- a/app/objects/player.py +++ b/app/objects/player.py @@ -412,7 +412,7 @@ async def update_privs(self, new: Privileges) -> None: if "bancho_priv" in vars(self): del self.bancho_priv # wipe cached_property - await users_repo.update( + await users_repo.partial_update( id=self.id, priv=self.priv, ) @@ -424,7 +424,7 @@ async def add_privs(self, bits: Privileges) -> None: if "bancho_priv" in vars(self): del self.bancho_priv # wipe cached_property - await users_repo.update( + await users_repo.partial_update( id=self.id, priv=self.priv, ) @@ -441,7 +441,7 @@ async def remove_privs(self, bits: Privileges) -> None: if "bancho_priv" in vars(self): del self.bancho_priv # wipe cached_property - await users_repo.update( + await users_repo.partial_update( id=self.id, priv=self.priv, ) @@ -527,7 +527,7 @@ async def silence(self, admin: Player, duration: float, reason: str) -> None: """Silence `self` for `duration` seconds, and log to sql.""" self.silence_end = int(time.time() + duration) - await users_repo.update( + await users_repo.partial_update( id=self.id, silence_end=self.silence_end, ) @@ -555,7 +555,7 @@ async def unsilence(self, admin: Player, reason: str) -> None: """Unsilence `self`, and log to sql.""" self.silence_end = int(time.time()) - await users_repo.update( + await users_repo.partial_update( id=self.id, silence_end=self.silence_end, ) @@ -973,7 +973,7 @@ async def stats_from_sql_full(self) -> None: def update_latest_activity_soon(self) -> None: """Update the player's latest activity in the database.""" - task = users_repo.update( + task = users_repo.partial_update( id=self.id, latest_activity=int(time.time()), ) diff --git a/app/repositories/__init__.py b/app/repositories/__init__.py index 75009bc1b..74d2ac295 100644 --- a/app/repositories/__init__.py +++ b/app/repositories/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations -from sqlalchemy.dialects.mysql.mysqldb import MySQLDialect_mysqldb from sqlalchemy.orm import DeclarativeMeta from sqlalchemy.orm import registry @@ -14,10 +13,3 @@ class Base(metaclass=DeclarativeMeta): metadata = mapper_registry.metadata __init__ = mapper_registry.constructor - - -class MySQLDialect(MySQLDialect_mysqldb): - default_paramstyle = "named" - - -DIALECT = MySQLDialect() diff --git a/app/repositories/achievements.py b/app/repositories/achievements.py index e2ee00ae5..7f131d2cf 100644 --- a/app/repositories/achievements.py +++ b/app/repositories/achievements.py @@ -2,14 +2,12 @@ from collections.abc import Callable from typing import TYPE_CHECKING -from typing import Any from typing import TypedDict from typing import cast import app.state.services from app._typing import UNSET from app._typing import _UnsetSentinel -from app.repositories import DIALECT from app.repositories import Base if TYPE_CHECKING: @@ -29,7 +27,7 @@ class AchievementsTable(Base): __tablename__ = "achievements" - id = Column("id", Integer, primary_key=True) + id = Column("id", Integer, primary_key=True, nullable=False, autoincrement=True) file = Column("file", String(128), nullable=False) name = Column("name", String(128, collation="utf8"), nullable=False) desc = Column("desc", String(256, collation="utf8"), nullable=False) @@ -72,18 +70,13 @@ async def create( desc=desc, cond=cond, ) - compiled = insert_stmt.compile(dialect=DIALECT) + rec_id = await app.state.services.database.execute(insert_stmt) - rec_id = await app.state.services.database.execute(str(compiled), compiled.params) + select_stmt = select(*READ_PARAMS).where(AchievementsTable.id == rec_id) + achievement = await app.state.services.database.fetch_one(select_stmt) + assert achievement is not None - select_stmt = select(READ_PARAMS).where(AchievementsTable.id == rec_id) - compiled = select_stmt.compile(dialect=DIALECT) - - rec = await app.state.services.database.fetch_one(str(compiled), compiled.params) - assert rec is not None - - achievement = dict(rec._mapping) - achievement["cond"] = eval(f'lambda score, mode_vn: {rec["cond"]}') + achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}') return cast(Achievement, achievement) @@ -95,32 +88,28 @@ async def fetch_one( if id is None and name is None: raise ValueError("Must provide at least one parameter.") - select_stmt = select(READ_PARAMS) + select_stmt = select(*READ_PARAMS) if id is not None: select_stmt = select_stmt.where(AchievementsTable.id == id) if name is not None: select_stmt = select_stmt.where(AchievementsTable.name == name) - compiled = select_stmt.compile(dialect=DIALECT) - rec = await app.state.services.database.fetch_one(str(compiled), compiled.params) - - if rec is None: + achievement = await app.state.services.database.fetch_one(select_stmt) + if achievement is None: return None - achievement = dict(rec._mapping) - achievement["cond"] = eval(f'lambda score, mode_vn: {rec["cond"]}') + achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}') return cast(Achievement, achievement) async def fetch_count() -> int: """Fetch the number of achievements.""" select_stmt = select(func.count().label("count")).select_from(AchievementsTable) - compiled = select_stmt.compile(dialect=DIALECT) - rec = await app.state.services.database.fetch_one(str(compiled), compiled.params) + rec = await app.state.services.database.fetch_one(select_stmt) assert rec is not None - return cast(int, rec._mapping["count"]) + return cast(int, rec["count"]) async def fetch_many( @@ -128,23 +117,13 @@ async def fetch_many( page_size: int | None = None, ) -> list[Achievement]: """Fetch a list of achievements.""" - select_stmt = select(READ_PARAMS) + select_stmt = select(*READ_PARAMS) if page is not None and page_size is not None: select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) - compiled = select_stmt.compile(dialect=DIALECT) - - records = await app.state.services.database.fetch_all( - str(compiled), - compiled.params, - ) - - achievements: list[dict[str, Any]] = [] - - for rec in records: - achievement = dict(rec._mapping) - achievement["cond"] = eval(f'lambda score, mode_vn: {rec["cond"]}') - achievements.append(achievement) + achievements = await app.state.services.database.fetch_all(select_stmt) + for achievement in achievements: + achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}') return cast(list[Achievement], achievements) @@ -167,16 +146,14 @@ async def partial_update( if not isinstance(cond, _UnsetSentinel): update_stmt = update_stmt.values(cond=cond) - compiled = update_stmt.compile(dialect=DIALECT) - await app.state.services.database.execute(str(compiled), compiled.params) + await app.state.services.database.execute(update_stmt) - select_stmt = select(READ_PARAMS).where(AchievementsTable.id == id) - compiled = select_stmt.compile(dialect=DIALECT) - rec = await app.state.services.database.fetch_one(str(compiled), compiled.params) - assert rec is not None + select_stmt = select(*READ_PARAMS).where(AchievementsTable.id == id) + achievement = await app.state.services.database.fetch_one(select_stmt) + if achievement is None: + return None - achievement = dict(rec._mapping) - achievement["cond"] = eval(f'lambda score, mode_vn: {rec["cond"]}') + achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}') return cast(Achievement, achievement) @@ -184,16 +161,13 @@ async def delete_one( id: int, ) -> Achievement | None: """Delete an existing achievement.""" - select_stmt = select(READ_PARAMS).where(AchievementsTable.id == id) - compiled = select_stmt.compile(dialect=DIALECT) - rec = await app.state.services.database.fetch_one(str(compiled), compiled.params) - if rec is None: + select_stmt = select(*READ_PARAMS).where(AchievementsTable.id == id) + achievement = await app.state.services.database.fetch_one(select_stmt) + if achievement is None: return None delete_stmt = delete(AchievementsTable).where(AchievementsTable.id == id) - compiled = delete_stmt.compile(dialect=DIALECT) - await app.state.services.database.execute(str(compiled), compiled.params) + await app.state.services.database.execute(delete_stmt) - achievement = dict(rec._mapping) - achievement["cond"] = eval(f'lambda score, mode_vn: {rec["cond"]}') + achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}') return cast(Achievement, achievement) diff --git a/app/repositories/channels.py b/app/repositories/channels.py index 35f7418f4..d478a755b 100644 --- a/app/repositories/channels.py +++ b/app/repositories/channels.py @@ -1,7 +1,5 @@ from __future__ import annotations -import textwrap -from typing import Any from typing import TypedDict from typing import cast @@ -19,14 +17,13 @@ import app.state.services from app._typing import UNSET from app._typing import _UnsetSentinel -from app.repositories import DIALECT from app.repositories import Base class ChannelsTable(Base): __tablename__ = "channels" - id = Column("id", Integer, primary_key=True) + id = Column("id", Integer, primary_key=True, nullable=False, autoincrement=True) name = Column("name", String(32), nullable=False) topic = Column("topic", String(256), nullable=False) read_priv = Column("read_priv", Integer, nullable=False, server_default="1") @@ -73,18 +70,13 @@ async def create( write_priv=write_priv, auto_join=auto_join, ) - compiled = insert_stmt.compile(dialect=DIALECT) - rec_id = await app.state.services.database.execute(str(compiled), compiled.params) - - select_stmt = select(READ_PARAMS).where(ChannelsTable.id == rec_id) - compiled = select_stmt.compile(dialect=DIALECT) - channel = await app.state.services.database.fetch_one( - str(compiled), - compiled.params, - ) + rec_id = await app.state.services.database.execute(insert_stmt) + + select_stmt = select(*READ_PARAMS).where(ChannelsTable.id == rec_id) + channel = await app.state.services.database.fetch_one(select_stmt) assert channel is not None - return cast(Channel, dict(channel._mapping)) + return cast(Channel, channel) async def fetch_one( @@ -95,20 +87,15 @@ async def fetch_one( if id is None and name is None: raise ValueError("Must provide at least one parameter.") - select_stmt = select(READ_PARAMS) + select_stmt = select(*READ_PARAMS) if id is not None: select_stmt = select_stmt.where(ChannelsTable.id == id) if name is not None: select_stmt = select_stmt.where(ChannelsTable.name == name) - compiled = select_stmt.compile(dialect=DIALECT) - channel = await app.state.services.database.fetch_one( - str(compiled), - compiled.params, - ) - - return cast(Channel, dict(channel._mapping)) if channel is not None else None + channel = await app.state.services.database.fetch_one(select_stmt) + return cast(Channel | None, channel) async def fetch_count( @@ -128,10 +115,9 @@ async def fetch_count( if auto_join is not None: select_stmt = select_stmt.where(ChannelsTable.auto_join == auto_join) - compiled = select_stmt.compile(dialect=DIALECT) - rec = await app.state.services.database.fetch_one(str(compiled), compiled.params) + rec = await app.state.services.database.fetch_one(select_stmt) assert rec is not None - return cast(int, rec._mapping["count"]) + return cast(int, rec["count"]) async def fetch_many( @@ -142,7 +128,7 @@ async def fetch_many( page_size: int | None = None, ) -> list[Channel]: """Fetch multiple channels from the database.""" - select_stmt = select(READ_PARAMS) + select_stmt = select(*READ_PARAMS) if read_priv is not None: select_stmt = select_stmt.where(ChannelsTable.read_priv == read_priv) @@ -154,12 +140,8 @@ async def fetch_many( if page is not None and page_size is not None: select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) - compiled = select_stmt.compile(dialect=DIALECT) - channels = await app.state.services.database.fetch_all( - str(compiled), - compiled.params, - ) - return cast(list[Channel], [dict(c._mapping) for c in channels]) + channels = await app.state.services.database.fetch_all(select_stmt) + return cast(list[Channel], channels) async def partial_update( @@ -181,29 +163,22 @@ async def partial_update( if not isinstance(auto_join, _UnsetSentinel): update_stmt = update_stmt.values(auto_join=auto_join) - compiled = update_stmt.compile(dialect=DIALECT) - await app.state.services.database.execute(str(compiled), compiled.params) + await app.state.services.database.execute(update_stmt) - select_stmt = select(READ_PARAMS).where(ChannelsTable.name == name) - compiled = select_stmt.compile(dialect=DIALECT) - channel = await app.state.services.database.fetch_one( - str(compiled), - compiled.params, - ) - return cast(Channel, dict(channel._mapping)) if channel is not None else None + select_stmt = select(*READ_PARAMS).where(ChannelsTable.name == name) + channel = await app.state.services.database.fetch_one(select_stmt) + return cast(Channel | None, channel) async def delete_one( name: str, ) -> Channel | None: """Delete a channel from the database.""" - select_stmt = select(READ_PARAMS).where(ChannelsTable.name == name) - compiled = select_stmt.compile(dialect=DIALECT) - rec = await app.state.services.database.fetch_one(str(compiled), compiled.params) - if rec is None: + select_stmt = select(*READ_PARAMS).where(ChannelsTable.name == name) + channel = await app.state.services.database.fetch_one(select_stmt) + if channel is None: return None delete_stmt = delete(ChannelsTable).where(ChannelsTable.name == name) - compiled = delete_stmt.compile(dialect=DIALECT) - await app.state.services.database.execute(str(compiled), compiled.params) - return cast(Channel, dict(rec._mapping)) if rec is not None else None + await app.state.services.database.execute(delete_stmt) + return cast(Channel | None, channel) diff --git a/app/repositories/clans.py b/app/repositories/clans.py index b894edf75..0fa8f15f3 100644 --- a/app/repositories/clans.py +++ b/app/repositories/clans.py @@ -1,29 +1,48 @@ from __future__ import annotations -import textwrap from datetime import datetime -from typing import Any from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import delete +from sqlalchemy import func +from sqlalchemy import insert +from sqlalchemy import select +from sqlalchemy import update + import app.state.services from app._typing import UNSET from app._typing import _UnsetSentinel +from app.repositories import Base + + +class ClansTable(Base): + __tablename__ = "clans" + + id = Column("id", Integer, primary_key=True, nullable=False, autoincrement=True) + name = Column("name", String(16, collation="utf8"), nullable=False) + tag = Column("tag", String(6, collation="utf8"), nullable=False) + owner = Column("owner", Integer, nullable=False) + created_at = Column("created_at", DateTime, nullable=False) + + __table_args__ = ( + Index("clans_name_uindex", name, unique=False), + Index("clans_owner_uindex", owner, unique=True), + Index("clans_tag_uindex", tag, unique=True), + ) -# +------------+-------------+------+-----+---------+----------------+ -# | Field | Type | Null | Key | Default | Extra | -# +------------+-------------+------+-----+---------+----------------+ -# | id | int | NO | PRI | NULL | auto_increment | -# | name | varchar(16) | NO | UNI | NULL | | -# | tag | varchar(6) | NO | UNI | NULL | | -# | owner | int | NO | UNI | NULL | | -# | created_at | datetime | NO | | NULL | | -# +------------+-------------+------+-----+---------+----------------+ - -READ_PARAMS = textwrap.dedent( - """\ - id, name, tag, owner, created_at - """, + +READ_PARAMS = ( + ClansTable.id, + ClansTable.name, + ClansTable.tag, + ClansTable.owner, + ClansTable.created_at, ) @@ -35,41 +54,25 @@ class Clan(TypedDict): created_at: datetime -class ClanUpdateFields(TypedDict, total=False): - name: str - tag: str - owner: int - - async def create( name: str, tag: str, owner: int, ) -> Clan: """Create a new clan in the database.""" - query = f"""\ - INSERT INTO clans (name, tag, owner, created_at) - VALUES (:name, :tag, :owner, NOW()) - """ - params: dict[str, Any] = { - "name": name, - "tag": tag, - "owner": owner, - } - rec_id = await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM clans - WHERE id = :id - """ - params = { - "id": rec_id, - } - clan = await app.state.services.database.fetch_one(query, params) + insert_stmt = insert(ClansTable).values( + name=name, + tag=tag, + owner=owner, + created_at=func.now(), + ) + rec_id = await app.state.services.database.execute(insert_stmt) + + select_stmt = select(*READ_PARAMS).where(ClansTable.id == rec_id) + clan = await app.state.services.database.fetch_one(select_stmt) assert clan is not None - return cast(Clan, dict(clan._mapping)) + return cast(Clan, clan) async def fetch_one( @@ -82,34 +85,28 @@ async def fetch_one( if id is None and name is None and tag is None and owner is None: raise ValueError("Must provide at least one parameter.") - query = f"""\ - SELECT {READ_PARAMS} - FROM clans - WHERE id = COALESCE(:id, id) - AND name = COALESCE(:name, name) - AND tag = COALESCE(:tag, tag) - AND owner = COALESCE(:owner, owner) - """ - params: dict[str, Any] = { - "id": id, - "name": name, - "tag": tag, - "owner": owner, - } - clan = await app.state.services.database.fetch_one(query, params) - - return cast(Clan, dict(clan._mapping)) if clan is not None else None + select_stmt = select(*READ_PARAMS) + + if id is not None: + select_stmt = select_stmt.where(ClansTable.id == id) + if name is not None: + select_stmt = select_stmt.where(ClansTable.name == name) + if tag is not None: + select_stmt = select_stmt.where(ClansTable.tag == tag) + if owner is not None: + select_stmt = select_stmt.where(ClansTable.owner == owner) + + clan = await app.state.services.database.fetch_one(select_stmt) + return cast(Clan | None, clan) async def fetch_count() -> int: """Fetch the number of clans in the database.""" - query = """\ - SELECT COUNT(*) AS count - FROM clans - """ - rec = await app.state.services.database.fetch_one(query) + select_stmt = select(func.count().label("count")).select_from(ClansTable) + rec = await app.state.services.database.fetch_one(select_stmt) + assert rec is not None - return cast(int, rec._mapping["count"]) + return cast(int, rec["count"]) async def fetch_many( @@ -117,81 +114,43 @@ async def fetch_many( page_size: int | None = None, ) -> list[Clan]: """Fetch many clans from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM clans - """ - params: dict[str, Any] = {} - + select_stmt = select(*READ_PARAMS) if page is not None and page_size is not None: - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["limit"] = page_size - params["offset"] = (page - 1) * page_size + select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) - clans = await app.state.services.database.fetch_all(query, params) - return cast(list[Clan], [dict(c._mapping) for c in clans]) + clans = await app.state.services.database.fetch_all(select_stmt) + return cast(list[Clan], clans) -async def update( +async def partial_update( id: int, name: str | _UnsetSentinel = UNSET, tag: str | _UnsetSentinel = UNSET, owner: int | _UnsetSentinel = UNSET, ) -> Clan | None: """Update a clan in the database.""" - update_fields: ClanUpdateFields = {} + update_stmt = update(ClansTable).where(ClansTable.id == id) if not isinstance(name, _UnsetSentinel): - update_fields["name"] = name + update_stmt = update_stmt.values(name=name) if not isinstance(tag, _UnsetSentinel): - update_fields["tag"] = tag + update_stmt = update_stmt.values(tag=tag) if not isinstance(owner, _UnsetSentinel): - update_fields["owner"] = owner - - query = f"""\ - UPDATE clans - SET {",".join(f"{k} = :{k}" for k in update_fields)} - WHERE id = :id - """ - params: dict[str, Any] = { - "id": id, - } | update_fields - await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM clans - WHERE id = :id - """ - params = { - "id": id, - } - clan = await app.state.services.database.fetch_one(query, params) - return cast(Clan, dict(clan._mapping)) if clan is not None else None - - -async def delete(id: int) -> Clan | None: + update_stmt = update_stmt.values(owner=owner) + + await app.state.services.database.execute(update_stmt) + + select_stmt = select(*READ_PARAMS).where(ClansTable.id == id) + clan = await app.state.services.database.fetch_one(select_stmt) + return cast(Clan | None, clan) + + +async def delete_one(id: int) -> Clan | None: """Delete a clan from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM clans - WHERE id = :id - """ - params: dict[str, Any] = { - "id": id, - } - clan = await app.state.services.database.fetch_one(query, params) + select_stmt = select(*READ_PARAMS).where(ClansTable.id == id) + clan = await app.state.services.database.fetch_one(select_stmt) if clan is None: return None - query = """\ - DELETE FROM clans - WHERE id = :id - """ - params = { - "id": id, - } - await app.state.services.database.execute(query, params) - return cast(Clan, dict(clan._mapping)) + delete_stmt = delete(ClansTable).where(ClansTable.id == id) + await app.state.services.database.execute(delete_stmt) + return cast(Clan, clan) diff --git a/app/repositories/client_hashes.py b/app/repositories/client_hashes.py index 268429552..67f3494bc 100644 --- a/app/repositories/client_hashes.py +++ b/app/repositories/client_hashes.py @@ -1,29 +1,44 @@ from __future__ import annotations -import textwrap from datetime import datetime -from typing import Any from typing import TypedDict from typing import cast -import app.state.services +from sqlalchemy import CHAR +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import Integer +from sqlalchemy import func +from sqlalchemy import or_ +from sqlalchemy import select +from sqlalchemy.dialects.mysql import Insert as MysqlInsert +from sqlalchemy.dialects.mysql import insert as mysql_insert -# +--------------+------------------------+------+-----+---------+----------------+ -# | Field | Type | Null | Key | Default | Extra | -# +--------------+------------------------+------+-----+---------+----------------+ -# | userid | int | NO | PRI | NULL | | -# | osupath | char(32) | NO | PRI | NULL | | -# | adapters | char(32) | NO | PRI | NULL | | -# | uninstall_id | char(32) | NO | PRI | NULL | | -# | disk_serial | char(32) | NO | PRI | NULL | | -# | latest_time | datetime | NO | | NULL | | -# | occurrences | int | NO | | 0 | | -# +--------------+------------------------+------+-----+---------+----------------+ - -READ_PARAMS = textwrap.dedent( - """\ - userid, osupath, adapters, uninstall_id, disk_serial, latest_time, occurrences - """, +import app.state.services +from app.repositories import Base +from app.repositories.users import UsersTable + + +class ClientHashesTable(Base): + __tablename__ = "client_hashes" + + userid = Column("userid", Integer, nullable=False, primary_key=True) + osupath = Column("osupath", CHAR(32), nullable=False, primary_key=True) + adapters = Column("adapters", CHAR(32), nullable=False, primary_key=True) + uninstall_id = Column("uninstall_id", CHAR(32), nullable=False, primary_key=True) + disk_serial = Column("disk_serial", CHAR(32), nullable=False, primary_key=True) + latest_time = Column("latest_time", DateTime, nullable=False) + occurrences = Column("occurrences", Integer, nullable=False, server_default="0") + + +READ_PARAMS = ( + ClientHashesTable.userid, + ClientHashesTable.osupath, + ClientHashesTable.adapters, + ClientHashesTable.uninstall_id, + ClientHashesTable.disk_serial, + ClientHashesTable.latest_time, + ClientHashesTable.occurrences, ) @@ -50,44 +65,37 @@ async def create( disk_serial: str, ) -> ClientHash: """Create a new client hash entry in the database.""" - query = f"""\ - INSERT INTO client_hashes (userid, osupath, adapters, uninstall_id, - disk_serial, latest_time, occurrences) - VALUES (:userid, :osupath, :adapters, :uninstall_id, - :disk_serial, NOW(), 1) - ON DUPLICATE KEY UPDATE - latest_time = NOW(), - occurrences = occurrences + 1 - """ - params: dict[str, Any] = { - "userid": userid, - "osupath": osupath, - "adapters": adapters, - "uninstall_id": uninstall_id, - "disk_serial": disk_serial, - } - await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM client_hashes - WHERE userid = :userid - AND osupath = :osupath - AND adapters = :adapters - AND uninstall_id = :uninstall_id - AND disk_serial = :disk_serial - """ - params = { - "userid": userid, - "osupath": osupath, - "adapters": adapters, - "uninstall_id": uninstall_id, - "disk_serial": disk_serial, - } - client_hash = await app.state.services.database.fetch_one(query, params) + insert_stmt: MysqlInsert = ( + mysql_insert(ClientHashesTable) + .values( + userid=userid, + osupath=osupath, + adapters=adapters, + uninstall_id=uninstall_id, + disk_serial=disk_serial, + latest_time=func.now(), + occurrences=1, + ) + .on_duplicate_key_update( + latest_time=func.now(), + occurrences=ClientHashesTable.occurrences + 1, + ) + ) + + await app.state.services.database.execute(insert_stmt) + + select_stmt = ( + select(*READ_PARAMS) + .where(ClientHashesTable.userid == userid) + .where(ClientHashesTable.osupath == osupath) + .where(ClientHashesTable.adapters == adapters) + .where(ClientHashesTable.uninstall_id == uninstall_id) + .where(ClientHashesTable.disk_serial == disk_serial) + ) + client_hash = await app.state.services.database.fetch_one(select_stmt) assert client_hash is not None - return cast(ClientHash, dict(client_hash._mapping)) + return cast(ClientHash, client_hash) async def fetch_any_hardware_matches_for_user( @@ -102,42 +110,22 @@ async def fetch_any_hardware_matches_for_user( `adapters`, `uninstall_id` or `disk_serial` match other users from the database. """ - hw_check_params: dict[str, Any] + select_stmt = ( + select(*READ_PARAMS, UsersTable.name, UsersTable.priv) + .join(UsersTable, ClientHashesTable.userid == UsersTable.id) + .where(ClientHashesTable.userid != userid) + ) + if running_under_wine: - hw_check_subquery = """\ - h.uninstall_id = :uninstall_id - """ - hw_check_params = {"uninstall_id": uninstall_id} + select_stmt = select_stmt.where(ClientHashesTable.uninstall_id == uninstall_id) else: - assert ( - adapters is not None - and uninstall_id is not None - and disk_serial is not None + select_stmt = select_stmt.where( + or_( + ClientHashesTable.adapters == adapters, + ClientHashesTable.uninstall_id == uninstall_id, + ClientHashesTable.disk_serial == disk_serial, + ), ) - hw_check_subquery = """\ - h.adapters = :adapters - OR h.uninstall_id = :uninstall_id - OR h.disk_serial = :disk_serial - """ - hw_check_params = { - "adapters": adapters, - "uninstall_id": uninstall_id, - "disk_serial": disk_serial, - } - - query = f"""\ - SELECT {READ_PARAMS}, u.name, u.priv - FROM client_hashes h - INNER JOIN users u ON h.userid = u.id - WHERE h.userid != :userid AND ({hw_check_subquery}) - """ - params: dict[str, Any] = { - "userid": userid, - **hw_check_params, - } - - client_hashes = await app.state.services.database.fetch_all(query, params) - return cast( - list[ClientHashWithPlayer], - [dict(client_hash._mapping) for client_hash in client_hashes], - ) + + client_hashes = await app.state.services.database.fetch_all(select_stmt) + return cast(list[ClientHashWithPlayer], client_hashes) diff --git a/app/repositories/comments.py b/app/repositories/comments.py index 3746512d6..bfd96ae59 100644 --- a/app/repositories/comments.py +++ b/app/repositories/comments.py @@ -1,30 +1,23 @@ from __future__ import annotations -import textwrap from enum import StrEnum -from typing import Any from typing import TypedDict from typing import cast -import app.state.services +from sqlalchemy import CHAR +from sqlalchemy import Column +from sqlalchemy import Enum +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import and_ +from sqlalchemy import insert +from sqlalchemy import or_ +from sqlalchemy import select +from sqlalchemy.dialects.mysql import FLOAT -# +-----------------+-----------------------------+------+-----+---------+----------------+ -# | Field | Type | Null | Key | Default | Extra | -# +-----------------+-----------------------------+------+-----+---------+----------------+ -# | id | int | NO | PRI | NULL | auto_increment | -# | target_id | int | NO | | NULL | | -# | target_type | enum('replay','map','song') | NO | | NULL | | -# | userid | int | NO | | NULL | | -# | time | float(6,3) | NO | | NULL | | -# | comment | varchar(80) | NO | | NULL | | -# | colour | char(6) | YES | | NULL | | -# +-----------------+-----------------------------+------+-----+---------+----------------+ - -READ_PARAMS = textwrap.dedent( - """\ - c.id, c.target_id, c.target_type, c.userid, c.time, c.comment, c.colour - """, -) +import app.state.services +from app.repositories import Base +from app.repositories.users import UsersTable class TargetType(StrEnum): @@ -33,6 +26,29 @@ class TargetType(StrEnum): SONG = "song" +class CommentsTable(Base): + __tablename__ = "comments" + + id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True) + target_id = Column("target_id", nullable=False) + target_type = Column(Enum(TargetType, name="target_type"), nullable=False) + userid = Column("userid", Integer, nullable=False) + time = Column("time", FLOAT(precision=6, scale=3), nullable=False) + comment = Column("comment", String(80, collation="utf8"), nullable=False) + colour = Column("colour", CHAR(6), nullable=True) + + +READ_PARAMS = ( + CommentsTable.id, + CommentsTable.target_id, + CommentsTable.target_type, + CommentsTable.userid, + CommentsTable.time, + CommentsTable.comment, + CommentsTable.colour, +) + + class Comment(TypedDict): id: int target_id: int @@ -52,32 +68,21 @@ async def create( colour: str | None, ) -> Comment: """Create a new comment entry in the database.""" - query = f"""\ - INSERT INTO comments (target_id, target_type, userid, time, comment, colour) - VALUES (:target_id, :target_type, :userid, :time, :comment, :colour) - """ - params: dict[str, Any] = { - "target_id": target_id, - "target_type": target_type, - "userid": userid, - "time": time, - "comment": comment, - "colour": colour, - } - rec_id = await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM comments c - WHERE id = :id - """ - params = { - "id": rec_id, - } - _comment = await app.state.services.database.fetch_one(query, params) + insert_stmt = insert(CommentsTable).values( + target_id=target_id, + target_type=target_type, + userid=userid, + time=time, + comment=comment, + colour=colour, + ) + rec_id = await app.state.services.database.execute(insert_stmt) + + select_stmt = select(*READ_PARAMS).where(CommentsTable.id == rec_id) + _comment = await app.state.services.database.fetch_one(select_stmt) assert _comment is not None - return cast(Comment, dict(_comment._mapping)) + return cast(Comment, _comment) class CommentWithUserPrivileges(Comment): @@ -95,19 +100,26 @@ async def fetch_all_relevant_to_replay( - `map_set_id` - `map_id` """ - query = f"""\ - SELECT {READ_PARAMS}, u.priv - FROM comments c - INNER JOIN users u ON u.id = c.userid - WHERE (c.target_type = 'replay' AND c.target_id = :score_id) - OR (c.target_type = 'song' AND c.target_id = :map_set_id) - OR (c.target_type = 'map' AND c.target_id = :map_id) - """ - params: dict[str, Any] = { - "score_id": score_id, - "map_set_id": map_set_id, - "map_id": map_id, - } - - comments = await app.state.services.database.fetch_all(query, params) - return cast(list[CommentWithUserPrivileges], [dict(c._mapping) for c in comments]) + select_stmt = ( + select(READ_PARAMS, UsersTable.priv) + .join(UsersTable, CommentsTable.userid == UsersTable.id) + .where( + or_( + and_( + CommentsTable.target_type == TargetType.REPLAY, + CommentsTable.target_id == score_id, + ), + and_( + CommentsTable.target_type == TargetType.SONG, + CommentsTable.target_id == map_set_id, + ), + and_( + CommentsTable.target_type == TargetType.BEATMAP, + CommentsTable.target_id == map_id, + ), + ), + ) + ) + + comments = await app.state.services.database.fetch_all(select_stmt) + return cast(list[CommentWithUserPrivileges], comments) diff --git a/app/repositories/favourites.py b/app/repositories/favourites.py index 7cdb52053..1c34547aa 100644 --- a/app/repositories/favourites.py +++ b/app/repositories/favourites.py @@ -1,24 +1,30 @@ from __future__ import annotations -import textwrap -from typing import Any from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import func +from sqlalchemy import insert +from sqlalchemy import select + import app.state.services +from app.repositories import Base + + +class FavouritesTable(Base): + __tablename__ = "favourites" -# +------------+------+------+-----+---------+-------+ -# | Field | Type | Null | Key | Default | Extra | -# +------------+------+------+-----+---------+-------+ -# | userid | int | NO | PRI | NULL | | -# | setid | int | NO | PRI | NULL | | -# | created_at | int | NO | | 0 | | -# +------------+------+------+-----+---------+-------+ - -READ_PARAMS = textwrap.dedent( - """\ - userid, setid, created_at - """, + userid = Column("userid", Integer, nullable=False, primary_key=True) + setid = Column("setid", Integer, nullable=False, primary_key=True) + created_at = Column("created_at", Integer, nullable=False, server_default="0") + + +READ_PARAMS = ( + FavouritesTable.userid, + FavouritesTable.setid, + FavouritesTable.created_at, ) @@ -33,59 +39,37 @@ async def create( setid: int, ) -> Favourite: """Create a new favourite mapset entry in the database.""" - query = f"""\ - INSERT INTO favourites (userid, setid, created_at) - VALUES (:userid, :setid, UNIX_TIMESTAMP()) - """ - params: dict[str, Any] = { - "userid": userid, - "setid": setid, - } - await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM favourites - WHERE userid = :userid - AND setid = :setid - """ - params = { - "userid": userid, - "setid": setid, - } - favourite = await app.state.services.database.fetch_one(query, params) + insert_stmt = insert(FavouritesTable).values( + userid=userid, + setid=setid, + created_at=func.unix_timestamp(), + ) + await app.state.services.database.execute(insert_stmt) + + select_stmt = ( + select(*READ_PARAMS) + .where(FavouritesTable.userid == userid) + .where(FavouritesTable.setid == setid) + ) + favourite = await app.state.services.database.fetch_one(select_stmt) assert favourite is not None - return cast(Favourite, dict(favourite._mapping)) + return cast(Favourite, favourite) async def fetch_all(userid: int) -> list[Favourite]: """Fetch all favourites from a player.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM favourites - WHERE userid = :userid - """ - params: dict[str, Any] = { - "userid": userid, - } - - favourites = await app.state.services.database.fetch_all(query, params) - return cast(list[Favourite], [dict(f._mapping) for f in favourites]) + select_stmt = select(*READ_PARAMS).where(FavouritesTable.userid == userid) + favourites = await app.state.services.database.fetch_all(select_stmt) + return cast(list[Favourite], favourites) async def fetch_one(userid: int, setid: int) -> Favourite | None: """Check if a mapset is already a favourite.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM favourites - WHERE userid = :userid - AND setid = :setid - """ - params: dict[str, Any] = { - "userid": userid, - "setid": setid, - } - - favourite = await app.state.services.database.fetch_one(query, params) - return cast(Favourite, dict(favourite._mapping)) if favourite else None + select_stmt = ( + select(*READ_PARAMS) + .where(FavouritesTable.userid == userid) + .where(FavouritesTable.setid == setid) + ) + favourite = await app.state.services.database.fetch_one(select_stmt) + return cast(Favourite | None, favourite) diff --git a/app/repositories/ingame_logins.py b/app/repositories/ingame_logins.py index f639ece46..0a44983b9 100644 --- a/app/repositories/ingame_logins.py +++ b/app/repositories/ingame_logins.py @@ -1,29 +1,41 @@ from __future__ import annotations -import textwrap from datetime import date from datetime import datetime -from typing import Any from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import Date +from sqlalchemy import DateTime +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import func +from sqlalchemy import insert +from sqlalchemy import select + import app.state.services +from app.repositories import Base + + +class IngameLoginsTable(Base): + __tablename__ = "ingame_logins" + + id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True) + userid = Column("userid", Integer, nullable=False) + ip = Column("ip", String(45), nullable=False) + osu_ver = Column("osu_ver", Date, nullable=False) + osu_stream = Column("osu_stream", String(11), nullable=False) + datetime = Column("datetime", DateTime, nullable=False) + -# +--------------+------------------------+------+-----+---------+-------+ -# | Field | Type | Null | Key | Default | Extra | -# +--------------+------------------------+------+-----+---------+-------+ -# | id | int | NO | PRI | NULL | | -# | userid | int | NO | | NULL | | -# | ip | varchar(45) | NO | | NULL | | -# | osu_ver | date | NO | | NULL | | -# | osu_stream | varchar(11) | NO | | NULL | | -# | datetime | datetime | NO | | NULL | | -# +--------------+------------------------+------+-----+---------+-------+ - -READ_PARAMS = textwrap.dedent( - """\ - id, userid, ip, osu_ver, osu_stream, datetime - """, +READ_PARAMS = ( + IngameLoginsTable.id, + IngameLoginsTable.userid, + IngameLoginsTable.ip, + IngameLoginsTable.osu_ver, + IngameLoginsTable.osu_stream, + IngameLoginsTable.datetime, ) @@ -50,45 +62,27 @@ async def create( osu_stream: str, ) -> IngameLogin: """Create a new login entry in the database.""" - query = f"""\ - INSERT INTO ingame_logins (userid, ip, osu_ver, osu_stream, datetime) - VALUES (:userid, :ip, :osu_ver, :osu_stream, NOW()) - """ - params: dict[str, Any] = { - "userid": user_id, - "ip": ip, - "osu_ver": osu_ver, - "osu_stream": osu_stream, - } - rec_id = await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM ingame_logins - WHERE id = :id - """ - params = { - "id": rec_id, - } - ingame_login = await app.state.services.database.fetch_one(query, params) + insert_stmt = insert(IngameLoginsTable).values( + userid=user_id, + ip=ip, + osu_ver=osu_ver, + osu_stream=osu_stream, + datetime=func.now(), + ) + rec_id = await app.state.services.database.execute(insert_stmt) + + select_stmt = select(*READ_PARAMS).where(IngameLoginsTable.id == rec_id) + ingame_login = await app.state.services.database.fetch_one(select_stmt) assert ingame_login is not None - return cast(IngameLogin, dict(ingame_login._mapping)) + return cast(IngameLogin, ingame_login) async def fetch_one(id: int) -> IngameLogin | None: """Fetch a login entry from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM ingame_logins - WHERE id = :id - """ - params: dict[str, Any] = { - "id": id, - } - ingame_login = await app.state.services.database.fetch_one(query, params) - - return cast(IngameLogin, ingame_login) if ingame_login is not None else None + select_stmt = select(*READ_PARAMS).where(IngameLoginsTable.id == id) + ingame_login = await app.state.services.database.fetch_one(select_stmt) + return cast(IngameLogin | None, ingame_login) async def fetch_count( @@ -96,19 +90,15 @@ async def fetch_count( ip: str | None = None, ) -> int: """Fetch the number of logins in the database.""" - query = """\ - SELECT COUNT(*) AS count - FROM ingame_logins - WHERE userid = COALESCE(:userid, userid) - AND ip = COALESCE(:ip, ip) - """ - params: dict[str, Any] = { - "userid": user_id, - "ip": ip, - } - rec = await app.state.services.database.fetch_one(query, params) + select_stmt = select(func.count().label("count")).select_from(IngameLoginsTable) + if user_id is not None: + select_stmt = select_stmt.where(IngameLoginsTable.userid == user_id) + if ip is not None: + select_stmt = select_stmt.where(IngameLoginsTable.ip == ip) + + rec = await app.state.services.database.fetch_one(select_stmt) assert rec is not None - return cast(int, rec._mapping["count"]) + return cast(int, rec["count"]) async def fetch_many( @@ -120,28 +110,19 @@ async def fetch_many( page_size: int | None = None, ) -> list[IngameLogin]: """Fetch a list of logins from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM ingame_logins - WHERE userid = COALESCE(:userid, userid) - AND ip = COALESCE(:ip, ip) - AND osu_ver = COALESCE(:osu_ver, osu_ver) - AND osu_stream = COALESCE(:osu_stream, osu_stream) - """ - params: dict[str, Any] = { - "userid": user_id, - "ip": ip, - "osu_ver": osu_ver, - "osu_stream": osu_stream, - } + select_stmt = select(*READ_PARAMS) + + if user_id is not None: + select_stmt = select_stmt.where(IngameLoginsTable.userid == user_id) + if ip is not None: + select_stmt = select_stmt.where(IngameLoginsTable.ip == ip) + if osu_ver is not None: + select_stmt = select_stmt.where(IngameLoginsTable.osu_ver == osu_ver) + if osu_stream is not None: + select_stmt = select_stmt.where(IngameLoginsTable.osu_stream == osu_stream) if page is not None and page_size is not None: - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["limit"] = page_size - params["offset"] = (page - 1) * page_size - - ingame_logins = await app.state.services.database.fetch_all(query, params) + select_stmt.limit(page_size).offset((page - 1) * page_size) + + ingame_logins = await app.state.services.database.fetch_all(select_stmt) return cast(list[IngameLogin], ingame_logins) diff --git a/app/repositories/logs.py b/app/repositories/logs.py index 082acb24a..a533fcbe7 100644 --- a/app/repositories/logs.py +++ b/app/repositories/logs.py @@ -1,28 +1,39 @@ from __future__ import annotations -import textwrap from datetime import datetime -from typing import Any from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import func +from sqlalchemy import insert +from sqlalchemy import select + import app.state.services +from app.repositories import Base + + +class LogTable(Base): + __tablename__ = "logs" + + id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True) + _from = Column("from", Integer, nullable=False) + to = Column("to", Integer, nullable=False) + action = Column("action", String(32), nullable=False) + msg = Column("msg", String(2048, collation="utf8"), nullable=True) + time = Column("time", DateTime, nullable=False, onupdate=func.now()) -# +--------------+------------------------+------+-----+---------+-----------------------------+ -# | Field | Type | Null | Key | Default | Extra | -# +--------------+------------------------+------+-----+---------+-----------------------------+ -# | id | int | NO | PRI | NULL | auto_increment | -# | from | int | NO | | NULL | | -# | to | int | NO | | NULL | | -# | action | varchar(32) | NO | | NULL | | -# | msg | varchar(2048) | YES | | NULL | | -# | time | datetime | NO | | NULL | on update current_timestamp | -# +--------------+------------------------+------+-----+---------+-----------------------------+ -READ_PARAMS = textwrap.dedent( - """\ - `id`, `from`, `to`, `action`, `msg`, `time` - """, +READ_PARAMS = ( + LogTable.id, + LogTable._from.label("from"), + LogTable.to, + LogTable.action, + LogTable.msg, + LogTable.time, ) @@ -42,26 +53,18 @@ async def create( msg: str, ) -> Log: """Create a new log entry in the database.""" - query = f"""\ - INSERT INTO logs (`from`, `to`, `action`, `msg`, `time`) - VALUES (:from, :to, :action, :msg, NOW()) - """ - params: dict[str, Any] = { - "from": _from, - "to": to, - "action": action, - "msg": msg, - } - rec_id = await app.state.services.database.execute(query, params) + insert_stmt = insert(LogTable).values( + { + "from": _from, + "to": to, + "action": action, + "msg": msg, + "time": func.now(), + }, + ) + rec_id = await app.state.services.database.execute(insert_stmt) - query = f"""\ - SELECT {READ_PARAMS} - FROM logs - WHERE id = :id - """ - params = { - "id": rec_id, - } - rec = await app.state.services.database.fetch_one(query, params) - assert rec is not None - return cast(Log, dict(rec._mapping)) + select_stmt = select(*READ_PARAMS).where(LogTable.id == rec_id) + log = await app.state.services.database.fetch_one(select_stmt) + assert log is not None + return cast(Log, log) diff --git a/app/repositories/mail.py b/app/repositories/mail.py index c04928f33..6ce18e6e7 100644 --- a/app/repositories/mail.py +++ b/app/repositories/mail.py @@ -1,21 +1,40 @@ from __future__ import annotations -import textwrap from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import func +from sqlalchemy import insert +from sqlalchemy import select +from sqlalchemy import update +from sqlalchemy.dialects.mysql import TINYINT + import app.state.services +from app.repositories import Base + + +class MailTable(Base): + __tablename__ = "mail" + + id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True) + from_id = Column("from_id", Integer, nullable=False) + to_id = Column("to_id", Integer, nullable=False) + msg = Column("msg", String(2048, collation="utf8"), nullable=False) + time = Column("time", Integer, nullable=True) + read = Column("read", TINYINT(1), nullable=False, server_default="0") -# +--------------+------------------------+------+-----+---------+-------+ -# | Field | Type | Null | Key | Default | Extra | -# +--------------+------------------------+------+-----+---------+-------+ -# | id | int | NO | PRI | NULL | | -# | from_id | int | NO | | NULL | | -# | to_id | int | NO | | NULL | | -# | msg | varchar(2048) | NO | | NULL | | -# | time | int | YES | | NULL | | -# | read | tinyint(1) | NO | | NULL | | -# +--------------+------------------------+------+-----+---------+-------+ + +READ_PARAMS = ( + MailTable.id, + MailTable.from_id, + MailTable.to_id, + MailTable.msg, + MailTable.time, + MailTable.read, +) class Mail(TypedDict): @@ -32,34 +51,23 @@ class MailWithUsernames(Mail): to_name: str -READ_PARAMS = textwrap.dedent( - """\ - id, from_id, to_id, msg, time, `read` - """, -) - - async def create(from_id: int, to_id: int, msg: str) -> Mail: """Create a new mail entry in the database.""" - query = f"""\ - INSERT INTO mail (from_id, to_id, msg, time) - VALUES (:from_id, :to_id, :msg, UNIX_TIMESTAMP()) - """ - params = {"from_id": from_id, "to_id": to_id, "msg": msg} - rec_id = await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM mail - WHERE id = :id - """ - params = { - "id": rec_id, - } - mail = await app.state.services.database.fetch_one(query, params) - + insert_stmt = insert(MailTable).values( + from_id=from_id, + to_id=to_id, + msg=msg, + time=func.unix_timestamp(), + ) + rec_id = await app.state.services.database.execute(insert_stmt) + + select_stmt = select(*READ_PARAMS).where(MailTable.id == rec_id) + mail = await app.state.services.database.fetch_one(select_stmt) assert mail is not None - return cast(Mail, dict(mail._mapping)) + return cast(Mail, mail) + + +from app.repositories.users import UsersTable async def fetch_all_mail_to_user( @@ -67,51 +75,39 @@ async def fetch_all_mail_to_user( read: bool | None = None, ) -> list[MailWithUsernames]: """Fetch all of mail to a given target from the database.""" - query = f"""\ - SELECT {READ_PARAMS}, - (SELECT name FROM users WHERE id = m.`from_id`) AS `from_name`, - (SELECT name FROM users WHERE id = m.`to_id`) AS `to_name` - FROM `mail` m - WHERE m.`to_id` = :to_id - AND m.`read` = COALESCE(:read, `read`) - """ - params = { - "to_id": user_id, - "read": read, - } - - mail = await app.state.services.database.fetch_all(query, params) - return cast(list[MailWithUsernames], [dict(m._mapping) for m in mail]) + from_subquery = select(UsersTable.name).where(UsersTable.id == MailTable.from_id) + to_subquery = select(UsersTable.name).where(UsersTable.id == MailTable.to_id) + + select_stmt = select( + *READ_PARAMS, + from_subquery.label("from_name"), + to_subquery.label("to_name"), + ).where(MailTable.to_id == user_id) + + if read is not None: + select_stmt = select_stmt.where(MailTable.read == read) + + mail = await app.state.services.database.fetch_all(select_stmt) + return cast(list[MailWithUsernames], mail) async def mark_conversation_as_read(to_id: int, from_id: int) -> list[Mail]: """Mark any mail in a user's conversation with another user as read.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM mail - WHERE to_id = :to_id - AND from_id = :from_id - AND `read` = False - """ - params = { - "to_id": to_id, - "from_id": from_id, - } - all_mail = await app.state.services.database.fetch_all(query, params) - if not all_mail: + select_stmt = select(*READ_PARAMS).where( + MailTable.to_id == to_id, + MailTable.from_id == from_id, + MailTable.read == False, + ) + mail = await app.state.services.database.fetch_all(select_stmt) + if not mail: return [] - query = """\ - UPDATE mail - SET `read` = True - WHERE to_id = :to_id - AND from_id = :from_id - AND `read` = False - """ - params = { - "to_id": to_id, - "from_id": from_id, - } - await app.state.services.database.execute(query, params) - - return cast(list[Mail], [dict(mail._mapping) for mail in all_mail]) + update_stmt = ( + update(MailTable) + .where(MailTable.to_id == to_id) + .where(MailTable.from_id == from_id) + .where(MailTable.read == False) + .values(read=True) + ) + await app.state.services.database.execute(update_stmt) + return cast(list[Mail], mail) diff --git a/app/repositories/map_requests.py b/app/repositories/map_requests.py index 9e2c98af7..1b193f772 100644 --- a/app/repositories/map_requests.py +++ b/app/repositories/map_requests.py @@ -1,27 +1,38 @@ from __future__ import annotations -import textwrap from datetime import datetime from typing import Any from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import Integer +from sqlalchemy import func +from sqlalchemy import insert +from sqlalchemy import select +from sqlalchemy import update +from sqlalchemy.dialects.mysql import TINYINT + import app.state.services +from app.repositories import Base + + +class MapRequestsTable(Base): + __tablename__ = "map_requests" + + id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True) + map_id = Column("map_id", Integer, nullable=False) + player_id = Column("player_id", Integer, nullable=False) + datetime = Column("datetime", DateTime, nullable=False) + active = Column("active", TINYINT(1), nullable=False) -# +--------------+------------------------+------+-----+---------+----------------+ -# | Field | Type | Null | Key | Default | Extra | -# +--------------+------------------------+------+-----+---------+----------------+ -# | id | int | NO | PRI | NULL | auto_increment | -# | map_id | int | NO | | osu! | | -# | player_id | int | NO | | NULL | | -# | datetime | datetime | NO | | NULL | | -# | active | tinyint(1) | NO | | NULL | | -# +--------------+------------------------+------+-----+---------+----------------+ - -READ_PARAMS = textwrap.dedent( - """\ - id, map_id, player_id, datetime, active - """, + +READ_PARAMS = ( + MapRequestsTable.id, + MapRequestsTable.map_id, + MapRequestsTable.player_id, + MapRequestsTable.datetime, ) @@ -39,29 +50,19 @@ async def create( active: bool, ) -> MapRequest: """Create a new map request entry in the database.""" - query = f"""\ - INSERT INTO map_requests (map_id, player_id, datetime, active) - VALUES (:map_id, :player_id, NOW(), :active) - """ - params: dict[str, Any] = { - "map_id": map_id, - "player_id": player_id, - "active": active, - } - rec_id = await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM map_requests - WHERE id = :id - """ - params = { - "id": rec_id, - } - map_request = await app.state.services.database.fetch_one(query, params) - + insert_stmt = insert(MapRequestsTable).values( + map_id=map_id, + player_id=player_id, + datetime=func.now(), + active=active, + ) + rec_id = await app.state.services.database.execute(insert_stmt) + + select_stmt = select(*READ_PARAMS).where(MapRequestsTable.id == rec_id) + map_request = await app.state.services.database.fetch_one(select_stmt) assert map_request is not None - return cast(MapRequest, dict(map_request._mapping)) + + return cast(MapRequest, map_request) async def fetch_all( @@ -70,42 +71,27 @@ async def fetch_all( active: bool | None = None, ) -> list[MapRequest]: """Fetch a list of map requests from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM map_requests - WHERE map_id = COALESCE(:map_id, map_id) - AND player_id = COALESCE(:player_id, player_id) - AND active = COALESCE(:active, active) - """ - params: dict[str, Any] = { - "map_id": map_id, - "player_id": player_id, - "active": active, - } - - map_requests = await app.state.services.database.fetch_all(query, params) - return cast(list[MapRequest], [dict(m._mapping) for m in map_requests]) + select_stmt = select(*READ_PARAMS) + if map_id is not None: + select_stmt = select_stmt.where(MapRequestsTable.map_id == map_id) + if player_id is not None: + select_stmt = select_stmt.where(MapRequestsTable.player_id == player_id) + if active is not None: + select_stmt = select_stmt.where(MapRequestsTable.active == active) + + map_requests = await app.state.services.database.fetch_all(select_stmt) + return cast(list[MapRequest], map_requests) async def mark_batch_as_inactive(map_ids: list[Any]) -> list[MapRequest]: """Mark a map request as inactive.""" - query = f"""\ - UPDATE map_requests - SET active = False - WHERE map_id IN :map_ids - """ - params: dict[str, Any] = { - "map_ids": map_ids, - } - await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM map_requests - WHERE map_id IN :map_ids - """ - params = { - "map_ids": map_ids, - } - map_requests = await app.state.services.database.fetch_all(query, params) - return cast(list[MapRequest], [dict(m._mapping) for m in map_requests]) + update_stmt = ( + update(MapRequestsTable) + .where(MapRequestsTable.map_id.in_(map_ids)) + .values(active=False) + ) + await app.state.services.database.execute(update_stmt) + + select_stmt = select(*READ_PARAMS).where(MapRequestsTable.map_id.in_(map_ids)) + map_requests = await app.state.services.database.fetch_all(select_stmt) + return cast(list[MapRequest], map_requests) diff --git a/app/repositories/maps.py b/app/repositories/maps.py index 049e8318e..a870b7421 100644 --- a/app/repositories/maps.py +++ b/app/repositories/maps.py @@ -1,49 +1,103 @@ from __future__ import annotations -import textwrap from datetime import datetime -from typing import Any +from enum import StrEnum from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import Enum +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import delete +from sqlalchemy import func +from sqlalchemy import insert +from sqlalchemy import select +from sqlalchemy import update +from sqlalchemy.dialects.mysql import FLOAT +from sqlalchemy.dialects.mysql import TINYINT + import app.state.services from app._typing import UNSET from app._typing import _UnsetSentinel +from app.repositories import Base + + +class MapServer(StrEnum): + OSU = "osu!" + PRIVATE = "private" + + +class MapsTable(Base): + __tablename__ = "maps" -# +--------------+------------------------+------+-----+---------+-------+ -# | Field | Type | Null | Key | Default | Extra | -# +--------------+------------------------+------+-----+---------+-------+ -# | id | int | NO | PRI | NULL | | -# | server | enum('osu!','private') | NO | | osu! | | -# | set_id | int | NO | | NULL | | -# | status | int | NO | | NULL | | -# | md5 | char(32) | NO | UNI | NULL | | -# | artist | varchar(128) | NO | | NULL | | -# | title | varchar(128) | NO | | NULL | | -# | version | varchar(128) | NO | | NULL | | -# | creator | varchar(19) | NO | | NULL | | -# | filename | varchar(256) | NO | | NULL | | -# | last_update | datetime | NO | | NULL | | -# | total_length | int | NO | | NULL | | -# | max_combo | int | NO | | NULL | | -# | frozen | tinyint(1) | NO | | 0 | | -# | plays | int | NO | | 0 | | -# | passes | int | NO | | 0 | | -# | mode | tinyint(1) | NO | | 0 | | -# | bpm | float(12,2) | NO | | 0.00 | | -# | cs | float(4,2) | NO | | 0.00 | | -# | ar | float(4,2) | NO | | 0.00 | | -# | od | float(4,2) | NO | | 0.00 | | -# | hp | float(4,2) | NO | | 0.00 | | -# | diff | float(6,3) | NO | | 0.000 | | -# +--------------+------------------------+------+-----+---------+-------+ - -READ_PARAMS = textwrap.dedent( - """\ - id, server, set_id, status, md5, artist, title, version, creator, filename, - last_update, total_length, max_combo, frozen, plays, passes, mode, bpm, cs, - ar, od, hp, diff - """, + server = Column( + Enum(MapServer, name="server"), + nullable=False, + server_default="osu!", + primary_key=True, + ) + id = Column(Integer, nullable=False, primary_key=True) + set_id = Column(Integer, nullable=False) + status = Column(Integer, nullable=False) + md5 = Column(String(32), nullable=False) + artist = Column(String(128, collation="utf8"), nullable=False) + title = Column(String(128, collation="utf8"), nullable=False) + version = Column(String(128, collation="utf8"), nullable=False) + creator = Column(String(19, collation="utf8"), nullable=False) + filename = Column(String(256, collation="utf8"), nullable=False) + last_update = Column(DateTime, nullable=False) + total_length = Column(Integer, nullable=False) + max_combo = Column(Integer, nullable=False) + frozen = Column(TINYINT(1), nullable=False, server_default="0") + plays = Column(Integer, nullable=False, server_default="0") + passes = Column(Integer, nullable=False, server_default="0") + mode = Column(TINYINT(1), nullable=False, server_default="0") + bpm = Column(FLOAT(12, 2), nullable=False, server_default="0.00") + cs = Column(FLOAT(4, 2), nullable=False, server_default="0.00") + ar = Column(FLOAT(4, 2), nullable=False, server_default="0.00") + od = Column(FLOAT(4, 2), nullable=False, server_default="0.00") + hp = Column(FLOAT(4, 2), nullable=False, server_default="0.00") + diff = Column(FLOAT(6, 3), nullable=False, server_default="0.000") + + __table_args__ = ( + Index("maps_set_id_index", "set_id"), + Index("maps_status_index", "status"), + Index("maps_filename_index", "filename"), + Index("maps_plays_index", "plays"), + Index("maps_mode_index", "mode"), + Index("maps_frozen_index", "frozen"), + Index("maps_md5_uindex", "md5", unique=True), + Index("maps_id_uindex", "id", unique=True), + ) + + +READ_PARAMS = ( + MapsTable.id, + MapsTable.server, + MapsTable.set_id, + MapsTable.status, + MapsTable.md5, + MapsTable.artist, + MapsTable.title, + MapsTable.version, + MapsTable.creator, + MapsTable.filename, + MapsTable.last_update, + MapsTable.total_length, + MapsTable.max_combo, + MapsTable.frozen, + MapsTable.plays, + MapsTable.passes, + MapsTable.mode, + MapsTable.bpm, + MapsTable.cs, + MapsTable.ar, + MapsTable.od, + MapsTable.hp, + MapsTable.diff, ) @@ -73,31 +127,6 @@ class Map(TypedDict): diff: float -class MapUpdateFields(TypedDict, total=False): - server: str - set_id: int - status: int - md5: str - artist: str - title: str - version: str - creator: str - filename: str - last_update: datetime - total_length: int - max_combo: int - frozen: bool - plays: int - passes: int - mode: int - bpm: float - cs: float - ar: float - od: float - hp: float - diff: float - - async def create( id: int, server: str, @@ -124,55 +153,37 @@ async def create( diff: float, ) -> Map: """Create a new beatmap entry in the database.""" - query = f"""\ - INSERT INTO maps (id, server, set_id, status, md5, artist, title, - version, creator, filename, last_update, - total_length, max_combo, frozen, plays, passes, - mode, bpm, cs, ar, od, hp, diff) - VALUES (:id, :server, :set_id, :status, :md5, :artist, :title, - :version, :creator, :filename, :last_update, :total_length, - :max_combo, :frozen, :plays, :passes, :mode, :bpm, :cs, :ar, - :od, :hp, :diff) - """ - params: dict[str, Any] = { - "id": id, - "server": server, - "set_id": set_id, - "status": status, - "md5": md5, - "artist": artist, - "title": title, - "version": version, - "creator": creator, - "filename": filename, - "last_update": last_update, - "total_length": total_length, - "max_combo": max_combo, - "frozen": frozen, - "plays": plays, - "passes": passes, - "mode": mode, - "bpm": bpm, - "cs": cs, - "ar": ar, - "od": od, - "hp": hp, - "diff": diff, - } - rec_id = await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM maps - WHERE id = :id - """ - params = { - "id": rec_id, - } - map = await app.state.services.database.fetch_one(query, params) + insert_stmt = insert(MapsTable).values( + id=id, + server=server, + set_id=set_id, + status=status, + md5=md5, + artist=artist, + title=title, + version=version, + creator=creator, + filename=filename, + last_update=last_update, + total_length=total_length, + max_combo=max_combo, + frozen=frozen, + plays=plays, + passes=passes, + mode=mode, + bpm=bpm, + cs=cs, + ar=ar, + od=od, + hp=hp, + diff=diff, + ) + rec_id = await app.state.services.database.execute(insert_stmt) + select_stmt = select(*READ_PARAMS).where(MapsTable.id == rec_id) + map = await app.state.services.database.fetch_one(select_stmt) assert map is not None - return cast(Map, dict(map._mapping)) + return cast(Map, map) async def fetch_one( @@ -184,21 +195,16 @@ async def fetch_one( if id is None and md5 is None and filename is None: raise ValueError("Must provide at least one parameter.") - query = f"""\ - SELECT {READ_PARAMS} - FROM maps - WHERE id = COALESCE(:id, id) - AND md5 = COALESCE(:md5, md5) - AND filename = COALESCE(:filename, filename) - """ - params: dict[str, Any] = { - "id": id, - "md5": md5, - "filename": filename, - } - map = await app.state.services.database.fetch_one(query, params) + select_stmt = select(*READ_PARAMS) + if id is not None: + select_stmt = select_stmt.where(MapsTable.id == id) + if md5 is not None: + select_stmt = select_stmt.where(MapsTable.md5 == md5) + if filename is not None: + select_stmt = select_stmt.where(MapsTable.filename == filename) - return cast(Map, dict(map._mapping)) if map is not None else None + map = await app.state.services.database.fetch_one(select_stmt) + return cast(Map | None, map) async def fetch_count( @@ -212,32 +218,27 @@ async def fetch_count( frozen: bool | None = None, ) -> int: """Fetch the number of maps in the database.""" - query = """\ - SELECT COUNT(*) AS count - FROM maps - WHERE server = COALESCE(:server, server) - AND set_id = COALESCE(:set_id, set_id) - AND status = COALESCE(:status, status) - AND artist = COALESCE(:artist, artist) - AND creator = COALESCE(:creator, creator) - AND filename = COALESCE(:filename, filename) - AND mode = COALESCE(:mode, mode) - AND frozen = COALESCE(:frozen, frozen) - - """ - params: dict[str, Any] = { - "server": server, - "set_id": set_id, - "status": status, - "artist": artist, - "creator": creator, - "filename": filename, - "mode": mode, - "frozen": frozen, - } - rec = await app.state.services.database.fetch_one(query, params) + select_stmt = select(func.count().label("count")).select_from(MapsTable) + if server is not None: + select_stmt = select_stmt.where(MapsTable.server == server) + if set_id is not None: + select_stmt = select_stmt.where(MapsTable.set_id == set_id) + if status is not None: + select_stmt = select_stmt.where(MapsTable.status == status) + if artist is not None: + select_stmt = select_stmt.where(MapsTable.artist == artist) + if creator is not None: + select_stmt = select_stmt.where(MapsTable.creator == creator) + if filename is not None: + select_stmt = select_stmt.where(MapsTable.filename == filename) + if mode is not None: + select_stmt = select_stmt.where(MapsTable.mode == mode) + if frozen is not None: + select_stmt = select_stmt.where(MapsTable.frozen == frozen) + + rec = await app.state.services.database.fetch_one(select_stmt) assert rec is not None - return cast(int, rec._mapping["count"]) + return cast(int, rec["count"]) async def fetch_many( @@ -253,42 +254,32 @@ async def fetch_many( page_size: int | None = None, ) -> list[Map]: """Fetch a list of maps from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM maps - WHERE server = COALESCE(:server, server) - AND set_id = COALESCE(:set_id, set_id) - AND status = COALESCE(:status, status) - AND artist = COALESCE(:artist, artist) - AND creator = COALESCE(:creator, creator) - AND filename = COALESCE(:filename, filename) - AND mode = COALESCE(:mode, mode) - AND frozen = COALESCE(:frozen, frozen) - """ - params: dict[str, Any] = { - "server": server, - "set_id": set_id, - "status": status, - "artist": artist, - "creator": creator, - "filename": filename, - "mode": mode, - "frozen": frozen, - } + select_stmt = select(*READ_PARAMS) + if server is not None: + select_stmt = select_stmt.where(MapsTable.server == server) + if set_id is not None: + select_stmt = select_stmt.where(MapsTable.set_id == set_id) + if status is not None: + select_stmt = select_stmt.where(MapsTable.status == status) + if artist is not None: + select_stmt = select_stmt.where(MapsTable.artist == artist) + if creator is not None: + select_stmt = select_stmt.where(MapsTable.creator == creator) + if filename is not None: + select_stmt = select_stmt.where(MapsTable.filename == filename) + if mode is not None: + select_stmt = select_stmt.where(MapsTable.mode == mode) + if frozen is not None: + select_stmt = select_stmt.where(MapsTable.frozen == frozen) if page is not None and page_size is not None: - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["limit"] = page_size - params["offset"] = (page - 1) * page_size + select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) - maps = await app.state.services.database.fetch_all(query, params) - return cast(list[Map], [dict(m._mapping) for m in maps]) + maps = await app.state.services.database.fetch_all(select_stmt) + return cast(list[Map], maps) -async def update( +async def partial_update( id: int, server: str | _UnsetSentinel = UNSET, set_id: int | _UnsetSentinel = UNSET, @@ -314,94 +305,66 @@ async def update( diff: float | _UnsetSentinel = UNSET, ) -> Map | None: """Update a beatmap entry in the database.""" - update_fields: MapUpdateFields = {} + update_stmt = update(MapsTable).where(MapsTable.id == id) if not isinstance(server, _UnsetSentinel): - update_fields["server"] = server + update_stmt = update_stmt.values(server=server) if not isinstance(set_id, _UnsetSentinel): - update_fields["set_id"] = set_id + update_stmt = update_stmt.values(set_id=set_id) if not isinstance(status, _UnsetSentinel): - update_fields["status"] = status + update_stmt = update_stmt.values(status=status) if not isinstance(md5, _UnsetSentinel): - update_fields["md5"] = md5 + update_stmt = update_stmt.values(md5=md5) if not isinstance(artist, _UnsetSentinel): - update_fields["artist"] = artist + update_stmt = update_stmt.values(artist=artist) if not isinstance(title, _UnsetSentinel): - update_fields["title"] = title + update_stmt = update_stmt.values(title=title) if not isinstance(version, _UnsetSentinel): - update_fields["version"] = version + update_stmt = update_stmt.values(version=version) if not isinstance(creator, _UnsetSentinel): - update_fields["creator"] = creator + update_stmt = update_stmt.values(creator=creator) if not isinstance(filename, _UnsetSentinel): - update_fields["filename"] = filename + update_stmt = update_stmt.values(filename=filename) if not isinstance(last_update, _UnsetSentinel): - update_fields["last_update"] = last_update + update_stmt = update_stmt.values(last_update=last_update) if not isinstance(total_length, _UnsetSentinel): - update_fields["total_length"] = total_length + update_stmt = update_stmt.values(total_length=total_length) if not isinstance(max_combo, _UnsetSentinel): - update_fields["max_combo"] = max_combo + update_stmt = update_stmt.values(max_combo=max_combo) if not isinstance(frozen, _UnsetSentinel): - update_fields["frozen"] = frozen + update_stmt = update_stmt.values(frozen=frozen) if not isinstance(plays, _UnsetSentinel): - update_fields["plays"] = plays + update_stmt = update_stmt.values(plays=plays) if not isinstance(passes, _UnsetSentinel): - update_fields["passes"] = passes + update_stmt = update_stmt.values(passes=passes) if not isinstance(mode, _UnsetSentinel): - update_fields["mode"] = mode + update_stmt = update_stmt.values(mode=mode) if not isinstance(bpm, _UnsetSentinel): - update_fields["bpm"] = bpm + update_stmt = update_stmt.values(bpm=bpm) if not isinstance(cs, _UnsetSentinel): - update_fields["cs"] = cs + update_stmt = update_stmt.values(cs=cs) if not isinstance(ar, _UnsetSentinel): - update_fields["ar"] = ar + update_stmt = update_stmt.values(ar=ar) if not isinstance(od, _UnsetSentinel): - update_fields["od"] = od + update_stmt = update_stmt.values(od=od) if not isinstance(hp, _UnsetSentinel): - update_fields["hp"] = hp + update_stmt = update_stmt.values(hp=hp) if not isinstance(diff, _UnsetSentinel): - update_fields["diff"] = diff - - query = f"""\ - UPDATE maps - SET {",".join(f"{k} = COALESCE(:{k}, {k})" for k in update_fields)} - WHERE id = :id - """ - params: dict[str, Any] = { - "id": id, - } | update_fields - await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM maps - WHERE id = :id - """ - params = { - "id": id, - } - map = await app.state.services.database.fetch_one(query, params) - return cast(Map, dict(map._mapping)) if map is not None else None - - -async def delete(id: int) -> Map | None: + update_stmt = update_stmt.values(diff=diff) + + await app.state.services.database.execute(update_stmt) + + select_stmt = select(*READ_PARAMS).where(MapsTable.id == id) + map = await app.state.services.database.fetch_one(select_stmt) + return cast(Map | None, map) + + +async def delete_one(id: int) -> Map | None: """Delete a beatmap entry from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM maps - WHERE id = :id - """ - params: dict[str, Any] = { - "id": id, - } - map = await app.state.services.database.fetch_one(query, params) + select_stmt = select(*READ_PARAMS).where(MapsTable.id == id) + map = await app.state.services.database.fetch_one(select_stmt) if map is None: return None - query = """\ - DELETE FROM maps - WHERE id = :id - """ - params = { - "id": id, - } - await app.state.services.database.execute(query, params) - return cast(Map, dict(map._mapping)) if map is not None else None + delete_stmt = delete(MapsTable).where(MapsTable.id == id) + await app.state.services.database.execute(delete_stmt) + return cast(Map, map) diff --git a/app/repositories/ratings.py b/app/repositories/ratings.py index c60fa0f20..a603555e6 100644 --- a/app/repositories/ratings.py +++ b/app/repositories/ratings.py @@ -1,24 +1,31 @@ from __future__ import annotations -import textwrap -from typing import Any from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import insert +from sqlalchemy import select +from sqlalchemy.dialects.mysql import TINYINT + import app.state.services +from app.repositories import Base + + +class RatingsTable(Base): + __tablename__ = "ratings" + + userid = Column("userid", Integer, nullable=False, primary_key=True) + map_md5 = Column("map_md5", String(32), nullable=False, primary_key=True) + rating = Column("rating", TINYINT(2), nullable=False) -# +---------+----------+------+-----+---------+-------+ -# | Field | Type | Null | Key | Default | Extra | -# +---------+----------+------+-----+---------+-------+ -# | userid | int | NO | PRI | NULL | | -# | map_md5 | char(32) | NO | PRI | NULL | | -# | rating | tinyint | NO | | NULL | | -# +---------+----------+------+-----+---------+-------+ - -READ_PARAMS = textwrap.dedent( - """\ - userid, map_md5, rating - """, + +READ_PARAMS = ( + RatingsTable.userid, + RatingsTable.map_md5, + RatingsTable.rating, ) @@ -30,31 +37,21 @@ class Rating(TypedDict): async def create(userid: int, map_md5: str, rating: int) -> Rating: """Create a new rating.""" - query = """\ - INSERT INTO ratings (userid, map_md5, rating) - VALUES (:userid, :map_md5, :rating) - """ - params: dict[str, Any] = { - "userid": userid, - "map_md5": map_md5, - "rating": rating, - } - await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM ratings - WHERE userid = :userid - AND map_md5 = :map_md5 - """ - params = { - "userid": userid, - "map_md5": map_md5, - } - _rating = await app.state.services.database.fetch_one(query, params) + insert_stmt = insert(RatingsTable).values( + userid=userid, + map_md5=map_md5, + rating=rating, + ) + await app.state.services.database.execute(insert_stmt) + select_stmt = ( + select(*READ_PARAMS) + .where(RatingsTable.userid == userid) + .where(RatingsTable.map_md5 == map_md5) + ) + _rating = await app.state.services.database.fetch_one(select_stmt) assert _rating is not None - return cast(Rating, dict(_rating._mapping)) + return cast(Rating, _rating) async def fetch_many( @@ -64,39 +61,25 @@ async def fetch_many( page_size: int | None = 50, ) -> list[Rating]: """Fetch multiple ratings, optionally with filter params and pagination.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM ratings - WHERE userid = COALESCE(:userid, userid) - AND map_md5 = COALESCE(:map_md5, map_md5) - """ - params: dict[str, Any] = { - "userid": userid, - "map_md5": map_md5, - } + select_stmt = select(*READ_PARAMS) + if userid is not None: + select_stmt = select_stmt.where(RatingsTable.userid == userid) + if map_md5 is not None: + select_stmt = select_stmt.where(RatingsTable.map_md5 == map_md5) + if page is not None and page_size is not None: - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["limit"] = page_size - params["offset"] = (page - 1) * page_size - ratings = await app.state.services.database.fetch_all(query, params) - return cast(list[Rating], [dict(r._mapping) for r in ratings]) + select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) + + ratings = await app.state.services.database.fetch_all(select_stmt) + return cast(list[Rating], ratings) async def fetch_one(userid: int, map_md5: str) -> Rating | None: """Fetch a single rating for a given user and map.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM ratings - WHERE userid = :userid - AND map_md5 = :map_md5 - """ - params: dict[str, Any] = { - "userid": userid, - "map_md5": map_md5, - } - - rating = await app.state.services.database.fetch_one(query, params) - return cast(Rating, dict(rating._mapping)) if rating else None + select_stmt = ( + select(*READ_PARAMS) + .where(RatingsTable.userid == userid) + .where(RatingsTable.map_md5 == map_md5) + ) + rating = await app.state.services.database.fetch_one(select_stmt) + return cast(Rating | None, rating) diff --git a/app/repositories/scores.py b/app/repositories/scores.py index d1b9dc4fc..94868e8bf 100644 --- a/app/repositories/scores.py +++ b/app/repositories/scores.py @@ -19,14 +19,13 @@ import app.state.services from app._typing import UNSET from app._typing import _UnsetSentinel -from app.repositories import DIALECT from app.repositories import Base class ScoresTable(Base): __tablename__ = "scores" - id = Column("id", Integer, primary_key=True) + id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True) map_md5 = Column("map_md5", String(32), nullable=False) score = Column("score", Integer, nullable=False) pp = Column("pp", FLOAT(precision=6, scale=3), nullable=False) @@ -159,31 +158,18 @@ async def create( perfect=perfect, online_checksum=online_checksum, ) - compiled = insert_stmt.compile(dialect=DIALECT) - rec_id = await app.state.services.database.execute( - query=str(compiled), - values=compiled.params, - ) + rec_id = await app.state.services.database.execute(insert_stmt) - select_stmt = select(READ_PARAMS).where(ScoresTable.id == rec_id) - compiled = select_stmt.compile(dialect=DIALECT) - rec = await app.state.services.database.fetch_one( - query=str(compiled), - values=compiled.params, - ) - assert rec is not None - return cast(Score, dict(rec._mapping)) + select_stmt = select(*READ_PARAMS).where(ScoresTable.id == rec_id) + _score = await app.state.services.database.fetch_one(select_stmt) + assert _score is not None + return cast(Score, _score) async def fetch_one(id: int) -> Score | None: - select_stmt = select(READ_PARAMS).where(ScoresTable.id == id) - compiled = select_stmt.compile(dialect=DIALECT) - rec = await app.state.services.database.fetch_one( - query=str(compiled), - values=compiled.params, - ) - - return cast(Score, dict(rec._mapping)) if rec is not None else None + select_stmt = select(*READ_PARAMS).where(ScoresTable.id == id) + _score = await app.state.services.database.fetch_one(select_stmt) + return cast(Score | None, _score) async def fetch_count( @@ -205,13 +191,9 @@ async def fetch_count( if user_id is not None: select_stmt = select_stmt.where(ScoresTable.userid == user_id) - compiled = select_stmt.compile(dialect=DIALECT) - rec = await app.state.services.database.fetch_one( - query=str(compiled), - values=compiled.params, - ) + rec = await app.state.services.database.fetch_one(select_stmt) assert rec is not None - return cast(int, rec._mapping["count"]) + return cast(int, rec["count"]) async def fetch_many( @@ -223,7 +205,7 @@ async def fetch_many( page: int | None = None, page_size: int | None = None, ) -> list[Score]: - select_stmt = select(READ_PARAMS) + select_stmt = select(*READ_PARAMS) if map_md5 is not None: select_stmt = select_stmt.where(ScoresTable.map_md5 == map_md5) if mods is not None: @@ -238,12 +220,8 @@ async def fetch_many( if page is not None and page_size is not None: select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) - compiled = select_stmt.compile(dialect=DIALECT) - recs = await app.state.services.database.fetch_all( - query=str(compiled), - values=compiled.params, - ) - return cast(list[Score], [dict(r._mapping) for r in recs]) + scores = await app.state.services.database.fetch_all(select_stmt) + return cast(list[Score], scores) async def partial_update( @@ -257,19 +235,12 @@ async def partial_update( update_stmt = update_stmt.values(pp=pp) if not isinstance(status, _UnsetSentinel): update_stmt = update_stmt.values(status=status) - compiled = update_stmt.compile(dialect=DIALECT) - await app.state.services.database.execute( - query=str(compiled), - values=compiled.params, - ) - select_stmt = select(READ_PARAMS).where(ScoresTable.id == id) - compiled = select_stmt.compile(dialect=DIALECT) - rec = await app.state.services.database.fetch_one( - query=str(compiled), - values=compiled.params, - ) - return cast(Score, dict(rec._mapping)) if rec is not None else None + await app.state.services.database.execute(update_stmt) + + select_stmt = select(*READ_PARAMS).where(ScoresTable.id == id) + _score = await app.state.services.database.fetch_one(select_stmt) + return cast(Score | None, _score) # TODO: delete diff --git a/app/repositories/stats.py b/app/repositories/stats.py index 12177f875..ee8a9f7f4 100644 --- a/app/repositories/stats.py +++ b/app/repositories/stats.py @@ -10,22 +10,19 @@ from sqlalchemy import insert from sqlalchemy import select from sqlalchemy import update - -# from sqlalchemy import update from sqlalchemy.dialects.mysql import FLOAT from sqlalchemy.dialects.mysql import TINYINT import app.state.services from app._typing import UNSET from app._typing import _UnsetSentinel -from app.repositories import DIALECT from app.repositories import Base class StatsTable(Base): __tablename__ = "stats" - id = Column("id", Integer, primary_key=True) + id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True) mode = Column("mode", TINYINT(1), primary_key=True) tscore = Column("tscore", Integer, nullable=False, server_default="0") rscore = Column("rscore", Integer, nullable=False, server_default="0") @@ -97,20 +94,12 @@ class Stat(TypedDict): async def create(player_id: int, mode: int) -> Stat: """Create a new player stats entry in the database.""" insert_stmt = insert(StatsTable).values(id=player_id, mode=mode) - compiled = insert_stmt.compile(dialect=DIALECT) - rec_id = await app.state.services.database.execute( - query=str(compiled), - values=compiled.params, - ) + rec_id = await app.state.services.database.execute(insert_stmt) - select_stmt = select(READ_PARAMS).where(StatsTable.id == rec_id) - compiled = select_stmt.compile(dialect=DIALECT) - stat = await app.state.services.database.fetch_one( - query=str(compiled), - values=compiled.params, - ) + select_stmt = select(*READ_PARAMS).where(StatsTable.id == rec_id) + stat = await app.state.services.database.fetch_one(select_stmt) assert stat is not None - return cast(Stat, dict(stat._mapping)) + return cast(Stat, stat) async def create_all_modes(player_id: int) -> list[Stat]: @@ -130,31 +119,22 @@ async def create_all_modes(player_id: int) -> list[Stat]: ) ], ) - compiled = insert_stmt.compile(dialect=DIALECT) - await app.state.services.database.execute(str(compiled), compiled.params) + await app.state.services.database.execute(insert_stmt) - select_stmt = select(READ_PARAMS).where(StatsTable.id == player_id) - compiled = select_stmt.compile(dialect=DIALECT) - stats = await app.state.services.database.fetch_all( - query=str(compiled), - values=compiled.params, - ) - return cast(list[Stat], [dict(s._mapping) for s in stats]) + select_stmt = select(*READ_PARAMS).where(StatsTable.id == player_id) + stats = await app.state.services.database.fetch_all(select_stmt) + return cast(list[Stat], stats) async def fetch_one(player_id: int, mode: int) -> Stat | None: """Fetch a player stats entry from the database.""" select_stmt = ( - select(READ_PARAMS) + select(*READ_PARAMS) .where(StatsTable.id == player_id) .where(StatsTable.mode == mode) ) - compiled = select_stmt.compile(dialect=DIALECT) - stat = await app.state.services.database.fetch_one( - query=str(compiled), - values=compiled.params, - ) - return cast(Stat, dict(stat._mapping)) if stat is not None else None + stat = await app.state.services.database.fetch_one(select_stmt) + return cast(Stat | None, stat) async def fetch_count( @@ -166,13 +146,10 @@ async def fetch_count( select_stmt = select_stmt.where(StatsTable.id == player_id) if mode is not None: select_stmt = select_stmt.where(StatsTable.mode == mode) - compiled = select_stmt.compile(dialect=DIALECT) - rec = await app.state.services.database.fetch_one( - query=str(compiled), - values=compiled.params, - ) + + rec = await app.state.services.database.fetch_one(select_stmt) assert rec is not None - return cast(int, rec._mapping["count"]) + return cast(int, rec["count"]) async def fetch_many( @@ -181,19 +158,16 @@ async def fetch_many( page: int | None = None, page_size: int | None = None, ) -> list[Stat]: - select_stmt = select(READ_PARAMS) + select_stmt = select(*READ_PARAMS) if player_id is not None: select_stmt = select_stmt.where(StatsTable.id == player_id) if mode is not None: select_stmt = select_stmt.where(StatsTable.mode == mode) if page is not None and page_size is not None: select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) - compiled = select_stmt.compile(dialect=DIALECT) - stats = await app.state.services.database.fetch_all( - query=str(compiled), - values=compiled.params, - ) - return cast(list[Stat], [dict(s._mapping) for s in stats]) + + stats = await app.state.services.database.fetch_all(select_stmt) + return cast(list[Stat], stats) async def partial_update( @@ -249,20 +223,15 @@ async def partial_update( if not isinstance(a_count, _UnsetSentinel): update_stmt = update_stmt.values(a_count=a_count) - compiled = update_stmt.compile(dialect=DIALECT) - await app.state.services.database.execute(str(compiled), compiled.params) + await app.state.services.database.execute(update_stmt) select_stmt = ( - select(READ_PARAMS) + select(*READ_PARAMS) .where(StatsTable.id == player_id) .where(StatsTable.mode == mode) ) - compiled = select_stmt.compile(dialect=DIALECT) - stat = await app.state.services.database.fetch_one( - query=str(compiled), - values=compiled.params, - ) - return cast(Stat, dict(stat._mapping)) if stat is not None else None + stat = await app.state.services.database.fetch_one(select_stmt) + return cast(Stat | None, stat) # TODO: delete? diff --git a/app/repositories/tourney_pool_maps.py b/app/repositories/tourney_pool_maps.py index 0e473410b..476da1d05 100644 --- a/app/repositories/tourney_pool_maps.py +++ b/app/repositories/tourney_pool_maps.py @@ -1,20 +1,39 @@ from __future__ import annotations -import textwrap -from typing import Any from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import delete +from sqlalchemy import insert +from sqlalchemy import select + import app.state.services +from app.repositories import Base + + +class TourneyPoolMapsTable(Base): + __tablename__ = "tourney_pool_maps" -# +---------+---------+------+-----+---------+-------+ -# | Field | Type | Null | Key | Default | Extra | -# +---------+---------+------+-----+---------+-------+ -# | map_id | int | NO | PRI | NULL | | -# | pool_id | int | NO | PRI | NULL | | -# | mods | int | NO | | NULL | | -# | slot | tinyint | NO | | NULL | | -# +---------+---------+------+-----+---------+-------+ + map_id = Column("map_id", Integer, nullable=False, primary_key=True) + pool_id = Column("pool_id", Integer, nullable=False, primary_key=True) + mods = Column("mods", Integer, nullable=False) + slot = Column("slot", Integer, nullable=False) + + __table_args__ = ( + Index("tourney_pool_maps_mods_slot_index", mods, slot), + Index("tourney_pool_maps_tourney_pools_id_fk", pool_id), + ) + + +READ_PARAMS = ( + TourneyPoolMapsTable.map_id, + TourneyPoolMapsTable.pool_id, + TourneyPoolMapsTable.mods, + TourneyPoolMapsTable.slot, +) class TourneyPoolMap(TypedDict): @@ -24,45 +43,24 @@ class TourneyPoolMap(TypedDict): slot: int -READ_PARAMS = textwrap.dedent( - """\ - map_id, pool_id, mods, slot - """, -) - - async def create(map_id: int, pool_id: int, mods: int, slot: int) -> TourneyPoolMap: """Create a new map pool entry in the database.""" - query = f"""\ - INSERT INTO tourney_pool_maps (map_id, pool_id, mods, slot) - VALUES (:map_id, :pool_id, :mods, :slot) - """ - params: dict[str, Any] = { - "map_id": map_id, - "pool_id": pool_id, - "mods": mods, - "slot": slot, - } - await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM tourney_pool_maps - WHERE map_id = :map_id - AND pool_id = :pool_id - AND mods = :mods - AND slot = :slot - """ - params = { - "map_id": map_id, - "pool_id": pool_id, - "mods": mods, - "slot": slot, - } - tourney_pool_map = await app.state.services.database.fetch_one(query, params) + insert_stmt = insert(TourneyPoolMapsTable).values( + map_id=map_id, + pool_id=pool_id, + mods=mods, + slot=slot, + ) + await app.state.services.database.execute(insert_stmt) + select_stmt = ( + select(*READ_PARAMS) + .where(TourneyPoolMapsTable.map_id == map_id) + .where(TourneyPoolMapsTable.pool_id == pool_id) + ) + tourney_pool_map = await app.state.services.database.fetch_one(select_stmt) assert tourney_pool_map is not None - return cast(TourneyPoolMap, dict(tourney_pool_map._mapping)) + return cast(TourneyPoolMap, tourney_pool_map) async def fetch_many( @@ -73,30 +71,18 @@ async def fetch_many( page_size: int | None = 50, ) -> list[TourneyPoolMap]: """Fetch a list of map pool entries from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM tourney_pool_maps - WHERE pool_id = COALESCE(:pool_id, pool_id) - AND mods = COALESCE(:mods, mods) - AND slot = COALESCE(:slot, slot) - """ - params: dict[str, Any] = { - "pool_id": pool_id, - "mods": mods, - "slot": slot, - } + select_stmt = select(*READ_PARAMS) + if pool_id is not None: + select_stmt = select_stmt.where(TourneyPoolMapsTable.pool_id == pool_id) + if mods is not None: + select_stmt = select_stmt.where(TourneyPoolMapsTable.mods == mods) + if slot is not None: + select_stmt = select_stmt.where(TourneyPoolMapsTable.slot == slot) if page and page_size: - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["limit"] = page_size - params["offset"] = (page - 1) * page_size - tourney_pool_maps = await app.state.services.database.fetch_all(query, params) - return cast( - list[TourneyPoolMap], - [dict(tourney_pool_map._mapping) for tourney_pool_map in tourney_pool_maps], - ) + select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) + + tourney_pool_maps = await app.state.services.database.fetch_all(select_stmt) + return cast(list[TourneyPoolMap], tourney_pool_maps) async def fetch_by_pool_and_pick( @@ -105,76 +91,47 @@ async def fetch_by_pool_and_pick( slot: int, ) -> TourneyPoolMap | None: """Fetch a map pool entry by pool and pick from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM tourney_pool_maps - WHERE pool_id = :pool_id - AND mods = :mods - AND slot = :slot - """ - params: dict[str, Any] = { - "pool_id": pool_id, - "mods": mods, - "slot": slot, - } - tourney_pool_map = await app.state.services.database.fetch_one(query, params) - if tourney_pool_map is None: - return None - return cast(TourneyPoolMap, dict(tourney_pool_map._mapping)) + select_stmt = ( + select(*READ_PARAMS) + .where(TourneyPoolMapsTable.pool_id == pool_id) + .where(TourneyPoolMapsTable.mods == mods) + .where(TourneyPoolMapsTable.slot == slot) + ) + tourney_pool_map = await app.state.services.database.fetch_one(select_stmt) + return cast(TourneyPoolMap | None, tourney_pool_map) async def delete_map_from_pool(pool_id: int, map_id: int) -> TourneyPoolMap | None: """Delete a map pool entry from a given tourney pool from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM tourney_pool_maps - WHERE pool_id = :pool_id - AND map_id = :map_id - """ - params: dict[str, Any] = { - "pool_id": pool_id, - "map_id": map_id, - } - tourney_pool_map = await app.state.services.database.fetch_one(query, params) + select_stmt = ( + select(*READ_PARAMS) + .where(TourneyPoolMapsTable.pool_id == pool_id) + .where(TourneyPoolMapsTable.map_id == map_id) + ) + + tourney_pool_map = await app.state.services.database.fetch_one(select_stmt) if tourney_pool_map is None: return None - query = f"""\ - DELETE FROM tourney_pool_maps - WHERE pool_id = :pool_id - AND map_id = :map_id - """ - params = { - "pool_id": pool_id, - "map_id": map_id, - } - await app.state.services.database.execute(query, params) - return cast(TourneyPoolMap, dict(tourney_pool_map._mapping)) + delete_stmt = ( + delete(TourneyPoolMapsTable) + .where(TourneyPoolMapsTable.pool_id == pool_id) + .where(TourneyPoolMapsTable.map_id == map_id) + ) + + await app.state.services.database.execute(delete_stmt) + return cast(TourneyPoolMap, tourney_pool_map) async def delete_all_in_pool(pool_id: int) -> list[TourneyPoolMap]: """Delete all map pool entries from a given tourney pool from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM tourney_pool_maps - WHERE pool_id = :pool_id - """ - params: dict[str, Any] = { - "pool_id": pool_id, - } - tourney_pool_maps = await app.state.services.database.fetch_all(query, params) + select_stmt = select(*READ_PARAMS).where(TourneyPoolMapsTable.pool_id == pool_id) + tourney_pool_maps = await app.state.services.database.fetch_all(select_stmt) if not tourney_pool_maps: return [] - query = f"""\ - DELETE FROM tourney_pool_maps - WHERE pool_id = :pool_id - """ - params = { - "pool_id": pool_id, - } - await app.state.services.database.execute(query, params) - return cast( - list[TourneyPoolMap], - [dict(tourney_pool_map._mapping) for tourney_pool_map in tourney_pool_maps], + delete_stmt = delete(TourneyPoolMapsTable).where( + TourneyPoolMapsTable.pool_id == pool_id, ) + await app.state.services.database.execute(delete_stmt) + return cast(list[TourneyPoolMap], tourney_pool_maps) diff --git a/app/repositories/tourney_pools.py b/app/repositories/tourney_pools.py index dade1a2ed..af1112773 100644 --- a/app/repositories/tourney_pools.py +++ b/app/repositories/tourney_pools.py @@ -1,21 +1,32 @@ from __future__ import annotations -import textwrap from datetime import datetime -from typing import Any from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import delete +from sqlalchemy import func +from sqlalchemy import insert +from sqlalchemy import select + import app.state.services +from app.repositories import Base + + +class TourneyPoolsTable(Base): + __tablename__ = "tourney_pools" + + id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True) + name = Column("name", String(16), nullable=False) + created_at = Column("created_at", DateTime, nullable=False) + created_by = Column("created_by", Integer, nullable=False) -# +------------+-------------+------+-----+---------+----------------+ -# | Field | Type | Null | Key | Default | Extra | -# +------------+-------------+------+-----+---------+----------------+ -# | id | int | NO | PRI | NULL | auto_increment | -# | name | varchar(16) | NO | | NULL | | -# | created_at | datetime | NO | | NULL | | -# | created_by | int | NO | MUL | NULL | | -# +------------+-------------+------+-----+---------+----------------+ + __table_args__ = (Index("tourney_pools_users_id_fk", created_by),) class TourneyPool(TypedDict): @@ -25,37 +36,27 @@ class TourneyPool(TypedDict): created_by: int -READ_PARAMS = textwrap.dedent( - """\ - id, name, created_at, created_by - """, +READ_PARAMS = ( + TourneyPoolsTable.id, + TourneyPoolsTable.name, + TourneyPoolsTable.created_at, + TourneyPoolsTable.created_by, ) async def create(name: str, created_by: int) -> TourneyPool: """Create a new tourney pool entry in the database.""" - query = f"""\ - INSERT INTO tourney_pools (name, created_at, created_by) - VALUES (:name, NOW(), :user_id) - """ - params: dict[str, Any] = { - "name": name, - "user_id": created_by, - } - rec_id = await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM tourney_pools - WHERE id = :id - """ - params = { - "id": rec_id, - } - tourney_pool = await app.state.services.database.fetch_one(query, params) + insert_stmt = insert(TourneyPoolsTable).values( + name=name, + created_at=func.now(), + created_by=created_by, + ) + rec_id = await app.state.services.database.execute(insert_stmt) + select_stmt = select(*READ_PARAMS).where(TourneyPoolsTable.id == rec_id) + tourney_pool = await app.state.services.database.fetch_one(select_stmt) assert tourney_pool is not None - return cast(TourneyPool, dict(tourney_pool._mapping)) + return cast(TourneyPool, tourney_pool) async def fetch_many( @@ -64,85 +65,40 @@ async def fetch_many( page: int | None = 1, page_size: int | None = 50, ) -> list[TourneyPool]: - query = f"""\ - SELECT {READ_PARAMS} - FROM tourney_pools - WHERE id = COALESCE(:id, id) - AND created_by = COALESCE(:created_by, created_by) - """ - params: dict[str, Any] = { - "id": id, - "created_by": created_by, - } + """Fetch many tourney pools from the database.""" + select_stmt = select(*READ_PARAMS) + if id is not None: + select_stmt = select_stmt.where(TourneyPoolsTable.id == id) + if created_by is not None: + select_stmt = select_stmt.where(TourneyPoolsTable.created_by == created_by) if page and page_size: - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["limit"] = page_size - params["offset"] = (page - 1) * page_size - tourney_pools = await app.state.services.database.fetch_all(query, params) - return [ - cast(TourneyPool, dict(tourney_pool._mapping)) for tourney_pool in tourney_pools - ] + select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) + + tourney_pools = await app.state.services.database.fetch_all(select_stmt) + return cast(list[TourneyPool], tourney_pools) async def fetch_by_name(name: str) -> TourneyPool | None: """Fetch a tourney pool by name from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM tourney_pools - WHERE name = :name - """ - params: dict[str, Any] = { - "name": name, - } - tourney_pool = await app.state.services.database.fetch_one(query, params) - return ( - cast(TourneyPool, dict(tourney_pool._mapping)) - if tourney_pool is not None - else None - ) + select_stmt = select(*READ_PARAMS).where(TourneyPoolsTable.name == name) + tourney_pool = await app.state.services.database.fetch_one(select_stmt) + return cast(TourneyPool | None, tourney_pool) async def fetch_by_id(id: int) -> TourneyPool | None: """Fetch a tourney pool by id from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM tourney_pools - WHERE id = :id - """ - params: dict[str, Any] = { - "id": id, - } - tourney_pool = await app.state.services.database.fetch_one(query, params) - return ( - cast(TourneyPool, dict(tourney_pool._mapping)) - if tourney_pool is not None - else None - ) + select_stmt = select(*READ_PARAMS).where(TourneyPoolsTable.id == id) + tourney_pool = await app.state.services.database.fetch_one(select_stmt) + return cast(TourneyPool | None, tourney_pool) async def delete_by_id(id: int) -> TourneyPool | None: """Delete a tourney pool by id from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM tourney_pools - WHERE id = :id - """ - params: dict[str, Any] = { - "id": id, - } - tourney_pool = await app.state.services.database.fetch_one(query, params) + select_stmt = select(*READ_PARAMS).where(TourneyPoolsTable.id == id) + tourney_pool = await app.state.services.database.fetch_one(select_stmt) if tourney_pool is None: return None - query = f"""\ - DELETE FROM tourney_pools - WHERE id = :id - """ - params = { - "id": id, - } - await app.state.services.database.execute(query, params) - return cast(TourneyPool, dict(tourney_pool._mapping)) + delete_stmt = delete(TourneyPoolsTable).where(TourneyPoolsTable.id == id) + await app.state.services.database.execute(delete_stmt) + return cast(TourneyPool, tourney_pool) diff --git a/app/repositories/user_achievements.py b/app/repositories/user_achievements.py index 75ffb6c25..1c3065a06 100644 --- a/app/repositories/user_achievements.py +++ b/app/repositories/user_achievements.py @@ -1,25 +1,35 @@ from __future__ import annotations -import textwrap -from typing import Any from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import insert +from sqlalchemy import select + import app.state.services from app._typing import UNSET from app._typing import _UnsetSentinel +from app.repositories import Base + + +class UserAchievementsTable(Base): + __tablename__ = "user_achievements" + + userid = Column("userid", Integer, nullable=False, primary_key=True) + achid = Column("achid", Integer, nullable=False, primary_key=True) + + __table_args__ = ( + Index("user_achievements_achid_index", achid), + Index("user_achievements_userid_index", userid), + ) -# create table user_achievements -# ( -# userid int not null, -# achid int not null, -# primary key (userid, achid) -# ); - -READ_PARAMS = textwrap.dedent( - """\ - userid, achid - """, + +READ_PARAMS = ( + UserAchievementsTable.userid, + UserAchievementsTable.achid, ) @@ -30,56 +40,40 @@ class UserAchievement(TypedDict): async def create(user_id: int, achievement_id: int) -> UserAchievement: """Creates a new user achievement entry.""" - query = """\ - INSERT INTO user_achievements (userid, achid) - VALUES (:user_id, :achievement_id) - """ - params: dict[str, Any] = { - "user_id": user_id, - "achievement_id": achievement_id, - } - await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM user_achievements - WHERE userid = :user_id - AND achid = :achievement_id - """ - user_achievement = await app.state.services.database.fetch_one(query, params) - + insert_stmt = insert(UserAchievementsTable).values( + userid=user_id, + achid=achievement_id, + ) + await app.state.services.database.execute(insert_stmt) + + select_stmt = ( + select(*READ_PARAMS) + .where(UserAchievementsTable.userid == user_id) + .where(UserAchievementsTable.achid == achievement_id) + ) + user_achievement = await app.state.services.database.fetch_one(select_stmt) assert user_achievement is not None - return cast(UserAchievement, dict(user_achievement._mapping)) + return cast(UserAchievement, user_achievement) async def fetch_many( - user_id: int, - page: int | _UnsetSentinel = UNSET, - page_size: int | _UnsetSentinel = UNSET, + user_id: int | _UnsetSentinel = UNSET, + achievement_id: int | _UnsetSentinel = UNSET, + page: int | None = None, + page_size: int | None = None, ) -> list[UserAchievement]: """Fetch a list of user achievements.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM user_achievements - WHERE userid = :user_id - """ - params: dict[str, Any] = { - "user_id": user_id, - } - - if not isinstance(page, _UnsetSentinel) and not isinstance( - page_size, - _UnsetSentinel, - ): - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["page_size"] = page_size - params["offset"] = (page - 1) * page_size - - user_achievements = await app.state.services.database.fetch_all(query, params) - return cast(list[UserAchievement], [dict(a._mapping) for a in user_achievements]) + select_stmt = select(*READ_PARAMS) + if not isinstance(user_id, _UnsetSentinel): + select_stmt = select_stmt.where(UserAchievementsTable.userid == user_id) + if not isinstance(achievement_id, _UnsetSentinel): + select_stmt = select_stmt.where(UserAchievementsTable.achid == achievement_id) + + if page and page_size: + select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) + + user_achievements = await app.state.services.database.fetch_all(select_stmt) + return cast(list[UserAchievement], user_achievements) # TODO: delete? diff --git a/app/repositories/users.py b/app/repositories/users.py index a634f8771..b94d860f0 100644 --- a/app/repositories/users.py +++ b/app/repositories/users.py @@ -1,45 +1,77 @@ from __future__ import annotations -import textwrap -from typing import Any from typing import TypedDict from typing import cast +from sqlalchemy import Column +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import func +from sqlalchemy import insert +from sqlalchemy import select +from sqlalchemy import update +from sqlalchemy.dialects.mysql import TINYINT + import app.state.services from app._typing import UNSET from app._typing import _UnsetSentinel +from app.repositories import Base from app.utils import make_safe_name -# +-------------------+---------------+------+-----+---------+----------------+ -# | Field | Type | Null | Key | Default | Extra | -# +-------------------+---------------+------+-----+---------+----------------+ -# | id | int | NO | PRI | NULL | auto_increment | -# | name | varchar(32) | NO | UNI | NULL | | -# | safe_name | varchar(32) | NO | UNI | NULL | | -# | email | varchar(254) | NO | UNI | NULL | | -# | priv | int | NO | | 1 | | -# | pw_bcrypt | char(60) | NO | | NULL | | -# | country | char(2) | NO | | xx | | -# | silence_end | int | NO | | 0 | | -# | donor_end | int | NO | | 0 | | -# | creation_time | int | NO | | 0 | | -# | latest_activity | int | NO | | 0 | | -# | clan_id | int | NO | | 0 | | -# | clan_priv | tinyint(1) | NO | | 0 | | -# | preferred_mode | int | NO | | 0 | | -# | play_style | int | NO | | 0 | | -# | custom_badge_name | varchar(16) | YES | | NULL | | -# | custom_badge_icon | varchar(64) | YES | | NULL | | -# | userpage_content | varchar(2048) | YES | | NULL | | -# | api_key | char(36) | YES | UNI | NULL | | -# +-------------------+---------------+------+-----+---------+----------------+ - -READ_PARAMS = textwrap.dedent( - """\ - id, name, safe_name, priv, country, silence_end, donor_end, creation_time, - latest_activity, clan_id, clan_priv, preferred_mode, play_style, custom_badge_name, - custom_badge_icon, userpage_content - """, + +class UsersTable(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True, nullable=False, autoincrement=True) + name = Column(String(32, collation="utf8"), nullable=False) + safe_name = Column(String(32, collation="utf8"), nullable=False) + email = Column(String(254), nullable=False) + priv = Column(Integer, nullable=False, server_default="1") + pw_bcrypt = Column(String(60), nullable=False) + country = Column(String(2), nullable=False, server_default="xx") + silence_end = Column(Integer, nullable=False, server_default="0") + donor_end = Column(Integer, nullable=False, server_default="0") + creation_time = Column(Integer, nullable=False, server_default="0") + latest_activity = Column(Integer, nullable=False, server_default="0") + clan_id = Column(Integer, nullable=False, server_default="0") + clan_priv = Column(TINYINT, nullable=False, server_default="0") + preferred_mode = Column(Integer, nullable=False, server_default="0") + play_style = Column(Integer, nullable=False, server_default="0") + custom_badge_name = Column(String(16, collation="utf8")) + custom_badge_icon = Column(String(64)) + userpage_content = Column(String(2048, collation="utf8")) + api_key = Column(String(36)) + + __table_args__ = ( + Index("users_priv_index", priv), + Index("users_clan_id_index", clan_id), + Index("users_clan_priv_index", clan_priv), + Index("users_country_index", country), + Index("users_api_key_uindex", api_key, unique=True), + Index("users_email_uindex", email, unique=True), + Index("users_name_uindex", name, unique=True), + Index("users_safe_name_uindex", safe_name, unique=True), + ) + + +READ_PARAMS = ( + UsersTable.id, + UsersTable.name, + UsersTable.safe_name, + UsersTable.priv, + UsersTable.country, + UsersTable.silence_end, + UsersTable.donor_end, + UsersTable.creation_time, + UsersTable.latest_activity, + UsersTable.clan_id, + UsersTable.clan_priv, + UsersTable.preferred_mode, + UsersTable.play_style, + UsersTable.custom_badge_name, + UsersTable.custom_badge_icon, + UsersTable.userpage_content, ) @@ -64,26 +96,6 @@ class User(TypedDict): api_key: str | None -class UserUpdateFields(TypedDict, total=False): - name: str - safe_name: str - email: str - priv: int - country: str - silence_end: int - donor_end: int - creation_time: int - latest_activity: int - clan_id: int - clan_priv: int - preferred_mode: int - play_style: int - custom_badge_name: str | None - custom_badge_icon: str | None - userpage_content: str | None - api_key: str | None - - async def create( name: str, email: str, @@ -91,31 +103,21 @@ async def create( country: str, ) -> User: """Create a new user in the database.""" - query = f"""\ - INSERT INTO users (name, safe_name, email, pw_bcrypt, country, creation_time, latest_activity) - VALUES (:name, :safe_name, :email, :pw_bcrypt, :country, UNIX_TIMESTAMP(), UNIX_TIMESTAMP()) - """ - params: dict[str, Any] = { - "name": name, - "safe_name": make_safe_name(name), - "email": email, - "pw_bcrypt": pw_bcrypt, - "country": country, - } - rec_id = await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM users - WHERE id = :id - """ - params = { - "id": rec_id, - } - user = await app.state.services.database.fetch_one(query, params) + insert_stmt = insert(UsersTable).values( + name=name, + safe_name=make_safe_name(name), + email=email, + pw_bcrypt=pw_bcrypt, + country=country, + creation_time=func.unix_timestamp(), + latest_activity=func.unix_timestamp(), + ) + rec_id = await app.state.services.database.execute(insert_stmt) + select_stmt = select(*READ_PARAMS).where(UsersTable.id == rec_id) + user = await app.state.services.database.fetch_one(select_stmt) assert user is not None - return cast(User, dict(user._mapping)) + return cast(User, user) async def fetch_one( @@ -128,20 +130,20 @@ async def fetch_one( if id is None and name is None and email is None: raise ValueError("Must provide at least one parameter.") - query = f"""\ - SELECT {'*' if fetch_all_fields else READ_PARAMS} - FROM users - WHERE id = COALESCE(:id, id) - AND safe_name = COALESCE(:safe_name, safe_name) - AND email = COALESCE(:email, email) - """ - params: dict[str, Any] = { - "id": id, - "safe_name": make_safe_name(name) if name is not None else None, - "email": email, - } - user = await app.state.services.database.fetch_one(query, params) - return cast(User, dict(user._mapping)) if user is not None else None + if fetch_all_fields: + select_stmt = select(UsersTable) + else: + select_stmt = select(*READ_PARAMS) + + if id is not None: + select_stmt = select_stmt.where(UsersTable.id == id) + if name is not None: + select_stmt = select_stmt.where(UsersTable.safe_name == make_safe_name(name)) + if email is not None: + select_stmt = select_stmt.where(UsersTable.email == email) + + user = await app.state.services.database.fetch_one(select_stmt) + return cast(User | None, user) async def fetch_count( @@ -153,27 +155,23 @@ async def fetch_count( play_style: int | None = None, ) -> int: """Fetch the number of users in the database.""" - query = """\ - SELECT COUNT(*) AS count - FROM users - WHERE priv = COALESCE(:priv, priv) - AND country = COALESCE(:country, country) - AND clan_id = COALESCE(:clan_id, clan_id) - AND clan_priv = COALESCE(:clan_priv, clan_priv) - AND preferred_mode = COALESCE(:preferred_mode, preferred_mode) - AND play_style = COALESCE(:play_style, play_style) - """ - params: dict[str, Any] = { - "priv": priv, - "country": country, - "clan_id": clan_id, - "clan_priv": clan_priv, - "preferred_mode": preferred_mode, - "play_style": play_style, - } - rec = await app.state.services.database.fetch_one(query, params) + select_stmt = select(func.count().label("count")) + if priv is not None: + select_stmt = select_stmt.where(UsersTable.priv == priv) + if country is not None: + select_stmt = select_stmt.where(UsersTable.country == country) + if clan_id is not None: + select_stmt = select_stmt.where(UsersTable.clan_id == clan_id) + if clan_priv is not None: + select_stmt = select_stmt.where(UsersTable.clan_priv == clan_priv) + if preferred_mode is not None: + select_stmt = select_stmt.where(UsersTable.preferred_mode == preferred_mode) + if play_style is not None: + select_stmt = select_stmt.where(UsersTable.play_style == play_style) + + rec = await app.state.services.database.fetch_one(select_stmt) assert rec is not None - return cast(int, rec._mapping["count"]) + return cast(int, rec["count"]) async def fetch_many( @@ -187,38 +185,28 @@ async def fetch_many( page_size: int | None = None, ) -> list[User]: """Fetch multiple users from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM users - WHERE priv = COALESCE(:priv, priv) - AND country = COALESCE(:country, country) - AND clan_id = COALESCE(:clan_id, clan_id) - AND clan_priv = COALESCE(:clan_priv, clan_priv) - AND preferred_mode = COALESCE(:preferred_mode, preferred_mode) - AND play_style = COALESCE(:play_style, play_style) - """ - params: dict[str, Any] = { - "priv": priv, - "country": country, - "clan_id": clan_id, - "clan_priv": clan_priv, - "preferred_mode": preferred_mode, - "play_style": play_style, - } + select_stmt = select(*READ_PARAMS) + if priv is not None: + select_stmt = select_stmt.where(UsersTable.priv == priv) + if country is not None: + select_stmt = select_stmt.where(UsersTable.country == country) + if clan_id is not None: + select_stmt = select_stmt.where(UsersTable.clan_id == clan_id) + if clan_priv is not None: + select_stmt = select_stmt.where(UsersTable.clan_priv == clan_priv) + if preferred_mode is not None: + select_stmt = select_stmt.where(UsersTable.preferred_mode == preferred_mode) + if play_style is not None: + select_stmt = select_stmt.where(UsersTable.play_style == play_style) if page is not None and page_size is not None: - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["limit"] = page_size - params["offset"] = (page - 1) * page_size + select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size) - users = await app.state.services.database.fetch_all(query, params) - return cast(list[User], [dict(p._mapping) for p in users]) + users = await app.state.services.database.fetch_all(select_stmt) + return cast(list[User], users) -async def update( +async def partial_update( id: int, name: str | _UnsetSentinel = UNSET, email: str | _UnsetSentinel = UNSET, @@ -238,61 +226,45 @@ async def update( api_key: str | None | _UnsetSentinel = UNSET, ) -> User | None: """Update a user in the database.""" - update_fields: UserUpdateFields = {} + update_stmt = update(UsersTable).where(UsersTable.id == id) if not isinstance(name, _UnsetSentinel): - update_fields["name"] = name - update_fields["safe_name"] = make_safe_name(name) + update_stmt = update_stmt.values(name=name, safe_name=make_safe_name(name)) if not isinstance(email, _UnsetSentinel): - update_fields["email"] = email + update_stmt = update_stmt.values(email=email) if not isinstance(priv, _UnsetSentinel): - update_fields["priv"] = priv + update_stmt = update_stmt.values(priv=priv) if not isinstance(country, _UnsetSentinel): - update_fields["country"] = country + update_stmt = update_stmt.values(country=country) if not isinstance(silence_end, _UnsetSentinel): - update_fields["silence_end"] = silence_end + update_stmt = update_stmt.values(silence_end=silence_end) if not isinstance(donor_end, _UnsetSentinel): - update_fields["donor_end"] = donor_end + update_stmt = update_stmt.values(donor_end=donor_end) if not isinstance(creation_time, _UnsetSentinel): - update_fields["creation_time"] = creation_time + update_stmt = update_stmt.values(creation_time=creation_time) if not isinstance(latest_activity, _UnsetSentinel): - update_fields["latest_activity"] = latest_activity + update_stmt = update_stmt.values(latest_activity=latest_activity) if not isinstance(clan_id, _UnsetSentinel): - update_fields["clan_id"] = clan_id + update_stmt = update_stmt.values(clan_id=clan_id) if not isinstance(clan_priv, _UnsetSentinel): - update_fields["clan_priv"] = clan_priv + update_stmt = update_stmt.values(clan_priv=clan_priv) if not isinstance(preferred_mode, _UnsetSentinel): - update_fields["preferred_mode"] = preferred_mode + update_stmt = update_stmt.values(preferred_mode=preferred_mode) if not isinstance(play_style, _UnsetSentinel): - update_fields["play_style"] = play_style + update_stmt = update_stmt.values(play_style=play_style) if not isinstance(custom_badge_name, _UnsetSentinel): - update_fields["custom_badge_name"] = custom_badge_name + update_stmt = update_stmt.values(custom_badge_name=custom_badge_name) if not isinstance(custom_badge_icon, _UnsetSentinel): - update_fields["custom_badge_icon"] = custom_badge_icon + update_stmt = update_stmt.values(custom_badge_icon=custom_badge_icon) if not isinstance(userpage_content, _UnsetSentinel): - update_fields["userpage_content"] = userpage_content + update_stmt = update_stmt.values(userpage_content=userpage_content) if not isinstance(api_key, _UnsetSentinel): - update_fields["api_key"] = api_key - - query = f"""\ - UPDATE users - SET {",".join(f"{k} = COALESCE(:{k}, {k})" for k in update_fields)} - WHERE id = :id - """ - params: dict[str, Any] = { - "id": id, - } | update_fields - await app.state.services.database.execute(query, params) - - query = f"""\ - SELECT {READ_PARAMS} - FROM users - WHERE id = :id - """ - params = { - "id": id, - } - user = await app.state.services.database.fetch_one(query, params) - return cast(User, dict(user._mapping)) if user is not None else None + update_stmt = update_stmt.values(api_key=api_key) + + await app.state.services.database.execute(update_stmt) + + select_stmt = select(*READ_PARAMS).where(UsersTable.id == id) + user = await app.state.services.database.fetch_one(select_stmt) + return cast(User | None, user) # TODO: delete? diff --git a/app/state/services.py b/app/state/services.py index d206a9aa7..8198ceb6e 100644 --- a/app/state/services.py +++ b/app/state/services.py @@ -11,7 +11,6 @@ from typing import TYPE_CHECKING from typing import TypedDict -import databases import datadog as datadog_module import datadog.threadstats.base as datadog_client import httpx @@ -21,15 +20,12 @@ import app.settings import app.state from app._typing import IPAddress +from app.adapters.database import Database from app.logging import Ansi from app.logging import Rainbow from app.logging import log from app.logging import printc -if TYPE_CHECKING: - import databases.core - - STRANGE_LOG_DIR = Path.cwd() / ".data/logs" VERSION_RGX = re.compile(r"^# v(?P\d+\.\d+\.\d+)$") @@ -39,7 +35,7 @@ """ session objects """ http_client = httpx.AsyncClient() -database = databases.Database(app.settings.DB_DSN) +database = Database(app.settings.DB_DSN) redis: aioredis.Redis = aioredis.from_url(app.settings.REDIS_DSN) datadog: datadog_client.ThreadStats | None = None @@ -394,7 +390,7 @@ async def _get_current_sql_structure_version() -> Version | None: ) if res: - return Version(*map(int, res)) + return Version(res["ver_major"], res["ver_minor"], res["ver_micro"]) return None diff --git a/app/usecases/user_achievements.py b/app/usecases/user_achievements.py index cde73895e..256673ce3 100644 --- a/app/usecases/user_achievements.py +++ b/app/usecases/user_achievements.py @@ -15,13 +15,13 @@ async def create(user_id: int, achievement_id: int) -> UserAchievement: async def fetch_many( - user_id: int, - page: int | _UnsetSentinel = UNSET, - page_size: int | _UnsetSentinel = UNSET, + user_id: int | _UnsetSentinel = UNSET, + page: int | None = None, + page_size: int | None = None, ) -> list[UserAchievement]: user_achievements = await app.repositories.user_achievements.fetch_many( - user_id, - page, - page_size, + user_id=user_id, + page=page, + page_size=page_size, ) return user_achievements diff --git a/pyproject.toml b/pyproject.toml index 7dc5ca374..400f38bc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ profile = "black" [tool.poetry] name = "bancho-py" -version = "5.0.1" +version = "5.1.0" description = "An osu! server implementation optimized for maintainability in modern python" authors = ["Akatsuki Team"] license = "MIT" From d570e77f9c6f9f68d12433f15235e3dd7a029a5c Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Mon, 26 Feb 2024 01:42:20 -0500 Subject: [PATCH 39/48] Low-lift solution to move to stdlib logging (#611) * low-lift solution move to stdlib * reverts * more directly match old behaviour * change err to warn * one msg log for root warning * remove unnecessary logging code * green log for all non-err handled reqs * logout is not a warning * Support stdlib logging `extra` parameter data forwarding * spacing * accurate comment * remove legacy chatlog logging to file * simplify * file logging without colors (for the sanity of humanity) * move debug query logging * sql query useful logs in stdout * backwards compat for .data/logs/chat.log * le fix * prefer const * Use `logging.yaml.example` * update docs for logging.yaml.example * Bump minor version to 5.2.0 --- .env.example | 2 + .github/docs/wiki/Setting-up.md | 10 +- .../wiki/locale/de-DE/Setting-up-de-DE.md | 10 +- .../wiki/locale/zh-CN/Setting-up-zh-CN.md | 10 +- .github/workflows/test.yaml | 1 + .gitignore | 2 + app/adapters/database.py | 9 +- app/api/domains/cho.py | 19 ++- app/api/domains/osu.py | 1 - app/api/init_api.py | 2 +- app/api/middlewares.py | 13 +- app/logging.py | 116 ++++++------------ app/objects/player.py | 2 +- app/settings.py | 2 + app/state/services.py | 20 +-- app/utils.py | 16 ++- docker-compose.test.yml | 1 + docker-compose.yml | 1 + logging.yaml.example | 32 +++++ main.py | 3 + poetry.lock | 13 +- pyproject.toml | 3 +- tools/proxy.py | 12 +- 23 files changed, 174 insertions(+), 126 deletions(-) create mode 100644 logging.yaml.example diff --git a/.env.example b/.env.example index 4e43ce48e..85846fc03 100644 --- a/.env.example +++ b/.env.example @@ -63,6 +63,8 @@ DISCORD_AUDIT_LOG_WEBHOOK= # for debugging & development purposes. AUTOMATICALLY_REPORT_PROBLEMS=False +LOG_WITH_COLORS=False + # XXX: Uncomment this if you have downloaded the database from maxmind. # Change the path to the .mmdb file you downloaded, uncomment here and in docker-compose.yml # You can download the database here: https://dev.maxmind.com/geoip/geolite2-free-geolocation-data diff --git a/.github/docs/wiki/Setting-up.md b/.github/docs/wiki/Setting-up.md index 0db2eb829..85517d93a 100644 --- a/.github/docs/wiki/Setting-up.md +++ b/.github/docs/wiki/Setting-up.md @@ -17,15 +17,23 @@ sudo apt install -y docker docker-compose ## configuring bancho.py all configuration for the osu! server (bancho.py) itself can be done from the -`.env` file. we provide an example `.env.example` file which you can use as a base. +`.env` and `logging.yaml` files. we will provide example files for each, which +you can use as a base and modify as you'd like. ```sh # create a configuration file from the sample provided cp .env.example .env +# create a logging configuration file from the sample provided +cp logging.yaml.example logging.yaml + # configure the application to your needs # this is required to move onto the next steps nano .env + +# you can additionally configure the logging if you'd like, +# but the default should work fine for most users. +nano logging.yaml ``` ## configuring a reverse proxy (we'll use nginx) diff --git a/.github/docs/wiki/locale/de-DE/Setting-up-de-DE.md b/.github/docs/wiki/locale/de-DE/Setting-up-de-DE.md index 0db2eb829..85517d93a 100644 --- a/.github/docs/wiki/locale/de-DE/Setting-up-de-DE.md +++ b/.github/docs/wiki/locale/de-DE/Setting-up-de-DE.md @@ -17,15 +17,23 @@ sudo apt install -y docker docker-compose ## configuring bancho.py all configuration for the osu! server (bancho.py) itself can be done from the -`.env` file. we provide an example `.env.example` file which you can use as a base. +`.env` and `logging.yaml` files. we will provide example files for each, which +you can use as a base and modify as you'd like. ```sh # create a configuration file from the sample provided cp .env.example .env +# create a logging configuration file from the sample provided +cp logging.yaml.example logging.yaml + # configure the application to your needs # this is required to move onto the next steps nano .env + +# you can additionally configure the logging if you'd like, +# but the default should work fine for most users. +nano logging.yaml ``` ## configuring a reverse proxy (we'll use nginx) diff --git a/.github/docs/wiki/locale/zh-CN/Setting-up-zh-CN.md b/.github/docs/wiki/locale/zh-CN/Setting-up-zh-CN.md index 0db2eb829..85517d93a 100644 --- a/.github/docs/wiki/locale/zh-CN/Setting-up-zh-CN.md +++ b/.github/docs/wiki/locale/zh-CN/Setting-up-zh-CN.md @@ -17,15 +17,23 @@ sudo apt install -y docker docker-compose ## configuring bancho.py all configuration for the osu! server (bancho.py) itself can be done from the -`.env` file. we provide an example `.env.example` file which you can use as a base. +`.env` and `logging.yaml` files. we will provide example files for each, which +you can use as a base and modify as you'd like. ```sh # create a configuration file from the sample provided cp .env.example .env +# create a logging configuration file from the sample provided +cp logging.yaml.example logging.yaml + # configure the application to your needs # this is required to move onto the next steps nano .env + +# you can additionally configure the logging if you'd like, +# but the default should work fine for most users. +nano logging.yaml ``` ## configuring a reverse proxy (we'll use nginx) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b05341a79..4f72fe45b 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -18,6 +18,7 @@ env: APP_HOST: "0.0.0.0" APP_PORT: "10000" AUTOMATICALLY_REPORT_PROBLEMS: "False" + LOG_WITH_COLORS: "False" COMMAND_PREFIX: "!" DATA_DIRECTORY: "not relevant" DB_HOST: "mysql" diff --git a/.gitignore b/.gitignore index 33b467341..7cd198bfb 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,5 @@ tools/cf_records.txt /.redis-data/ poetry.toml .coverage +logging.yaml +logs.log diff --git a/app/adapters/database.py b/app/adapters/database.py index 117469a90..c5e40b437 100644 --- a/app/adapters/database.py +++ b/app/adapters/database.py @@ -10,6 +10,7 @@ from sqlalchemy.sql.expression import ClauseElement from app import settings +from app.logging import log class MySQLDialect(MySQLDialect_mysqldb): @@ -38,8 +39,6 @@ def _compile(self, clause_element: ClauseElement) -> tuple[str, MySQLParams]: dialect=DIALECT, compile_kwargs={"render_postcompile": True}, ) - if settings.DEBUG: - print(str(compiled), compiled.params) return str(compiled), compiled.params async def fetch_one( @@ -51,6 +50,12 @@ async def fetch_one( query, params = self._compile(query) row = await self._database.fetch_one(query, params) + if settings.DEBUG: + log( + f"Executed SQL query: {query} {params}", + extra={"query": query, "params": params}, + ) + return dict(row._mapping) if row is not None else None async def fetch_all( diff --git a/app/api/domains/cho.py b/app/api/domains/cho.py index 14ca8b36f..068277f75 100644 --- a/app/api/domains/cho.py +++ b/app/api/domains/cho.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import logging import re import struct import time @@ -13,6 +14,7 @@ from pathlib import Path from typing import Literal from typing import TypedDict +from zoneinfo import ZoneInfo import bcrypt import databases.core @@ -37,6 +39,7 @@ from app.constants.privileges import ClientPrivileges from app.constants.privileges import Privileges from app.logging import Ansi +from app.logging import get_timestamp from app.logging import log from app.logging import magnitude_fmt_time from app.objects.beatmap import Beatmap @@ -69,6 +72,7 @@ OSU_API_V2_CHANGELOG_URL = "https://osu.ppy.sh/api/v2/changelog" BEATMAPS_PATH = Path.cwd() / ".data/osu" +DISK_CHAT_LOG_FILE = ".data/logs/chat.log" BASE_DOMAIN = app.settings.DOMAIN @@ -424,7 +428,13 @@ async def handle(self, player: Player) -> None: t_chan.send(msg, sender=player) player.update_latest_activity_soon() - log(f"{player} @ {t_chan}: {msg}", Ansi.LCYAN, file=".data/logs/chat.log") + + log(f"{player} @ {t_chan}: {msg}", Ansi.LCYAN) + + with open(DISK_CHAT_LOG_FILE, "a+") as f: + f.write( + f"[{get_timestamp(full=True, tz=ZoneInfo('GMT'))}] {player} @ {t_chan}: {msg}\n", + ) @register(ClientPackets.LOGOUT, restricted=True) @@ -1298,7 +1308,12 @@ async def handle(self, player: Player) -> None: player.send(resp_msg, sender=target) player.update_latest_activity_soon() - log(f"{player} @ {target}: {msg}", Ansi.LCYAN, file=".data/logs/chat.log") + + log(f"{player} @ {target}: {msg}", Ansi.LCYAN) + with open(DISK_CHAT_LOG_FILE, "a+") as f: + f.write( + f"[{get_timestamp(full=True, tz=ZoneInfo('GMT'))}] {player} @ {target}: {msg}\n", + ) @register(ClientPackets.PART_LOBBY) diff --git a/app/api/domains/osu.py b/app/api/domains/osu.py index 0122c5ec0..a7b79a6d7 100644 --- a/app/api/domains/osu.py +++ b/app/api/domains/osu.py @@ -51,7 +51,6 @@ from app.constants.privileges import Privileges from app.logging import Ansi from app.logging import log -from app.logging import printc from app.objects import models from app.objects.beatmap import Beatmap from app.objects.beatmap import RankedStatus diff --git a/app/api/init_api.py b/app/api/init_api.py index 1602c558d..74934d3f3 100644 --- a/app/api/init_api.py +++ b/app/api/init_api.py @@ -79,7 +79,7 @@ async def lifespan(asgi_app: BanchoAPI) -> AsyncIterator[None]: if app.utils.is_running_as_admin(): log( "Running the server with root privileges is not recommended.", - Ansi.LRED, + Ansi.LYELLOW, ) await app.state.services.database.connect() diff --git a/app/api/middlewares.py b/app/api/middlewares.py index 1a64dde5f..8ff917a03 100644 --- a/app/api/middlewares.py +++ b/app/api/middlewares.py @@ -10,7 +10,6 @@ from app.logging import Ansi from app.logging import log from app.logging import magnitude_fmt_time -from app.logging import printc class MetricsMiddleware(BaseHTTPMiddleware): @@ -25,16 +24,14 @@ async def dispatch( time_elapsed = end_time - start_time - col = ( - Ansi.LGREEN - if 200 <= response.status_code < 300 - else Ansi.LYELLOW if 300 <= response.status_code < 400 else Ansi.LRED - ) + col = Ansi.LGREEN if response.status_code < 400 else Ansi.LRED url = f"{request.headers['host']}{request['path']}" - log(f"[{request.method}] {response.status_code} {url}", col, end=" | ") - printc(f"Request took: {magnitude_fmt_time(time_elapsed)}", Ansi.LBLUE) + log( + f"[{request.method}] {response.status_code} {url}{Ansi.RESET!r} | {Ansi.LBLUE!r}Request took: {magnitude_fmt_time(time_elapsed)}", + col, + ) response.headers["process-time"] = str(round(time_elapsed) / 1e6) return response diff --git a/app/logging.py b/app/logging.py index c6b98a6aa..bffed18fb 100644 --- a/app/logging.py +++ b/app/logging.py @@ -1,10 +1,22 @@ from __future__ import annotations -import colorsys import datetime +import logging.config +import re +from collections.abc import Mapping from enum import IntEnum from zoneinfo import ZoneInfo +import yaml + +from app import settings + + +def configure_logging() -> None: + with open("logging.yaml") as f: + config = yaml.safe_load(f.read()) + logging.config.dictConfig(config) + class Ansi(IntEnum): # Default colours @@ -33,105 +45,49 @@ def __repr__(self) -> str: return f"\x1b[{self.value}m" -class RGB: - def __init__(self, *args: int) -> None: - largs = len(args) - - if largs == 3: - # r, g, b passed. - self.r, self.g, self.b = args - elif largs == 1: - # passed as single argument - rgb = args[0] - self.b = rgb & 0xFF - self.g = (rgb >> 8) & 0xFF - self.r = (rgb >> 16) & 0xFF - else: - raise ValueError("Incorrect params for RGB.") - - def __repr__(self) -> str: - return f"\x1b[38;2;{self.r};{self.g};{self.b}m" - - -class _Rainbow: ... - - -Rainbow = _Rainbow() - -Colour_Types = Ansi | RGB | _Rainbow - - def get_timestamp(full: bool = False, tz: ZoneInfo | None = None) -> str: fmt = "%d/%m/%Y %I:%M:%S%p" if full else "%I:%M:%S%p" return f"{datetime.datetime.now(tz=tz):{fmt}}" -# TODO: better solution than this; this at least requires the -# iana/tzinfo database to be installed, meaning it's limited. -_log_tz = ZoneInfo("GMT") # default +ANSI_ESCAPE_REGEX = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]") -def set_timezone(tz: ZoneInfo) -> None: - global _log_tz - _log_tz = tz +def escape_ansi(line: str) -> str: + return ANSI_ESCAPE_REGEX.sub("", line) -def printc(msg: str, col: Colour_Types, end: str = "\n") -> None: - """Print a string, in a specified ansi colour.""" - print(f"{col!r}{msg}{Ansi.RESET!r}", end=end) +ROOT_LOGGER = logging.getLogger() def log( msg: str, - col: Colour_Types | None = None, - file: str | None = None, - end: str = "\n", + start_color: Ansi | None = None, + extra: Mapping[str, object] | None = None, ) -> None: """\ - Print a string, in a specified ansi colour with timestamp. - - Allows for the functionality to write to a file as - well by passing the filepath with the `file` parameter. + A thin wrapper around the stdlib logging module to handle mostly + backwards-compatibility for colours during our migration to the + standard library logging module. """ - ts_short = get_timestamp(full=False, tz=_log_tz) - - if col: - if col is Rainbow: - print(f"{Ansi.GRAY!r}[{ts_short}] {_fmt_rainbow(msg, 2/3)}", end=end) - print(f"{Ansi.GRAY!r}[{ts_short}] {_fmt_rainbow(msg, 2/3)}", end=end) - else: - # normal colour - print(f"{Ansi.GRAY!r}[{ts_short}] {col!r}{msg}{Ansi.RESET!r}", end=end) + # TODO: decouple colors from the base logging function; move it to + # be a formatter-specific concern such that we can log without color. + if start_color is Ansi.LYELLOW: + log_level = logging.WARNING + elif start_color is Ansi.LRED: + log_level = logging.ERROR else: - print(f"{Ansi.GRAY!r}[{ts_short}]{Ansi.RESET!r} {msg}", end=end) - - if file: - # log simple ascii output to file. - with open(file, "a+") as f: - f.write(f"[{get_timestamp(full=True, tz=_log_tz)}] {msg}\n") - - -def rainbow_color_stops( - n: int = 10, - lum: float = 0.5, - end: float = 2 / 3, -) -> list[tuple[float, float, float]]: - return [ - (r * 255, g * 255, b * 255) - for r, g, b in [ - colorsys.hls_to_rgb(end * i / (n - 1), lum, 1) for i in range(n) - ] - ] - - -def _fmt_rainbow(msg: str, end: float = 2 / 3) -> str: - cols = [RGB(*map(int, rgb)) for rgb in rainbow_color_stops(n=len(msg), end=end)] - return "".join([f"{cols[i]!r}{c}" for i, c in enumerate(msg)]) + repr(Ansi.RESET) + log_level = logging.INFO + if settings.LOG_WITH_COLORS: + color_prefix = f"{start_color!r}" if start_color is not None else "" + color_suffix = f"{Ansi.RESET!r}" if start_color is not None else "" + else: + msg = escape_ansi(msg) + color_prefix = color_suffix = "" -def print_rainbow(msg: str, rainbow_end: float = 2 / 3, end: str = "\n") -> None: - print(_fmt_rainbow(msg, rainbow_end), end=end) + ROOT_LOGGER.log(log_level, f"{color_prefix}{msg}{color_suffix}", extra=extra) TIME_ORDER_SUFFIXES = ["nsec", "μsec", "msec", "sec"] diff --git a/app/objects/player.py b/app/objects/player.py index 85e3193c3..ce4264493 100644 --- a/app/objects/player.py +++ b/app/objects/player.py @@ -403,7 +403,7 @@ def logout(self) -> None: app.state.sessions.players.enqueue(app.packets.logout(self.id)) - log(f"{self} logged out.", Ansi.LYELLOW) + log(f"{self} logged out.") async def update_privs(self, new: Privileges) -> None: """Update `self`'s privileges to `new`.""" diff --git a/app/settings.py b/app/settings.py index 0980c7842..aa292197c 100644 --- a/app/settings.py +++ b/app/settings.py @@ -59,6 +59,8 @@ AUTOMATICALLY_REPORT_PROBLEMS = read_bool(os.environ["AUTOMATICALLY_REPORT_PROBLEMS"]) +LOG_WITH_COLORS = read_bool(os.environ["LOG_WITH_COLORS"]) + # advanced dev settings ## WARNING touch this once you've diff --git a/app/state/services.py b/app/state/services.py index 8198ceb6e..22929257e 100644 --- a/app/state/services.py +++ b/app/state/services.py @@ -1,6 +1,7 @@ from __future__ import annotations import ipaddress +import logging import pickle import re import secrets @@ -22,9 +23,7 @@ from app._typing import IPAddress from app.adapters.database import Database from app.logging import Ansi -from app.logging import Rainbow from app.logging import log -from app.logging import printc STRANGE_LOG_DIR = Path.cwd() / ".data/logs" @@ -245,8 +244,11 @@ async def log_strange_occurrence(obj: object) -> None: ) if response.status_code == 200 and response.read() == b"ok": uploaded = True - log("Logged strange occurrence to cmyui's server.", Ansi.LBLUE) - log("Thank you for your participation! <3", Rainbow) + log( + "Logged strange occurrence to cmyui's server. " + "Thank you for your participation! <3", + Ansi.LBLUE, + ) else: log( f"Autoupload to cmyui's server failed (HTTP {response.status_code})", @@ -263,11 +265,13 @@ async def log_strange_occurrence(obj: object) -> None: log_file.touch(exist_ok=False) log_file.write_bytes(pickled_obj) - log("Logged strange occurrence to", Ansi.LYELLOW, end=" ") - printc("/".join(log_file.parts[-4:]), Ansi.LBLUE) - log( - "Greatly appreciated if you could forward this to cmyui#0425 :)", + "Logged strange occurrence to" + "/".join(log_file.parts[-4:]), + Ansi.LYELLOW, + ) + log( + "It would be greatly appreciated if you could forward this to the " + "bancho.py development team. To do so, please email josh@akatsuki.gg", Ansi.LYELLOW, ) diff --git a/app/utils.py b/app/utils.py index 64f675c49..a294cb186 100644 --- a/app/utils.py +++ b/app/utils.py @@ -222,7 +222,7 @@ def is_running_as_admin() -> bool: def display_startup_dialog() -> None: """Print any general information or warnings to the console.""" if app.settings.DEVELOPER_MODE: - log("running in advanced mode", Ansi.LRED) + log("running in advanced mode", Ansi.LYELLOW) if app.settings.DEBUG: log("running in debug mode", Ansi.LMAGENTA) @@ -230,17 +230,15 @@ def display_startup_dialog() -> None: # unnecessary power over the operating system and is not advised. if is_running_as_admin(): log( - "It is not recommended to run bancho.py as root/admin, especially in production..", + "It is not recommended to run bancho.py as root/admin, especially in production." + + ( + " You are at increased risk as developer mode is enabled." + if app.settings.DEVELOPER_MODE + else "" + ), Ansi.LYELLOW, ) - if app.settings.DEVELOPER_MODE: - log( - "The risk is even greater with features " - "such as config.advanced enabled.", - Ansi.LRED, - ) - if not has_internet_connectivity(): log("No internet connectivity detected", Ansi.LYELLOW) diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 1afd3eb54..6f3901287 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -87,6 +87,7 @@ services: - DISALLOW_INGAME_REGISTRATION=${DISALLOW_INGAME_REGISTRATION} - DISCORD_AUDIT_LOG_WEBHOOK=${DISCORD_AUDIT_LOG_WEBHOOK} - AUTOMATICALLY_REPORT_PROBLEMS=${AUTOMATICALLY_REPORT_PROBLEMS} + - LOG_WITH_COLORS=${LOG_WITH_COLORS} - SSL_CERT_PATH=${SSL_CERT_PATH} - SSL_KEY_PATH=${SSL_KEY_PATH} - DEVELOPER_MODE=${DEVELOPER_MODE} diff --git a/docker-compose.yml b/docker-compose.yml index f9210f73d..91e71bace 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -83,6 +83,7 @@ services: - DISALLOW_INGAME_REGISTRATION=${DISALLOW_INGAME_REGISTRATION} - DISCORD_AUDIT_LOG_WEBHOOK=${DISCORD_AUDIT_LOG_WEBHOOK} - AUTOMATICALLY_REPORT_PROBLEMS=${AUTOMATICALLY_REPORT_PROBLEMS} + - LOG_WITH_COLORS=${LOG_WITH_COLORS} - SSL_CERT_PATH=${SSL_CERT_PATH} - SSL_KEY_PATH=${SSL_KEY_PATH} - DEVELOPER_MODE=${DEVELOPER_MODE} diff --git a/logging.yaml.example b/logging.yaml.example new file mode 100644 index 000000000..a8f101d77 --- /dev/null +++ b/logging.yaml.example @@ -0,0 +1,32 @@ +version: 1 +disable_existing_loggers: true +loggers: + httpx: + level: WARNING + handlers: [console] + propagate: no + httpcore: + level: WARNING + handlers: [console] + propagate: no +handlers: + console: + class: logging.StreamHandler + level: INFO + formatter: plaintext + stream: ext://sys.stdout + # file: + # class: logging.FileHandler + # level: INFO + # formatter: json + # filename: logs.log +formatters: + plaintext: + format: '[%(asctime)s] %(levelname)s %(message)s' + datefmt: '%Y-%m-%d %H:%M:%S' + # json: + # class: pythonjsonlogger.jsonlogger.JsonFormatter + # format: '%(asctime)s %(name)s %(levelname)s %(message)s' +root: + level: INFO + handlers: [console] # , file] diff --git a/main.py b/main.py index 6fbaacf90..fae33964a 100755 --- a/main.py +++ b/main.py @@ -5,9 +5,12 @@ import uvicorn +import app.logging import app.settings import app.utils +app.logging.configure_logging() + def main() -> int: app.utils.display_startup_dialog() diff --git a/poetry.lock b/poetry.lock index ef594dcf9..c0f0085d3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1441,6 +1441,17 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "python-json-logger" +version = "2.0.7" +description = "A python library adding a json log formatter" +optional = false +python-versions = ">=3.6" +files = [ + {file = "python-json-logger-2.0.7.tar.gz", hash = "sha256:23e7ec02d34237c5aa1e29a070193a4ea87583bb4e7f8fd06d3de8264c4b2e1c"}, + {file = "python_json_logger-2.0.7-py3-none-any.whl", hash = "sha256:f380b826a991ebbe3de4d897aeec42760035ac760345e57b812938dc8b35e2bd"}, +] + [[package]] name = "python-multipart" version = "0.0.9" @@ -1929,4 +1940,4 @@ cython = "*" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "b5d6c82b060f449f4b59e5b9e38e78871bbdf72b615f0a06cd6d13722e2871fe" +content-hash = "7f64adbe23aa171bfb5cb54c1019c1e6b3656a473ba1627965057e28f7d21372" diff --git a/pyproject.toml b/pyproject.toml index 400f38bc1..901ba987b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ profile = "black" [tool.poetry] name = "bancho-py" -version = "5.1.0" +version = "5.2.0" description = "An osu! server implementation optimized for maintainability in modern python" authors = ["Akatsuki Team"] license = "MIT" @@ -73,6 +73,7 @@ respx = "0.20.2" tzdata = "2024.1" coverage = "^7.4.1" databases = {version = "^0.8.0", extras = ["mysql"]} +python-json-logger = "^2.0.7" [tool.poetry.group.dev.dependencies] pre-commit = "3.6.1" diff --git a/tools/proxy.py b/tools/proxy.py index ea1297c30..099a9d37e 100644 --- a/tools/proxy.py +++ b/tools/proxy.py @@ -14,8 +14,6 @@ from mitmproxy import http -from app.logging import RGB - @unique class ServerPackets(IntEnum): @@ -83,19 +81,15 @@ def __repr__(self) -> str: return f"<{self.name} ({self.value})>" -BYTE_ORDER_SUFFIXES = [ - f"{RGB(0x76eb00)!r}B\x1b[0m", - f"{RGB(0xbfbf00)!r}KB\x1b[0m", - f"{RGB(0xe98b00)!r}MB\x1b[0m", - f"{RGB(0xfd4900)!r}GB\x1b[0m", -] +BYTE_ORDER_SUFFIXES = ["B", "KB", "MB", "GB"] def fmt_bytes(n: int | float) -> str: + suffix = None for suffix in BYTE_ORDER_SUFFIXES: if n < 1024: break - n /= 1024 # more to go + n /= 1024 return f"{n:,.2f}{suffix}" From cea03a74ec6b2b82435836a83155301aa42c925c Mon Sep 17 00:00:00 2001 From: cmyui Date: Mon, 26 Feb 2024 01:49:46 -0500 Subject: [PATCH 40/48] fix some bugs where a set without maps breaks the server --- app/objects/beatmap.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/app/objects/beatmap.py b/app/objects/beatmap.py index 6039f583b..894faba0d 100644 --- a/app/objects/beatmap.py +++ b/app/objects/beatmap.py @@ -625,6 +625,9 @@ def _cache_expired(self) -> bool: expired and needs an update from the osu!api.""" current_datetime = datetime.now() + if not self.maps: + return True + # the delta between cache invalidations will increase depending # on how long it's been since the map was last updated on osu! last_map_update = max(bmap.last_update for bmap in self.maps) @@ -758,19 +761,20 @@ async def _update_if_available(self) -> None: # TODO: a couple of open questions here: # - should we delete the beatmap from the database if it's not in the osu!api? # - are 404 and 200 the only cases where we should delete the beatmap? - map_md5s_to_delete = {bmap.md5 for bmap in self.maps} + if self.maps: + map_md5s_to_delete = {bmap.md5 for bmap in self.maps} - # delete maps - await app.state.services.database.execute( - "DELETE FROM maps WHERE md5 IN :map_md5s", - {"map_md5s": map_md5s_to_delete}, - ) + # delete maps + await app.state.services.database.execute( + "DELETE FROM maps WHERE md5 IN :map_md5s", + {"map_md5s": map_md5s_to_delete}, + ) - # delete scores on the maps - await app.state.services.database.execute( - "DELETE FROM scores WHERE map_md5 IN :map_md5s", - {"map_md5s": map_md5s_to_delete}, - ) + # delete scores on the maps + await app.state.services.database.execute( + "DELETE FROM scores WHERE map_md5 IN :map_md5s", + {"map_md5s": map_md5s_to_delete}, + ) # delete set await app.state.services.database.execute( From 6a8bf6386cd592d2d96b58545a330e2188deab24 Mon Sep 17 00:00:00 2001 From: mini <39670899+minisbett@users.noreply.github.com> Date: Mon, 26 Feb 2024 09:07:19 +0100 Subject: [PATCH 41/48] fix: ignore "0" disk-serial md5 on client hash checks (#604) * Add check for "0" md5-hash * small refactor/comment update --------- Co-authored-by: cmyui --- app/api/domains/cho.py | 14 +++++++++++++- app/repositories/client_hashes.py | 20 +++++++++++--------- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/app/api/domains/cho.py b/app/api/domains/cho.py index 068277f75..b7ac5b8f1 100644 --- a/app/api/domains/cho.py +++ b/app/api/domains/cho.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import hashlib import logging import re import struct @@ -759,12 +760,23 @@ async def handle_osu_login_request( # TODO: store adapters individually + # Some disk manufacturers set constant/shared ids for their products. + # In these cases, there's not a whole lot we can do -- we'll allow them thru. + INACTIONABLE_DISK_SIGNATURE_MD5S: list[str] = [ + hashlib.md5(b"0").hexdigest(), # "0" is likely the most common variant + ] + + if login_data["disk_signature_md5"] not in INACTIONABLE_DISK_SIGNATURE_MD5S: + disk_signature_md5 = login_data["disk_signature_md5"] + else: + disk_signature_md5 = None + hw_matches = await client_hashes_repo.fetch_any_hardware_matches_for_user( userid=user_info["id"], running_under_wine=running_under_wine, adapters=login_data["adapters_md5"], uninstall_id=login_data["uninstall_md5"], - disk_serial=login_data["disk_signature_md5"], + disk_serial=disk_signature_md5, ) if hw_matches: diff --git a/app/repositories/client_hashes.py b/app/repositories/client_hashes.py index 67f3494bc..8f7d50c6c 100644 --- a/app/repositories/client_hashes.py +++ b/app/repositories/client_hashes.py @@ -13,6 +13,8 @@ from sqlalchemy import select from sqlalchemy.dialects.mysql import Insert as MysqlInsert from sqlalchemy.dialects.mysql import insert as mysql_insert +from sqlalchemy.sql import ColumnElement +from sqlalchemy.types import Boolean import app.state.services from app.repositories import Base @@ -101,8 +103,8 @@ async def create( async def fetch_any_hardware_matches_for_user( userid: int, running_under_wine: bool, - adapters: str | None = None, - uninstall_id: str | None = None, + adapters: str, + uninstall_id: str, disk_serial: str | None = None, ) -> list[ClientHashWithPlayer]: """\ @@ -119,13 +121,13 @@ async def fetch_any_hardware_matches_for_user( if running_under_wine: select_stmt = select_stmt.where(ClientHashesTable.uninstall_id == uninstall_id) else: - select_stmt = select_stmt.where( - or_( - ClientHashesTable.adapters == adapters, - ClientHashesTable.uninstall_id == uninstall_id, - ClientHashesTable.disk_serial == disk_serial, - ), - ) + # make disk serial optional in the OR + oneof_filters: list[ColumnElement[Boolean]] = [] + oneof_filters.append(ClientHashesTable.adapters == adapters) + oneof_filters.append(ClientHashesTable.uninstall_id == uninstall_id) + if disk_serial is not None: + oneof_filters.append(ClientHashesTable.disk_serial == disk_serial) + select_stmt = select_stmt.where(or_(*oneof_filters)) client_hashes = await app.state.services.database.fetch_all(select_stmt) return cast(list[ClientHashWithPlayer], client_hashes) From f1a1eb2a7f8387583fc33328f11d74a99a97ec20 Mon Sep 17 00:00:00 2001 From: Khoo Hao Yit Date: Mon, 26 Feb 2024 17:07:15 +0800 Subject: [PATCH 42/48] fix: prevent crash when leaving tournament match channel (#479) * fix: remove non-existence player * small logic change to avoid unnecessary leave_channel() call --------- Co-authored-by: cmyui --- app/api/domains/cho.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/api/domains/cho.py b/app/api/domains/cho.py index b7ac5b8f1..29790114e 100644 --- a/app/api/domains/cho.py +++ b/app/api/domains/cho.py @@ -1997,7 +1997,7 @@ async def handle(self, player: Player) -> None: return # insufficient privs match = app.state.sessions.matches[self.match_id] - if not match: + if not (match and player.id in match.tourney_clients): return # match not found # attempt to join match chan From bcd54a3af9b648ca584599173f5c7672c9918914 Mon Sep 17 00:00:00 2001 From: cmyui Date: Mon, 26 Feb 2024 04:11:57 -0500 Subject: [PATCH 43/48] Bump patch version to 5.2.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 901ba987b..a3869e5bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ profile = "black" [tool.poetry] name = "bancho-py" -version = "5.2.0" +version = "5.2.1" description = "An osu! server implementation optimized for maintainability in modern python" authors = ["Akatsuki Team"] license = "MIT" From 6b64b8b352e6019ebf2061a69021a25d44fed55e Mon Sep 17 00:00:00 2001 From: mini <39670899+minisbett@users.noreply.github.com> Date: Mon, 26 Feb 2024 14:36:46 +0100 Subject: [PATCH 44/48] fix: catch pp wrong due to both acc and judgements specified (#627) * Fix both acc and judgements specified * comment * more clear comment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix pp snapshot in itest * fix itest snapshot --------- Co-authored-by: cmyui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- app/objects/score.py | 2 -- app/usecases/performance.py | 26 ++++++++++++++++---------- tests/integration/domains/osu_test.py | 2 +- tools/recalc.py | 1 - 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/app/objects/score.py b/app/objects/score.py index 1051e18d0..c9c72601a 100644 --- a/app/objects/score.py +++ b/app/objects/score.py @@ -324,8 +324,6 @@ def calculate_performance(self, beatmap_id: int) -> tuple[float, float]: mode=mode_vn, mods=int(self.mods), combo=self.max_combo, - # prefer to use the score's specific params that add up to the acc - acc=self.acc, ngeki=self.ngeki, n300=self.n300, nkatu=self.nkatu, diff --git a/app/usecases/performance.py b/app/usecases/performance.py index 5db4f0299..2c2ed1015 100644 --- a/app/usecases/performance.py +++ b/app/usecases/performance.py @@ -18,6 +18,7 @@ class ScoreParams: combo: int | None = None # caller may pass either acc OR 300/100/50/geki/katu/miss + # passing both will result in a value error being raised acc: float | None = None n300: int | None = None @@ -60,21 +61,26 @@ def calculate_performances( osu_file_path: str, scores: Iterable[ScoreParams], ) -> list[PerformanceResult]: + """\ + Calculate performance for multiple scores on a single beatmap. + + Typically most useful for mass-recalculation situations. + + TODO: Some level of error handling & returning to caller should be + implemented here to handle cases where e.g. the beatmap file is invalid + or there an issue during calculation. + """ calc_bmap = Beatmap(path=osu_file_path) results: list[PerformanceResult] = [] for score in scores: - # assert either acc OR 300/100/50/geki/katu/miss is present, but not both - # if (score.acc is None) == ( - # score.n300 is None - # and score.n100 is None - # and score.n50 is None - # and score.ngeki is None - # and score.nkatu is None - # and score.nmiss is None - # ): - # raise ValueError("Either acc OR 300/100/50/geki/katu/miss must be present") + if score.acc and ( + score.n300 or score.n100 or score.n50 or score.ngeki or score.nkatu + ): + raise ValueError( + "Must not specify accuracy AND 300/100/50/geki/katu. Only one or the other.", + ) # rosupp ignores NC and requires DT if score.mods is not None: diff --git a/tests/integration/domains/osu_test.py b/tests/integration/domains/osu_test.py index 49adc7efa..30c6f086f 100644 --- a/tests/integration/domains/osu_test.py +++ b/tests/integration/domains/osu_test.py @@ -239,5 +239,5 @@ async def test_score_submission( assert response.status_code == status.HTTP_200_OK assert ( response.read() - == b"beatmapId:315|beatmapSetId:141|beatmapPlaycount:1|beatmapPasscount:1|approvedDate:2014-05-18 15:41:48|\n|chartId:beatmap|chartUrl:https://osu.cmyui.xyz/s/141|chartName:Beatmap Ranking|rankBefore:|rankAfter:1|rankedScoreBefore:|rankedScoreAfter:26810|totalScoreBefore:|totalScoreAfter:26810|maxComboBefore:|maxComboAfter:52|accuracyBefore:|accuracyAfter:81.94|ppBefore:|ppAfter:10.041|onlineScoreId:1|\n|chartId:overall|chartUrl:https://cmyui.xyz/u/3|chartName:Overall Ranking|rankBefore:|rankAfter:1|rankedScoreBefore:|rankedScoreAfter:26810|totalScoreBefore:|totalScoreAfter:26810|maxComboBefore:|maxComboAfter:52|accuracyBefore:|accuracyAfter:81.94|ppBefore:|ppAfter:10|achievements-new:osu-skill-pass-4+Insanity Approaches+You're not twitching, you're just ready./all-intro-hidden+Blindsight+I can see just perfectly" + == b"beatmapId:315|beatmapSetId:141|beatmapPlaycount:1|beatmapPasscount:1|approvedDate:2014-05-18 15:41:48|\n|chartId:beatmap|chartUrl:https://osu.cmyui.xyz/s/141|chartName:Beatmap Ranking|rankBefore:|rankAfter:1|rankedScoreBefore:|rankedScoreAfter:26810|totalScoreBefore:|totalScoreAfter:26810|maxComboBefore:|maxComboAfter:52|accuracyBefore:|accuracyAfter:81.94|ppBefore:|ppAfter:10.313|onlineScoreId:1|\n|chartId:overall|chartUrl:https://cmyui.xyz/u/3|chartName:Overall Ranking|rankBefore:|rankAfter:1|rankedScoreBefore:|rankedScoreAfter:26810|totalScoreBefore:|totalScoreAfter:26810|maxComboBefore:|maxComboAfter:52|accuracyBefore:|accuracyAfter:81.94|ppBefore:|ppAfter:11|achievements-new:osu-skill-pass-4+Insanity Approaches+You're not twitching, you're just ready./all-intro-hidden+Blindsight+I can see just perfectly" ) diff --git a/tools/recalc.py b/tools/recalc.py index 172a5053e..d7c019d33 100644 --- a/tools/recalc.py +++ b/tools/recalc.py @@ -66,7 +66,6 @@ async def recalculate_score( calculator = Calculator( mode=GameMode(score["mode"]).as_vanilla, mods=score["mods"], - acc=score["acc"], combo=score["max_combo"], n_geki=score["ngeki"], # Mania 320s n300=score["n300"], From fa94c6d977470f98eb7c67e1b9ff9a0aff17ab4f Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Tue, 27 Feb 2024 21:48:26 -0500 Subject: [PATCH 45/48] Time all db calls (#643) --- app/adapters/database.py | 73 ++++++++++++++++++++++++++++++++++++---- app/timer.py | 27 +++++++++++++++ 2 files changed, 93 insertions(+), 7 deletions(-) create mode 100644 app/timer.py diff --git a/app/adapters/database.py b/app/adapters/database.py index c5e40b437..404aa1c60 100644 --- a/app/adapters/database.py +++ b/app/adapters/database.py @@ -11,6 +11,7 @@ from app import settings from app.logging import log +from app.timer import Timer class MySQLDialect(MySQLDialect_mysqldb): @@ -49,11 +50,18 @@ async def fetch_one( if isinstance(query, ClauseElement): query, params = self._compile(query) - row = await self._database.fetch_one(query, params) + with Timer() as timer: + row = await self._database.fetch_one(query, params) + if settings.DEBUG: + time_elapsed = timer.elapsed() log( - f"Executed SQL query: {query} {params}", - extra={"query": query, "params": params}, + f"Executed SQL query: {query} {params} in {time_elapsed:.2f} seconds.", + extra={ + "query": query, + "params": params, + "time_elapsed": time_elapsed, + }, ) return dict(row._mapping) if row is not None else None @@ -66,7 +74,20 @@ async def fetch_all( if isinstance(query, ClauseElement): query, params = self._compile(query) - rows = await self._database.fetch_all(query, params) + with Timer() as timer: + rows = await self._database.fetch_all(query, params) + + if settings.DEBUG: + time_elapsed = timer.elapsed() + log( + f"Executed SQL query: {query} {params} in {time_elapsed:.2f} seconds.", + extra={ + "query": query, + "params": params, + "time_elapsed": time_elapsed, + }, + ) + return [dict(row._mapping) for row in rows] async def fetch_val( @@ -78,14 +99,40 @@ async def fetch_val( if isinstance(query, ClauseElement): query, params = self._compile(query) - val = await self._database.fetch_val(query, params, column) + with Timer() as timer: + val = await self._database.fetch_val(query, params, column) + + if settings.DEBUG: + time_elapsed = timer.elapsed() + log( + f"Executed SQL query: {query} {params} in {time_elapsed:.2f} seconds.", + extra={ + "query": query, + "params": params, + "time_elapsed": time_elapsed, + }, + ) + return val async def execute(self, query: MySQLQuery, params: MySQLParams = None) -> int: if isinstance(query, ClauseElement): query, params = self._compile(query) - rec_id = await self._database.execute(query, params) + with Timer() as timer: + rec_id = await self._database.execute(query, params) + + if settings.DEBUG: + time_elapsed = timer.elapsed() + log( + f"Executed SQL query: {query} {params} in {time_elapsed:.2f} seconds.", + extra={ + "query": query, + "params": params, + "time_elapsed": time_elapsed, + }, + ) + return cast(int, rec_id) # NOTE: this accepts str since current execute_many uses are not using alchemy. @@ -94,7 +141,19 @@ async def execute_many(self, query: str, params: list[MySQLParams]) -> None: if isinstance(query, ClauseElement): query, _ = self._compile(query) - await self._database.execute_many(query, params) + with Timer() as timer: + await self._database.execute_many(query, params) + + if settings.DEBUG: + time_elapsed = timer.elapsed() + log( + f"Executed SQL query: {query} {params} in {time_elapsed:.2f} seconds.", + extra={ + "query": query, + "params": params, + "time_elapsed": time_elapsed, + }, + ) def transaction( self, diff --git a/app/timer.py b/app/timer.py new file mode 100644 index 000000000..358926833 --- /dev/null +++ b/app/timer.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import time +from types import TracebackType + + +class Timer: + def __init__(self) -> None: + self.start_time: float | None = None + self.end_time: float | None = None + + def __enter__(self) -> Timer: + self.start_time = time.time() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.end_time = time.time() + + def elapsed(self) -> float: + if self.start_time is None or self.end_time is None: + raise ValueError("Timer has not been started or stopped.") + return self.end_time - self.start_time From 8965739c52ca9049f5a732a2a94acb8299242a4c Mon Sep 17 00:00:00 2001 From: cmyui Date: Wed, 28 Feb 2024 12:02:41 -0500 Subject: [PATCH 46/48] Prefer millisec for timing db latency --- app/adapters/database.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/app/adapters/database.py b/app/adapters/database.py index 404aa1c60..abec747d4 100644 --- a/app/adapters/database.py +++ b/app/adapters/database.py @@ -56,7 +56,7 @@ async def fetch_one( if settings.DEBUG: time_elapsed = timer.elapsed() log( - f"Executed SQL query: {query} {params} in {time_elapsed:.2f} seconds.", + f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.", extra={ "query": query, "params": params, @@ -80,7 +80,7 @@ async def fetch_all( if settings.DEBUG: time_elapsed = timer.elapsed() log( - f"Executed SQL query: {query} {params} in {time_elapsed:.2f} seconds.", + f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.", extra={ "query": query, "params": params, @@ -105,7 +105,7 @@ async def fetch_val( if settings.DEBUG: time_elapsed = timer.elapsed() log( - f"Executed SQL query: {query} {params} in {time_elapsed:.2f} seconds.", + f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.", extra={ "query": query, "params": params, @@ -125,7 +125,7 @@ async def execute(self, query: MySQLQuery, params: MySQLParams = None) -> int: if settings.DEBUG: time_elapsed = timer.elapsed() log( - f"Executed SQL query: {query} {params} in {time_elapsed:.2f} seconds.", + f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.", extra={ "query": query, "params": params, @@ -147,7 +147,7 @@ async def execute_many(self, query: str, params: list[MySQLParams]) -> None: if settings.DEBUG: time_elapsed = timer.elapsed() log( - f"Executed SQL query: {query} {params} in {time_elapsed:.2f} seconds.", + f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.", extra={ "query": query, "params": params, From 23b11246dc6076cc74f9adaa68445990d9f6543f Mon Sep 17 00:00:00 2001 From: Josh Smith Date: Wed, 28 Feb 2024 20:25:26 -0500 Subject: [PATCH 47/48] [v5.2.2] Significantly optimize performance for fetching leaderboards (#642) * Significantly optimize performance for fetching leaderboards * v5.2.2 * bump ver --- migrations/base.sql | 3 ++- migrations/migrations.sql | 4 ++++ pyproject.toml | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/migrations/base.sql b/migrations/base.sql index b3fdcdc28..79bba0a97 100644 --- a/migrations/base.sql +++ b/migrations/base.sql @@ -254,7 +254,8 @@ create index scores_userid_index on scores (userid); create index scores_online_checksum_index on scores (online_checksum); - +create index scores_fetch_leaderboard_generic_index + on scores (map_md5, status, mode); create table startups ( diff --git a/migrations/migrations.sql b/migrations/migrations.sql index 2db215246..95794b63e 100644 --- a/migrations/migrations.sql +++ b/migrations/migrations.sql @@ -471,3 +471,7 @@ create index users_clan_priv_index on users (clan_priv); create index users_country_index on users (country); + +# v5.2.2 +create index scores_fetch_leaderboard_generic_index + on scores (map_md5, status, mode); diff --git a/pyproject.toml b/pyproject.toml index a3869e5bb..b4b629a24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ profile = "black" [tool.poetry] name = "bancho-py" -version = "5.2.1" +version = "5.2.2" description = "An osu! server implementation optimized for maintainability in modern python" authors = ["Akatsuki Team"] license = "MIT" From ae42bf66110484b3cb98b0f99278f297eda160cf Mon Sep 17 00:00:00 2001 From: cmyui Date: Wed, 28 Feb 2024 20:29:06 -0500 Subject: [PATCH 48/48] shebang for makefile --- Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile b/Makefile index e3d1922c0..d82346b6b 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,5 @@ +#!/usr/bin/env make + build: if [ -d ".dbdata" ]; then sudo chmod -R 755 .dbdata; fi docker build -t bancho:latest .