diff --git a/Pipfile b/Pipfile index ae8c5f44..51267049 100644 --- a/Pipfile +++ b/Pipfile @@ -35,6 +35,7 @@ mypy = "*" autoflake = "*" asgi-lifespan = "*" pytest-asyncio = "*" +types-psutil = "*" [requires] python_version = "3.11" diff --git a/Pipfile.lock b/Pipfile.lock index a8366b92..07a43e7d 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "7a06e42ad2785e9e5c1ffade8bf1f418d962033f9b059ac4e23e97fc91423853" + "sha256": "adae0fd08fc6aa77038a451fae53b99c0a6a4700ee72755e420b6bc1953cfa68" }, "pipfile-spec": 6, "requires": { @@ -1166,6 +1166,14 @@ "markers": "python_version >= '3.7'", "version": "==1.3.0" }, + "types-psutil": { + "hashes": [ + "sha256:4e9b219efb625d3d04f6bf106934f87cab49aa41a94b0a3b3089403f47a79228", + "sha256:fec713104d5d143afea7b976cfa691ca1840f5d19e8714a5d02a96ebd061363e" + ], + "index": "pypi", + "version": "==5.9.5.16" + }, "typing-extensions": { "hashes": [ "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0", diff --git a/app/api/domains/cho.py b/app/api/domains/cho.py index 42146e5d..ce5a4c0a 100644 --- a/app/api/domains/cho.py +++ b/app/api/domains/cho.py @@ -83,7 +83,7 @@ @router.get("/") -async def bancho_http_handler(): +async def bancho_http_handler() -> Response: """Handle a request from a web browser.""" new_line = "\n" matches = [m for m in app.state.sessions.matches if m is not None] @@ -109,7 +109,7 @@ async def bancho_http_handler(): @router.get("/online") -async def bancho_view_online_users(): +async def bancho_view_online_users() -> Response: """see who's online""" new_line = "\n" @@ -132,7 +132,7 @@ async def bancho_view_online_users(): @router.get("/matches") -async def bancho_view_matches(): +async def bancho_view_matches() -> Response: """ongoing matches""" new_line = "\n" @@ -173,7 +173,7 @@ async def bancho_handler( request: Request, osu_token: str | None = Header(None), user_agent: Literal["osu!"] = Header(...), -): +) -> Response: ip = app.state.services.ip_resolver.get_ip(request.headers) if osu_token is None: diff --git a/app/api/domains/osu.py b/app/api/domains/osu.py index ec29c55b..3ba90aab 100644 --- a/app/api/domains/osu.py +++ b/app/api/domains/osu.py @@ -49,6 +49,7 @@ from app.constants.clientflags import LastFMFlags from app.constants.gamemodes import GameMode from app.constants.mods import Mods +from app.constants.privileges import Privileges from app.logging import Ansi from app.logging import log from app.logging import printc @@ -57,7 +58,6 @@ from app.objects.beatmap import ensure_local_osu_file from app.objects.beatmap import RankedStatus from app.objects.player import Player -from app.objects.player import Privileges from app.objects.score import Grade from app.objects.score import Score from app.objects.score import SubmissionStatus @@ -174,11 +174,11 @@ async def osuError( exe_hash: str = Form(..., alias="exehash"), config: str = Form(...), screenshot_file: UploadFile | None = File(None, alias="ss"), -): +) -> Response: """Handle an error submitted from the osu! client.""" if not app.settings.DEBUG: # only handle osu-error in debug mode - return + return Response(b"") if username and pw_md5: player = await app.state.sessions.players.from_login( @@ -201,14 +201,16 @@ async def osuError( # TODO: save error in db? + return Response(b"") + @router.post("/web/osu-screenshot.php") async def osuScreenshot( player: Player = Depends(authenticate_player_session(Form, "u", "p")), endpoint_version: int = Form(..., alias="v"), screenshot_file: UploadFile = File(..., alias="ss"), # TODO: why can't i use bytes? -): - with memoryview(await screenshot_file.read()) as screenshot_view: # type: ignore +) -> Response: + with memoryview(await screenshot_file.read()) as screenshot_view: # png sizes: 1080p: ~300-800kB | 4k: ~1-2mB if len(screenshot_view) > (4 * 1024 * 1024): return Response( @@ -247,8 +249,8 @@ async def osuScreenshot( @router.get("/web/osu-getfriends.php") async def osuGetFriends( player: Player = Depends(authenticate_player_session(Query, "u", "h")), -): - return "\n".join(map(str, player.friends)).encode() +) -> Response: + return Response("\n".join(map(str, player.friends)).encode()) def bancho_to_osuapi_status(bancho_status: int) -> int: @@ -265,7 +267,7 @@ def bancho_to_osuapi_status(bancho_status: int) -> int: async def osuGetBeatmapInfo( form_data: models.OsuBeatmapRequestForm, player: Player = Depends(authenticate_player_session(Query, "u", "h")), -): +) -> Response: num_requests = len(form_data.Filenames) + len(form_data.Ids) log(f"{player} requested info for {num_requests} maps.", Ansi.LCYAN) @@ -309,32 +311,32 @@ async def osuGetBeatmapInfo( f"{player} requested map(s) info by id ({form_data.Ids})", ) - return "\n".join(ret).encode() + return Response("\n".join(ret).encode()) @router.get("/web/osu-getfavourites.php") async def osuGetFavourites( player: Player = Depends(authenticate_player_session(Query, "u", "h")), -): +) -> Response: rows = await app.state.services.database.fetch_all( "SELECT setid FROM favourites WHERE userid = :user_id", {"user_id": player.id}, ) - return "\n".join([str(row["setid"]) for row in rows]).encode() + return Response("\n".join([str(row["setid"]) for row in rows]).encode()) @router.get("/web/osu-addfavourite.php") async def osuAddFavourite( player: Player = Depends(authenticate_player_session(Query, "u", "h")), map_set_id: int = Query(..., alias="a"), -): +) -> Response: # check if they already have this favourited. if await app.state.services.database.fetch_one( "SELECT 1 FROM favourites WHERE userid = :user_id AND setid = :set_id", {"user_id": player.id, "set_id": map_set_id}, ): - return b"You've already favourited this beatmap!" + return Response(b"You've already favourited this beatmap!") # add favourite await app.state.services.database.execute( @@ -342,7 +344,7 @@ async def osuAddFavourite( {"user_id": player.id, "set_id": map_set_id}, ) - return b"Added favourite!" + return Response(b"Added favourite!") @router.get("/web/lastfm.php") @@ -357,11 +359,11 @@ async def lastFM( alias="b", ), player: Player = Depends(authenticate_player_session(Query, "us", "ha")), -): +) -> Response: if beatmap_id_or_hidden_flag[0] != "a": # not anticheat related, tell the # client not to send any more for now. - return b"-3" + return Response(b"-3") flags = LastFMFlags(int(beatmap_id_or_hidden_flag[1:])) @@ -378,7 +380,7 @@ async def lastFM( if player.is_online: player.logout() - return b"-3" + return Response(b"-3") if flags & LastFMFlags.REGISTRY_EDITS: # Player has registry edits left from @@ -397,7 +399,7 @@ async def lastFM( if player.is_online: player.logout() - return b"-3" + return Response(b"-3") # TODO: make a tool to remove the flags & send this as a dm. # also add to db so they never are restricted on first one. @@ -416,7 +418,7 @@ async def lastFM( player.logout() - return b"-3" + return Response(b"-3") """ These checks only worked for ~5 hours from release. rumoi's quick! if flags & ( @@ -432,6 +434,8 @@ async def lastFM( pass """ + return Response(b"") + DIRECT_SET_INFO_FMTSTR = ( "{SetID}.osz|{Artist}|{Title}|{Creator}|" @@ -453,7 +457,7 @@ async def osuSearchHandler( query: str = Query(..., alias="q"), mode: int = Query(..., alias="m", ge=-1, le=3), # -1 for all page_num: int = Query(..., alias="p"), -): +) -> Response: params: dict[str, Any] = {"amount": 100, "offset": page_num * 100} # eventually we could try supporting these, @@ -473,7 +477,7 @@ async def osuSearchHandler( params=params, ) if response.status_code != status.HTTP_200_OK: - return b"-1\nFailed to retrieve data from the beatmap mirror." + return Response(b"-1\nFailed to retrieve data from the beatmap mirror.") result = response.json() @@ -528,7 +532,7 @@ def handle_invalid_characters(s: str) -> str: ), ) - return "\n".join(ret).encode() + return Response("\n".join(ret).encode()) # TODO: video support (needs db change) @@ -537,7 +541,7 @@ async def osuSearchSetHandler( player: Player = Depends(authenticate_player_session(Query, "u", "h")), map_set_id: int | None = Query(None, alias="s"), map_id: int | None = Query(None, alias="b"), -): +) -> Response: # TODO: refactor this to use the new internal bmap(set) api # Since we only need set-specific data, we can basically @@ -549,7 +553,7 @@ async def osuSearchSetHandler( elif map_id is not None: k, v = ("id", map_id) else: - return # invalid args + return Response(b"") # invalid args # Get all set data. rec = await app.state.services.database.fetch_one( @@ -561,18 +565,18 @@ async def osuSearchSetHandler( if rec is None: # TODO: get from osu! - return None + return Response(b"") bmapset = dict(rec._mapping) - return ( + return Response( ( "{set_id}.osz|{artist}|{title}|{creator}|" "{status}|10.0|{last_update}|{set_id}|" # TODO: rating "0|0|0|0|0" ) .format(**bmapset) - .encode() + .encode(), ) # 0s are threadid, has_vid, has_story, filesize, filesize_novid @@ -654,7 +658,7 @@ async def osuSubmitModularSelector( # TODO: do these need to be Optional? # TODO: validate this is actually what it is fl_cheat_screenshot: bytes | None = File(None, alias="i"), -): +) -> Response: """Handle a score submission from an osu! client with an active session.""" if fl_cheat_screenshot: @@ -666,7 +670,7 @@ async def osuSubmitModularSelector( # starlette/fastapi do not support this, so we've moved it out score_parameters = parse_form_data_score_params(await request.form()) if score_parameters is None: - return + return Response(b"") # extract the score data and replay file from the score data score_data_b64, replay_file = score_parameters @@ -685,7 +689,7 @@ async def osuSubmitModularSelector( bmap = await Beatmap.from_md5(bmap_md5) if not bmap: # Map does not exist, most likely unsubmitted. - return b"error: beatmap" + return Response(b"error: beatmap") # if the client has supporter, a space is appended # but usernames may also end with a space, which must be preserved @@ -697,7 +701,7 @@ async def osuSubmitModularSelector( if not player: # Player is not online, return nothing so that their # client will retry submission when they log in. - return + return Response(b"") # parse the score from the remaining data score = Score.from_submission(score_data[2:]) @@ -783,7 +787,7 @@ async def osuSubmitModularSelector( {"checksum": score.client_checksum}, ): log(f"{score.player} submitted a duplicate score.", Ansi.LYELLOW) - return b"error: no" + return Response(b"error: no") # all data read from submission. # now we can calculate things based on our data. @@ -888,6 +892,7 @@ async def osuSubmitModularSelector( ), ) + assert announce_chan is not None announce_chan.send(" ".join(ann), sender=score.player, to_self=True) # this score is our best score. @@ -944,8 +949,8 @@ async def osuSubmitModularSelector( MIN_REPLAY_SIZE = 24 if len(replay_data) >= MIN_REPLAY_SIZE: - replay_file = REPLAYS_PATH / f"{score.id}.osr" - replay_file.write_bytes(replay_data) + replay_disk_file = REPLAYS_PATH / f"{score.id}.osr" + replay_disk_file.write_bytes(replay_data) else: log(f"{score.player} submitted a score without a replay!", Ansi.LRED) @@ -1185,7 +1190,7 @@ async def osuSubmitModularSelector( Ansi.LGREEN, ) - return response + return Response(response) @router.get("/web/osu-getreplay.php") @@ -1193,14 +1198,14 @@ async def getReplay( player: Player = Depends(authenticate_player_session(Query, "u", "h")), mode: int = Query(..., alias="m", ge=0, le=3), score_id: int = Query(..., alias="c", min=0, max=9_223_372_036_854_775_807), -): +) -> Response: score = await Score.from_sql(score_id) if not score: - return + return Response(b"", status_code=404) file = REPLAYS_PATH / f"{score_id}.osr" if not file.exists(): - return + return Response(b"", status_code=404) # increment replay views for this score if score.player is not None and player.id != score.player.id: @@ -1216,18 +1221,18 @@ async def osuRate( ), map_md5: str = Query(..., alias="c", min_length=32, max_length=32), rating: int | None = Query(None, alias="v", ge=1, le=10), -): +) -> Response: if rating is None: # check if we have the map in our cache; # if not, the map probably doesn't exist. if map_md5 not in app.state.cache.beatmap: - return b"no exist" + return Response(b"no exist") cached = app.state.cache.beatmap[map_md5] # only allow rating on maps with a leaderboard. if cached.status < RankedStatus.Ranked: - return b"not ranked" + return Response(b"not ranked") # osu! client is checking whether we can rate the map or not. has_previous_rating = ( @@ -1241,7 +1246,7 @@ async def osuRate( # the client hasn't rated the map, so simply # tell them that they can submit a rating. if not has_previous_rating: - return b"ok" + return Response(b"ok") else: # the client is submitting a rating for the map. await app.state.services.database.execute( @@ -1259,7 +1264,7 @@ async def osuRate( # send back the average rating avg = sum(ratings) / len(ratings) - return f"alreadyvoted\n{avg}".encode() + return Response(f"alreadyvoted\n{avg}".encode()) @unique @@ -1322,7 +1327,7 @@ async def get_leaderboard_scores( if score_rows: # None or [] # fetch player's personal best score - personal_best_score_row = await app.state.services.database.fetch_one( + personal_best_score_rec = await app.state.services.database.fetch_one( f"SELECT id, {scoring_metric} AS _score, " "max_combo, n50, n100, n300, " "nmiss, nkatu, ngeki, perfect, mods, " @@ -1334,7 +1339,9 @@ async def get_leaderboard_scores( {"map_md5": map_md5, "mode": mode, "user_id": player.id}, ) - if personal_best_score_row: + if personal_best_score_rec is not None: + personal_best_score_row = dict(personal_best_score_rec._mapping) + # calculate the rank of the score. p_best_rank = 1 + await app.state.services.database.fetch_val( "SELECT COUNT(*) FROM scores s " @@ -1351,7 +1358,6 @@ async def get_leaderboard_scores( ) # attach rank to personal best row - personal_best_score_row = dict(personal_best_score_row._mapping) personal_best_score_row["rank"] = p_best_rank else: personal_best_score_row = None @@ -1382,7 +1388,7 @@ async def getScores( mods_arg: int = Query(..., alias="mods", ge=0, le=2_147_483_647), map_package_hash: str = Query(..., alias="h"), # TODO: further validation aqn_files_found: bool = Query(..., alias="a"), -): +) -> Response: if aqn_files_found: stacktrace = app.utils.get_appropriate_stacktrace() await app.state.services.log_strange_occurrence(stacktrace) @@ -1390,9 +1396,9 @@ async def getScores( # check if this md5 has already been cached as # unsubmitted/needs update to reduce osu!api spam if map_md5 in app.state.cache.unsubmitted: - return b"-1|false" + return Response(b"-1|false") if map_md5 in app.state.cache.needs_update: - return b"1|false" + return Response(b"1|false") if mods_arg & Mods.RELAX: if mode_arg == 3: # rx!mania doesn't exist @@ -1431,7 +1437,7 @@ async def getScores( if has_set_id and map_set_id not in app.state.cache.beatmapset: # set not cached, it doesn't exist app.state.cache.unsubmitted.add(map_md5) - return b"-1|false" + return Response(b"-1|false") map_filename = unquote_plus(map_filename) # TODO: is unquote needed? @@ -1457,13 +1463,13 @@ async def getScores( if map_exists: # map can be updated. app.state.cache.needs_update.add(map_md5) - return b"1|false" + return Response(b"1|false") else: # map is unsubmitted. # add this map to the unsubmitted cache, so # that we don't have to make this request again. app.state.cache.unsubmitted.add(map_md5) - return b"-1|false" + return Response(b"-1|false") # we've found a beatmap for the request. @@ -1473,7 +1479,7 @@ async def getScores( if bmap.status < RankedStatus.Ranked: # only show leaderboards for ranked, # approved, qualified, or loved maps. - return f"{int(bmap.status)}|false".encode() + return Response(f"{int(bmap.status)}|false".encode()) # fetch scores & personal best # TODO: create a leaderboard cache @@ -1508,7 +1514,7 @@ async def getScores( if not score_rows: response_lines.extend(("", "")) # no scores, no personal best - return "\n".join(response_lines).encode() + return Response("\n".join(response_lines).encode()) if personal_best_score_row is not None: response_lines.append( @@ -1535,7 +1541,7 @@ async def getScores( ], ) - return "\n".join(response_lines).encode() + return Response("\n".join(response_lines).encode()) @router.post("/web/osu-comment.php") @@ -1551,7 +1557,7 @@ async def osuComment( colour: str | None = Form(None, alias="f", min_length=6, max_length=6), start_time: int | None = Form(None, alias="starttime"), comment: str | None = Form(None, min_length=1, max_length=80), -): +) -> Response: if action == "get": # client is requesting all comments comments = [ @@ -1591,7 +1597,7 @@ async def osuComment( ) player.update_latest_activity_soon() - return "\n".join(ret).encode() + return Response("\n".join(ret).encode()) elif action == "post": # client is submitting a new comment @@ -1627,7 +1633,7 @@ async def osuComment( player.update_latest_activity_soon() - return Response(content=b"") # empty resp is fine + return Response(b"") # empty resp is fine @router.get("/web/osu-markasread.php") @@ -1637,7 +1643,7 @@ async def osuMarkAsRead( ) -> Response: target_name = unquote(channel) # TODO: unquote needed? if not target_name: - return Response(content=b"") # no channel specified + return Response(b"") # no channel specified target = await app.state.sessions.players.from_cache_or_sql(name=target_name) if target: @@ -1649,7 +1655,7 @@ async def osuMarkAsRead( {"to": player.id, "from": target.id}, ) - return Response(content=b"") + return Response(b"") @router.get("/web/osu-getseasonal.php") @@ -1668,7 +1674,7 @@ async def banchoConnect( client_hash: str | None = Query(None, alias="ch"), retrying: bool | None = Query(None, alias="retry"), # '0' or '1' ) -> Response: - return Response(content=b"") # TODO + return Response(b"") # TODO _checkupdates_cache = { # default timeout is 1h, set on request. @@ -1685,7 +1691,7 @@ async def checkUpdates( action: Literal["check", "path", "error"], stream: Literal["cuttingedge", "stable40", "beta40", "stable"], ) -> Response: - return Response(content=b"") + return Response(b"") """ Misc handlers """ diff --git a/app/api/init_api.py b/app/api/init_api.py index 56cdd1ee..ecbf68a2 100644 --- a/app/api/init_api.py +++ b/app/api/init_api.py @@ -22,7 +22,7 @@ import app.settings import app.state import app.utils -from app.api import api_router +from app.api import api_router # type: ignore[attr-defined] from app.api import domains from app.api import middlewares from app.logging import Ansi diff --git a/app/api/v2/clans.py b/app/api/v2/clans.py index 58326b67..ae4d6e11 100644 --- a/app/api/v2/clans.py +++ b/app/api/v2/clans.py @@ -6,6 +6,7 @@ from fastapi.param_functions import Query from app.api.v2.common import responses +from app.api.v2.common.responses import Failure from app.api.v2.common.responses import Success from app.api.v2.models.clans import Clan from app.repositories import clans as clans_repo @@ -17,7 +18,7 @@ async def get_clans( page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=100), -) -> Success[list[Clan]]: +) -> Success[list[Clan]] | Failure: clans = await clans_repo.fetch_many( page=page, page_size=page_size, @@ -36,7 +37,7 @@ async def get_clans( @router.get("/clans/{clan_id}") -async def get_clan(clan_id: int) -> Success[Clan]: +async def get_clan(clan_id: int) -> Success[Clan] | Failure: data = await clans_repo.fetch_one(id=clan_id) if data is None: return responses.failure( diff --git a/app/api/v2/common/responses.py b/app/api/v2/common/responses.py index a25cf4a9..8017b937 100644 --- a/app/api/v2/common/responses.py +++ b/app/api/v2/common/responses.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Any +from typing import cast from typing import Generic from typing import Literal from typing import TypeVar @@ -20,21 +21,21 @@ class Success(BaseModel, Generic[T]): def success( - content: Any, + content: T, status_code: int = 200, headers: dict[str, Any] | None = None, meta: dict[str, Any] | None = None, -) -> Any: +) -> Success[T]: if meta is None: meta = {} data = {"status": "success", "data": content, "meta": meta} - return json.ORJSONResponse(data, status_code, headers) + # XXX:HACK to make typing work + return cast(Success[T], json.ORJSONResponse(data, status_code, headers)) -class ErrorResponse(BaseModel, Generic[T]): +class Failure(BaseModel): status: Literal["error"] - error: T - message: str + error: str def failure( @@ -42,6 +43,7 @@ def failure( message: str, status_code: int = 400, headers: dict[str, Any] | None = None, -) -> Any: +) -> Failure: data = {"status": "error", "error": message} - return json.ORJSONResponse(data, status_code, headers) + # XXX:HACK to make typing work + return cast(Failure, json.ORJSONResponse(data, status_code, headers)) diff --git a/app/api/v2/maps.py b/app/api/v2/maps.py index 07cbe42f..98b87a5f 100644 --- a/app/api/v2/maps.py +++ b/app/api/v2/maps.py @@ -6,6 +6,7 @@ from fastapi.param_functions import Query from app.api.v2.common import responses +from app.api.v2.common.responses import Failure from app.api.v2.common.responses import Success from app.api.v2.models.maps import Map from app.repositories import maps as maps_repo @@ -25,7 +26,7 @@ async def get_maps( frozen: bool | None = None, page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=100), -) -> Success[list[Map]]: +) -> Success[list[Map]] | Failure: maps = await maps_repo.fetch_many( server=server, set_id=set_id, @@ -62,7 +63,7 @@ async def get_maps( @router.get("/maps/{map_id}") -async def get_map(map_id: int) -> Success[Map]: +async def get_map(map_id: int) -> Success[Map] | Failure: data = await maps_repo.fetch_one(id=map_id) if data is None: return responses.failure( diff --git a/app/api/v2/players.py b/app/api/v2/players.py index 4488098a..d8aad97c 100644 --- a/app/api/v2/players.py +++ b/app/api/v2/players.py @@ -7,6 +7,7 @@ import app.state.sessions from app.api.v2.common import responses +from app.api.v2.common.responses import Failure from app.api.v2.common.responses import Success from app.api.v2.models.players import Player from app.api.v2.models.players import PlayerStats @@ -27,7 +28,7 @@ async def get_players( play_style: int | None = None, page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=100), -) -> Success[list[Player]]: +) -> Success[list[Player]] | Failure: players = await players_repo.fetch_many( priv=priv, country=country, @@ -60,7 +61,7 @@ async def get_players( @router.get("/players/{player_id}") -async def get_player(player_id: int) -> Success[Player]: +async def get_player(player_id: int) -> Success[Player] | Failure: data = await players_repo.fetch_one(id=player_id) if data is None: return responses.failure( @@ -73,7 +74,7 @@ async def get_player(player_id: int) -> Success[Player]: @router.get("/players/{player_id}/status") -async def get_player_status(player_id: int) -> Success[PlayerStatus]: +async def get_player_status(player_id: int) -> Success[PlayerStatus] | Failure: player = app.state.sessions.players.get(id=player_id) if not player: @@ -94,7 +95,10 @@ async def get_player_status(player_id: int) -> Success[PlayerStatus]: @router.get("/players/{player_id}/stats/{mode}") -async def get_player_mode_stats(player_id: int, mode: int) -> Success[PlayerStats]: +async def get_player_mode_stats( + player_id: int, + mode: int, +) -> Success[PlayerStats] | Failure: data = await stats_repo.fetch_one(player_id, mode) if data is None: return responses.failure( @@ -111,7 +115,7 @@ async def get_player_stats( player_id: int, page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=100), -) -> Success[list[PlayerStats]]: +) -> Success[list[PlayerStats]] | Failure: data = await stats_repo.fetch_many( player_id=player_id, page=page, diff --git a/app/api/v2/scores.py b/app/api/v2/scores.py index 001ce275..e81f8817 100644 --- a/app/api/v2/scores.py +++ b/app/api/v2/scores.py @@ -6,6 +6,7 @@ from fastapi.param_functions import Query from app.api.v2.common import responses +from app.api.v2.common.responses import Failure from app.api.v2.common.responses import Success from app.api.v2.models.scores import Score from app.repositories import scores as scores_repo @@ -22,7 +23,7 @@ async def get_all_scores( user_id: int | None = None, page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=100), -) -> Success[list[Score]]: +) -> Success[list[Score]] | Failure: scores = await scores_repo.fetch_many( map_md5=map_md5, mods=mods, @@ -53,7 +54,7 @@ async def get_all_scores( @router.get("/scores/{score_id}") -async def get_score(score_id: int) -> Success[Score]: +async def get_score(score_id: int) -> Success[Score] | Failure: data = await scores_repo.fetch_one(id=score_id) if data is None: return responses.failure( diff --git a/app/commands.py b/app/commands.py index 6c72c14e..9810e654 100644 --- a/app/commands.py +++ b/app/commands.py @@ -25,7 +25,6 @@ from typing import Optional from typing import TYPE_CHECKING from typing import TypedDict -from typing import TypeVar from urllib.parse import urlparse import psutil @@ -66,9 +65,6 @@ from app.objects.channel import Channel -R = TypeVar("R") - - BEATMAPS_PATH = Path.cwd() / ".data/osu" @@ -845,13 +841,15 @@ async def user(ctx: Context) -> str | None: player = ctx.player else: # username given, fetch the player - player = await app.state.sessions.players.from_cache_or_sql( + maybe_player = await app.state.sessions.players.from_cache_or_sql( name=" ".join(ctx.args), ) - if not player: + if maybe_player is None: return "Player not found." + player = maybe_player + priv_list = [ priv.name for priv in Privileges @@ -863,7 +861,7 @@ async def user(ctx: Context) -> str | None: last_np = None if player.is_online and player.client_details is not None: - osu_version = player.client_details.osu_version.date + osu_version = player.client_details.osu_version.date.isoformat() else: osu_version = "Unknown" @@ -1025,6 +1023,7 @@ async def shutdown(ctx: Context) -> str | None | NoReturn: return f"Enqueued {ctx.trigger}." else: # shutdown immediately os.kill(os.getpid(), _signal) + return "Process killed" """ Developer commands @@ -1202,11 +1201,12 @@ async def reload(ctx: Context) -> str | None: except ModuleNotFoundError: return "Module not found." + child = None try: for child in children: mod = getattr(mod, child) except AttributeError: - return f"Failed at {child}." # type: ignore + return f"Failed at {child}." try: mod = importlib.reload(mod) @@ -1264,7 +1264,7 @@ async def server(ctx: Context) -> str | None: requirements = [] for dist in importlib.metadata.distributions(): - requirements.append(f"{dist.name} v{dist.version}") # type: ignore + requirements.append(f"{dist.name} v{dist.version}") requirements.sort(key=lambda x: x.casefold()) requirements_info = "\n".join( @@ -1341,7 +1341,7 @@ async def py(ctx: Context) -> str | None: if not isinstance(ret, str): ret = pprint.pformat(ret, compact=True) - return ret + return str(ret) """ Multiplayer commands @@ -1351,10 +1351,10 @@ async def py(ctx: Context) -> str | None: def ensure_match( - f: Callable[[Context, Match], Awaitable[R | None]], -) -> Callable[[Context], Awaitable[R | None]]: + f: Callable[[Context, Match], Awaitable[str | None]], +) -> Callable[[Context], Awaitable[str | None]]: @wraps(f) - async def wrapper(ctx: Context) -> R | None: + async def wrapper(ctx: Context) -> str | None: match = ctx.player.match # multi set is a bit of a special case, @@ -1367,11 +1367,11 @@ async def wrapper(ctx: Context) -> R | None: # message not in match channel return None - if f is not mp_help and ( - ctx.player not in match.refs - and not ctx.player.priv & Privileges.TOURNEY_MANAGER + if not ( + ctx.player in match.refs + or ctx.player.priv & Privileges.TOURNEY_MANAGER + or f is mp_help.__wrapped__ # type: ignore[attr-defined] ): - # doesn't have privs to use !mp commands (allow help). return None return await f(ctx, match) @@ -2096,13 +2096,12 @@ async def pool_create(ctx: Context) -> str | None: ) # add to cache (get from sql for id & time) - row = await app.state.services.database.fetch_one( + rec = await app.state.services.database.fetch_one( "SELECT * FROM tourney_pools WHERE name = :name", {"name": name}, ) - assert row is not None - - row = dict(row) # make mutable copy + assert rec is not None + row = dict(rec._mapping) pool_creator = await app.state.sessions.players.from_cache_or_sql( id=row["created_by"], @@ -2328,7 +2327,7 @@ async def clan_create(ctx: Context) -> str | None: created_at = datetime.now() # add clan to sql - clan = await clans_repo.create( + persisted_clan = await clans_repo.create( name=name, tag=tag, owner=ctx.player.id, @@ -2336,7 +2335,7 @@ async def clan_create(ctx: Context) -> str | None: # add clan to cache clan = Clan( - id=clan["id"], + id=persisted_clan["id"], name=name, tag=tag, created_at=created_at, @@ -2428,7 +2427,7 @@ async def clan_info(ctx: Context) -> str | None: @clan_commands.add(Privileges.UNRESTRICTED) -async def clan_leave(ctx: Context): +async def clan_leave(ctx: Context) -> str | None: """Leaves the clan you're in.""" if not ctx.player.clan: return "You're not in a clan." diff --git a/app/repositories/ingame_logins.py b/app/repositories/ingame_logins.py index 541f2451..5c6ce1b0 100644 --- a/app/repositories/ingame_logins.py +++ b/app/repositories/ingame_logins.py @@ -1,6 +1,7 @@ from __future__ import annotations import textwrap +from datetime import date from datetime import datetime from typing import Any from typing import cast @@ -30,7 +31,7 @@ class IngameLogin(TypedDict): id: int userid: str ip: str - osu_ver: str + osu_ver: date osu_stream: str datetime: datetime @@ -38,14 +39,14 @@ class IngameLogin(TypedDict): class InGameLoginUpdateFields(TypedDict, total=False): userid: str ip: str - osu_ver: str + osu_ver: date osu_stream: str async def create( user_id: int, ip: str, - osu_ver: str, + osu_ver: date, osu_stream: str, ) -> IngameLogin: """Create a new login entry in the database.""" @@ -113,7 +114,7 @@ async def fetch_count( async def fetch_many( user_id: int | None = None, ip: str | None = None, - osu_ver: str | None = None, + osu_ver: date | None = None, osu_stream: str | None = None, page: int | None = None, page_size: int | None = None, diff --git a/mypy.ini b/mypy.ini index 3631598b..5d0e5abd 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,19 +1,28 @@ [mypy] strict = True disallow_untyped_calls = False -exclude = ["tests"] -[mypy-aiomysql] +[mypy-tests.*] +disable_error_code = var-annotated, has-type +allow_untyped_defs = True + +[mypy-aiomysql.*] +ignore_missing_imports = True + +[mypy-pymysql.*] +ignore_missing_imports = True + +[mypy-mitmproxy.*] ignore_missing_imports = True -[mypy-pymysql] +[mypy-py3rijndael.*] ignore_missing_imports = True -[mypy-mitmproxy] +[mypy-requests.*] ignore_missing_imports = True -[mypy-py3rijndael] +[mypy-timeago.*] ignore_missing_imports = True -[mypy-requests] +[mypy-pytimeparse.*] ignore_missing_imports = True diff --git a/requirements-dev.txt b/requirements-dev.txt index fe5a3e45..3a273c62 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -24,6 +24,7 @@ pyyaml==6.0.1; python_version >= '3.6' reorder-python-imports==3.11.0; python_version >= '3.8' setuptools==68.2.2; python_version >= '3.8' sniffio==1.3.0; python_version >= '3.7' +types-psutil==5.9.5.16 typing-extensions==4.8.0; python_version >= '3.8' virtualenv==20.24.5; python_version >= '3.7' aiomysql==0.2.0