From 686123eae37563c2b48057e8e4c572013152e1dd Mon Sep 17 00:00:00 2001 From: arily Date: Thu, 7 Dec 2023 23:56:25 +0900 Subject: [PATCH] Squashed commit of the following: commit c1a0c54fe472f35b461a7dd5e3e490a84cf88fe7 Author: arily Date: Thu Dec 7 23:34:51 2023 +0900 * commit 3ee1d4c6857002a4ba0f03315b05a14cbd6f5fd3 Author: arily Date: Thu Dec 7 23:31:01 2023 +0900 WHERE commit c4aa1e0ddbcb6d921b7a36d5b234a13c4687c145 Author: arily Date: Thu Dec 7 23:22:55 2023 +0900 update stats commit be434e103911f3d3fe137b10619a18f572bf889d Author: arily Date: Thu Dec 7 23:09:23 2023 +0900 AND commit feb6001ae61e4b7347700f7256d2a165962c6efe Author: arily Date: Thu Dec 7 22:47:05 2023 +0900 use eq var commit 710dc3698bbd8f0b46a1bad33f7c4bacbe334726 Author: arily Date: Thu Dec 7 22:46:39 2023 +0900 apply use query builder commit 4e11a44175cb4c7b36b87cf9a5d48930e4a1f995 Author: arily Date: Thu Dec 7 21:48:10 2023 +0900 lib --- app/api/v1/api.py | 24 ++- app/query_builder.py | 255 ++++++++++++++++++++++++++++++ app/repositories/achievements.py | 28 ++-- app/repositories/channels.py | 80 ++++------ app/repositories/clans.py | 29 ++-- app/repositories/ingame_logins.py | 53 +++---- app/repositories/maps.py | 102 +++++------- app/repositories/players.py | 113 ++++++------- app/repositories/scores.py | 74 ++++----- app/repositories/stats.py | 72 +++++---- 10 files changed, 508 insertions(+), 322 deletions(-) create mode 100644 app/query_builder.py diff --git a/app/api/v1/api.py b/app/api/v1/api.py index bfb7deb84..23225a1e2 100644 --- a/app/api/v1/api.py +++ b/app/api/v1/api.py @@ -21,6 +21,7 @@ from app.constants import regexes from app.constants.gamemodes import GameMode from app.constants.mods import Mods +from app.constants.privileges import Privileges from app.objects.beatmap import Beatmap from app.objects.beatmap import ensure_local_osu_file from app.objects.clan import Clan @@ -30,6 +31,8 @@ from app.repositories import stats as stats_repo from app.usecases.performance import ScoreParams +from app.query_builder import build as bq, sql + AVATARS_PATH = SystemPath.cwd() / ".data/avatars" BEATMAPS_PATH = SystemPath.cwd() / ".data/osu" REPLAYS_PATH = SystemPath.cwd() / ".data/osr" @@ -194,20 +197,25 @@ async def api_calculate_pp( ) +VisibleUser = Privileges.UNRESTRICTED | Privileges.VERIFIED + + @router.get("/search_players") async def api_search_players( search: str | None = Query(None, alias="q", min=2, max=32), ) -> Response: """Search for users on the server by name.""" - rows = await app.state.services.database.fetch_all( - "SELECT id, name " - "FROM users " - "WHERE name LIKE COALESCE(:name, name) " - "AND priv & 3 = 3 " - "ORDER BY id ASC", - {"name": f"%{search}%" if search is not None else None}, + + query, params = bq( + sql("SELECT id, name FROM users WHERE 1 = 1"), + (search, lambda: (f"%{search}%", sql("AND name LIKE :name"))), + sql(f"AND priv & {VisibleUser} = {VisibleUser}"), + sql("ORDER BY id ASC"), + sql("LIMIT 100"), ) + rows = await app.state.services.database.fetch_all(query, params) + return ORJSONResponse( { "status": "success", @@ -679,7 +687,7 @@ async def api_get_map_scores( "WHERE s.map_md5 = :map_md5 " "AND s.mode = :mode " "AND s.status = 2 " - "AND u.priv & 1", + "AND u.priv & 1" ] params: dict[str, object] = { "map_md5": bmap.md5, diff --git a/app/query_builder.py b/app/query_builder.py new file mode 100644 index 000000000..a0b6a070b --- /dev/null +++ b/app/query_builder.py @@ -0,0 +1,255 @@ +import traceback +import random +import string +from typing import Union, Callable, Tuple, Dict, Optional + + +class SQLLiteral: + def __init__(self, value: str): + self.value = value + + +# Types +DatabaseAllowedNotNull = Union[str, int, bool, float] +Value = Union[DatabaseAllowedNotNull, None] +SQLValueWithTemplate = Tuple[Value, str] +SQLValueWithNested = Tuple[Value, "SQLPart"] +BuiltSQL = SubQuery = Tuple[str, Dict[str, Value]] +SQLType = Union[ + SQLValueWithTemplate, SQLValueWithNested, BuiltSQL, SQLLiteral, "OptionalSQL" +] +SQLPart = Union[SQLType, Tuple[SQLType, ...], Callable[[], SQLType]] + + +class Nullable: + def __init__(self, value: Value): + self.value = value + + +class OptionalSQL: + def __init__(self, value: Value, sql: SQLPart): + self.value = value + self.sql = sql + + +def sql(value: str) -> SQLLiteral: + return SQLLiteral(value) + + +def equals_variable(name: str, param_key: str | None = None) -> SQLLiteral: + return SQLLiteral( + f"`{name}` = :{param_key if param_key is not None else generate_random_string(5)}" + ) + + +def generate_random_string(length): + # Define the characters to use in the string + characters = string.ascii_letters + string.digits + # Generate a random string of the specified length + random_string = "".join(random.choice(characters) for i in range(length)) + return random_string + + +def table(value: str) -> SQLLiteral: + return SQLLiteral(f"`{value}`") + + +def nullable(value: DatabaseAllowedNotNull | None) -> Nullable: + return Nullable(value) + + +def optional_param(value: Value, sql: SQLPart) -> OptionalSQL: + return OptionalSQL(value, sql) + + +def WHERE(*parts: SQLPart) -> SQLPart: + return (sql("WHERE"), parts) + + +def UPDATE(table: SQLLiteral, update_set: BuiltSQL, cond: SQLPart) -> BuiltSQL: + _sql, param = update_set + if _sql is None: + return None + + built, bparam = build(cond) + if built is None: + return None + + param.update(bparam) + return ("UPDATE " + table.value + " SET " + _sql + " " + built, bparam) + + +def SET(*parts: SQLPart) -> BuiltSQL: + parameters = {} + query_parts = (_process_query_part(p, parameters) for p in parts) + filtered = [q for q in query_parts if q is not None] + + # Check if filtered is empty or contains only None + if not filtered or all(part is None for part in filtered): + return None, parameters # or some other appropriate return value + return ", ".join(filtered), parameters + + +def AND(*parts: SQLPart) -> BuiltSQL: + update = {} + subq = _process_query_part((sql("AND"), parts), update) + return (subq, update) + + +def OR(*parts: SQLPart) -> BuiltSQL: + update = {} + subq = _process_query_part((sql("OR ("), (parts, sql(")"))), update) + return (subq, update) + + +def build(*parts: SQLPart) -> BuiltSQL: + params = {} + query_parts = (_process_query_part(p, params) for p in parts) + query = " ".join(q for q in query_parts if q is not None) + return query, params + + +def _is_nullable(value) -> bool: + return isinstance(value, Nullable) + + +def _extract_value(value: Value | Nullable) -> Value: + if _is_nullable(value): + return value.value + return value + + +def _process_query_part(part: SQLPart, parameters: Dict[str, Value]) -> Optional[str]: + if callable(part): + part = part() + + match part: + case None: + return None + + case (str(built), dict(params)): + parameters.update(params) + return built + + case SQLLiteral(value=value): + return value + case OptionalSQL(value=value, sql=parts): + update = {} + sql = _process_query_part((value, parts), update) + if sql is None: + return None + parameters.update(update) + return sql + case (str(literal), dict(params)): + parameters.update(params) + return literal + case tuple(parts): + return _process_tuple_part(parts, parameters) + case str(_literal): + raise TypeError( + "Raw strings are not allowed. Use sql() function for literal SQL strings." + ) + case _: + traceback.print_stack() + raise TypeError(f"Unexpected type for query part: {type(part)}, {part}") + + +def _process_tuple_part(part: SQLPart, parameters: Dict[str, Value]) -> Optional[str]: + match part: + case (None, *_): + return None + case (SQLLiteral(value=value),): + return value + case (SQLLiteral(value=value), val): + update = {} + evaluated = _process_query_part(val, update) + if evaluated is None: + return None + parameters.update(update) + return value + " " + evaluated + case ( + Nullable(value=cond) | bool(cond) | str(cond) | int(cond) | float(cond), + val, + ): + evaluated = _process_query_part(val, parameters) + revealed_cond = _extract_value(cond) + if evaluated is None: + return None + if revealed_cond is None and _is_nullable(cond) is False: + return None + if ":" in evaluated and revealed_cond is not None: + parameter_name = evaluated.split(":")[-1].strip().split(" ")[0].strip() + parameters[parameter_name] = revealed_cond + return evaluated + case tuple(_), *_: + parts: list[str] = [] + for elem in part: + return_value = _process_query_part(elem, parameters) + if return_value is None: + return None + parts.append(return_value) + return " ".join(parts) + case _: + raise TypeError(f"Unexpected type for tuple part: {type(part)} {part}") + + +# def test_query( +# mode: int | None = None, +# page: int | None = None, +# player_id: int | None = None, +# page_size: int | None = None, +# ): +# return build( +# sql(f"SELECT 1 FROM stats WHERE 1 = 1"), +# AND(player_id, equals_variable("id")), +# AND(mode, equals_variable("mode")), +# ( +# (page_size, sql("LIMIT :page_size")), +# lambda: ( +# (page - 1) * page_size if page is not None else None, +# sql("OFFSET :offset"), +# ), +# ), +# ) + + +# def test_update( +# id: int, +# _from: Optional[int] = None, +# to: Optional[int] = None, +# action: Optional[str] = None, +# msg: Optional[str] = None, +# time: Optional[str] = None, +# ) -> BuiltSQL: +# """Update a log entry in the database.""" + +# query, params = build( +# UPDATE( +# table("logs"), +# SET( +# optional_param(to, equals_variable("to")), +# optional_param(action, equals_variable("action")), +# optional_param(msg, equals_variable("msg")), +# optional_param(time, equals_variable("time")), +# ), +# ( +# sql("WHERE"), +# equals_variable("id", "id"), +# ), +# ), +# ) +# params.update({"id": id}) if query != "" else None +# return query, params + + +# print("query:empty") +# print(test_query()) +# print() +# print("query:populated") +# print(test_query(player_id=1001, mode=0, page=1, page_size=10)) +# print() +# print("update:empty") +# print(test_update(id=3)) +# print() +# print("update:populated") +# print(test_update(id=3, to=1001, action="test", msg="message", time="datetime")) diff --git a/app/repositories/achievements.py b/app/repositories/achievements.py index fd8119b8e..bcdc36d3e 100644 --- a/app/repositories/achievements.py +++ b/app/repositories/achievements.py @@ -10,6 +10,7 @@ import app.state.services from app._typing import _UnsetSentinel from app._typing import UNSET +from app.query_builder import build as bq, sql, equals_variable, WHERE if TYPE_CHECKING: from app.objects.score import Score @@ -89,16 +90,12 @@ async def fetch_one( if id is None and name is None: raise ValueError("Must provide at least one parameter.") - query = f"""\ - SELECT {READ_PARAMS} - FROM achievements - WHERE id = COALESCE(:id, id) - OR name = COALESCE(:name, name) - """ - params: dict[str, Any] = { - "id": id, - "name": name, - } + query, params = bq( + sql(f"SELECT {READ_PARAMS} FROM achievements WHERE 1 = 1"), + AND(id, equals_variable("id", "id")), + AND(name, equals_variable("name", "name")), + ) + rec = await app.state.services.database.fetch_one(query, params) if rec is None: @@ -171,11 +168,12 @@ async def update( if not isinstance(cond, _UnsetSentinel): update_fields["cond"] = cond - query = f"""\ - UPDATE achievements - SET {",".join(f"{k} = COALESCE(:{k}, {k})" for k in update_fields)} - WHERE id = :id - """ + query, _ = bq( + sql("UPDATE achievements SET"), + sql(",".join(f"{k} = :{k}" for k in update_fields)), + WHERE(equals_variable("id", "id")), + ) + values = {"id": id} | update_fields await app.state.services.database.execute(query, values) diff --git a/app/repositories/channels.py b/app/repositories/channels.py index 04d0962ce..18e7038b5 100644 --- a/app/repositories/channels.py +++ b/app/repositories/channels.py @@ -8,6 +8,7 @@ import app.state.services from app._typing import _UnsetSentinel from app._typing import UNSET +from app.query_builder import build as bq, sql, equals_variable, AND, WHERE # +------------+--------------+------+-----+---------+----------------+ # | Field | Type | Null | Key | Default | Extra | @@ -88,16 +89,13 @@ async def fetch_one( """Fetch a single channel.""" if id is None and name is None: raise ValueError("Must provide at least one parameter.") - query = f"""\ - SELECT {READ_PARAMS} - FROM channels - WHERE id = COALESCE(:id, id) - AND name = COALESCE(:name, name) - """ - params: dict[str, Any] = { - "id": id, - "name": name, - } + + query, params = bq( + sql(f"SELECT {READ_PARAMS} FROM channels WHERE 1 = 1"), + AND(id, equals_variable("id", "id")), + AND(name, equals_variable("name", "name")), + ) + channel = await app.state.services.database.fetch_one(query, params) return cast(Channel, dict(channel._mapping)) if channel is not None else None @@ -111,18 +109,12 @@ async def fetch_count( if read_priv is None and write_priv is None and auto_join is None: raise ValueError("Must provide at least one parameter.") - query = """\ - SELECT COUNT(*) AS count - FROM channels - WHERE read_priv = COALESCE(:read_priv, read_priv) - AND write_priv = COALESCE(:write_priv, write_priv) - AND auto_join = COALESCE(:auto_join, auto_join) - """ - params: dict[str, Any] = { - "read_priv": read_priv, - "write_priv": write_priv, - "auto_join": auto_join, - } + query, params = bq( + sql("SELECT COUNT(*) count FROM channels WHERE 1 = 1"), + AND(read_priv, equals_variable("read_priv", "read_priv")), + AND(write_priv, equals_variable("write_priv", "write_priv")), + AND(auto_join, equals_variable("auto_join", "auto_join")), + ) rec = await app.state.services.database.fetch_one(query, params) assert rec is not None @@ -137,26 +129,19 @@ async def fetch_many( page_size: int | None = None, ) -> list[Channel]: """Fetch multiple channels from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM channels - WHERE read_priv = COALESCE(:read_priv, read_priv) - AND write_priv = COALESCE(:write_priv, write_priv) - AND auto_join = COALESCE(:auto_join, auto_join) - """ - params: dict[str, Any] = { - "read_priv": read_priv, - "write_priv": write_priv, - "auto_join": auto_join, - } - - if page is not None and page_size is not None: - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["limit"] = page_size - params["offset"] = (page - 1) * page_size + query, params = bq( + sql(f"SELECT {READ_PARAMS} FROM channels WHERE 1 = 1"), + AND(read_priv, equals_variable("read_priv", "read_priv")), + AND(write_priv, equals_variable("write_priv", "write_priv")), + AND(auto_join, equals_variable("auto_join", "auto_join")), + ( + (page_size, sql("LIMIT :page_size")), + lambda: ( + (page - 1) * page_size if page is not None else None, + sql("OFFSET :offset"), + ), + ), + ) channels = await app.state.services.database.fetch_all(query, params) return cast(list[Channel], [dict(c._mapping) for c in channels]) @@ -180,11 +165,12 @@ async def update( if not isinstance(auto_join, _UnsetSentinel): update_fields["auto_join"] = auto_join - query = f"""\ - UPDATE channels - SET {",".join(f"{k} = COALESCE(:{k}, {k})" for k in update_fields)} - WHERE name = :name - """ + query, _ = bq( + sql("UPDATE channels SET"), + sql(",".join(f"{k} = :{k}" for k in update_fields)), + WHERE(equals_variable("name", "name")), + ) + values = {"name": name} | update_fields await app.state.services.database.execute(query, values) diff --git a/app/repositories/clans.py b/app/repositories/clans.py index fc11ef672..924f28ea1 100644 --- a/app/repositories/clans.py +++ b/app/repositories/clans.py @@ -9,6 +9,7 @@ import app.state.services from app._typing import _UnsetSentinel from app._typing import UNSET +from app.query_builder import build as bq, sql, equals_variable, AND, WHERE # +------------+-------------+------+-----+---------+----------------+ # | Field | Type | Null | Key | Default | Extra | @@ -82,15 +83,14 @@ async def fetch_one( if id is None and name is None and tag is None and owner is None: raise ValueError("Must provide at least one parameter.") - query = f"""\ - SELECT {READ_PARAMS} - FROM clans - WHERE id = COALESCE(:id, id) - AND name = COALESCE(:name, name) - AND tag = COALESCE(:tag, tag) - AND owner = COALESCE(:owner, owner) - """ - params: dict[str, Any] = {"id": id, "name": name, "tag": tag, "owner": owner} + query, params = bq( + sql(f"SELECT {READ_PARAMS} FROM clans WHERE 1 = 1"), + AND(id, equals_variable("id", "id")), + AND(name, equals_variable("name", "name")), + AND(tag, equals_variable("tag", "tag")), + AND(owner, equals_variable("owner", "owner")), + ) + clan = await app.state.services.database.fetch_one(query, params) return cast(Clan, dict(clan._mapping)) if clan is not None else None @@ -145,11 +145,12 @@ async def update( if not isinstance(owner, _UnsetSentinel): update_fields["owner"] = owner - query = f"""\ - UPDATE clans - SET {",".join(f"{k} = :{k}" for k in update_fields)} - WHERE id = :id - """ + query, _ = bq( + sql("UPDATE clans SET"), + sql(",".join(f"{k} = :{k}" for k in update_fields)), + WHERE(equals_variable("id", "id")), + ) + values = {"id": id} | update_fields await app.state.services.database.execute(query, values) diff --git a/app/repositories/ingame_logins.py b/app/repositories/ingame_logins.py index 5c6ce1b03..a36615292 100644 --- a/app/repositories/ingame_logins.py +++ b/app/repositories/ingame_logins.py @@ -6,6 +6,7 @@ from typing import Any from typing import cast from typing import TypedDict +from app.query_builder import build as bq, sql, equals_variable, AND import app.state.services @@ -96,16 +97,12 @@ async def fetch_count( ip: str | None = None, ) -> int: """Fetch the number of logins in the database.""" - query = """\ - SELECT COUNT(*) AS count - FROM ingame_logins - WHERE userid = COALESCE(:userid, userid) - AND ip = COALESCE(:ip, ip) - """ - params: dict[str, Any] = { - "userid": user_id, - "ip": ip, - } + query, params = bq( + sql("SELECT COUNT(*) count FROM ingame_logins WHERE 1 = 1"), + AND(user_id, equals_variable("userid", "userid")), + AND(ip, equals_variable("ip", "ip")), + ) + rec = await app.state.services.database.fetch_one(query, params) assert rec is not None return cast(int, rec._mapping["count"]) @@ -120,28 +117,20 @@ async def fetch_many( page_size: int | None = None, ) -> list[IngameLogin]: """Fetch a list of logins from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM ingame_logins - WHERE userid = COALESCE(:userid, userid) - AND ip = COALESCE(:ip, ip) - AND osu_ver = COALESCE(:osu_ver, osu_ver) - AND osu_stream = COALESCE(:osu_stream, osu_stream) - """ - params: dict[str, Any] = { - "userid": user_id, - "ip": ip, - "osu_ver": osu_ver, - "osu_stream": osu_stream, - } - - if page is not None and page_size is not None: - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["limit"] = page_size - params["offset"] = (page - 1) * page_size + query, params = bq( + sql(f"SELECT {READ_PARAMS} FROM ingame_logins WHERE 1 = 1"), + AND(user_id, equals_variable("userid", "userid")), + AND(ip, equals_variable("ip", "ip")), + AND(osu_ver, equals_variable("osu_ver", "osu_ver")), + AND(osu_stream, equals_variable("osu_stream", "osu_stream")), + ( + (page_size, sql("LIMIT :page_size")), + lambda: ( + (page - 1) * page_size if page is not None else None, + sql("OFFSET :offset"), + ), + ), + ) ingame_logins = await app.state.services.database.fetch_all(query, params) return cast(list[IngameLogin], ingame_logins) diff --git a/app/repositories/maps.py b/app/repositories/maps.py index 1d55000cb..2262e5171 100644 --- a/app/repositories/maps.py +++ b/app/repositories/maps.py @@ -8,6 +8,7 @@ import app.state.services from app._typing import _UnsetSentinel from app._typing import UNSET +from app.query_builder import build as bq, sql, equals_variable, AND, WHERE # +--------------+------------------------+------+-----+---------+-------+ # | Field | Type | Null | Key | Default | Extra | @@ -183,21 +184,15 @@ async def fetch_one( if id is None and md5 is None and filename is None: raise ValueError("Must provide at least one parameter.") - query = f"""\ - SELECT {READ_PARAMS} - FROM maps - WHERE id = COALESCE(:id, id) - AND md5 = COALESCE(:md5, md5) - AND filename = COALESCE(:filename, filename) - """ - params: dict[str, Any] = { - "id": id, - "md5": md5, - "filename": filename, - } - map = await app.state.services.database.fetch_one(query, params) + query, params = bq( + sql(f"SELECT {READ_PARAMS} FROM maps WHERE 1 = 1"), + AND(id, equals_variable("id", "id")), + AND(md5, equals_variable("md5", "md5")), + AND(filename, equals_variable("filename", "filename")), + ) - return cast(Map, dict(map._mapping)) if map is not None else None + result = await app.state.services.database.fetch_one(query, params) + return cast(Map, dict(result._mapping)) if result is not None else None async def fetch_count( @@ -211,29 +206,19 @@ async def fetch_count( frozen: bool | None = None, ) -> int: """Fetch the number of maps in the database.""" - query = """\ - SELECT COUNT(*) AS count - FROM maps - WHERE server = COALESCE(:server, server) - AND set_id = COALESCE(:set_id, set_id) - AND status = COALESCE(:status, status) - AND artist = COALESCE(:artist, artist) - AND creator = COALESCE(:creator, creator) - AND filename = COALESCE(:filename, filename) - AND mode = COALESCE(:mode, mode) - AND frozen = COALESCE(:frozen, frozen) - """ - params: dict[str, Any] = { - "server": server, - "set_id": set_id, - "status": status, - "artist": artist, - "creator": creator, - "filename": filename, - "mode": mode, - "frozen": frozen, - } + query, params = bq( + sql("SELECT COUNT(*) count FROM maps WHERE 1 = 1"), + AND(server, equals_variable("server", "server")), + AND(set_id, equals_variable("set_id", "set_id")), + AND(status, equals_variable("status", "status")), + AND(artist, equals_variable("artist", "artist")), + AND(creator, equals_variable("creator", "creator")), + AND(filename, equals_variable("filename", "filename")), + AND(mode, equals_variable("mode", "mode")), + AND(frozen, equals_variable("frozen", "frozen")), + ) + rec = await app.state.services.database.fetch_one(query, params) assert rec is not None return cast(int, rec._mapping["count"]) @@ -252,28 +237,18 @@ async def fetch_many( page_size: int | None = None, ) -> list[Map]: """Fetch a list of maps from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM maps - WHERE server = COALESCE(:server, server) - AND set_id = COALESCE(:set_id, set_id) - AND status = COALESCE(:status, status) - AND artist = COALESCE(:artist, artist) - AND creator = COALESCE(:creator, creator) - AND filename = COALESCE(:filename, filename) - AND mode = COALESCE(:mode, mode) - AND frozen = COALESCE(:frozen, frozen) - """ - params: dict[str, Any] = { - "server": server, - "set_id": set_id, - "status": status, - "artist": artist, - "creator": creator, - "filename": filename, - "mode": mode, - "frozen": frozen, - } + + query, params = bq( + sql(f"SELECT {READ_PARAMS} FROM maps WHERE 1 = 1"), + AND(server, equals_variable("server", "server")), + AND(set_id, equals_variable("set_id", "set_id")), + AND(status, equals_variable("status", "status")), + AND(artist, equals_variable("artist", "artist")), + AND(creator, equals_variable("creator", "creator")), + AND(filename, equals_variable("filename", "filename")), + AND(mode, equals_variable("mode", "mode")), + AND(frozen, equals_variable("frozen", "frozen")), + ) if page is not None and page_size is not None: query += """\ @@ -359,11 +334,12 @@ async def update( if not isinstance(diff, _UnsetSentinel): update_fields["diff"] = diff - query = f"""\ - UPDATE maps - SET {",".join(f"{k} = COALESCE(:{k}, {k})" for k in update_fields)} - WHERE id = :id - """ + query, _ = bq( + sql("UPDATE maps SET"), + sql(",".join(f"{k} = :{k}" for k in update_fields)), + WHERE(equals_variable("id", "id")), + ) + values = {"id": id} | update_fields await app.state.services.database.execute(query, values) diff --git a/app/repositories/players.py b/app/repositories/players.py index cd4f1a687..0ee9b4609 100644 --- a/app/repositories/players.py +++ b/app/repositories/players.py @@ -9,6 +9,7 @@ from app._typing import _UnsetSentinel from app._typing import UNSET from app.utils import make_safe_name +from app.query_builder import build as bq, sql, equals_variable, AND, WHERE # +-------------------+---------------+------+-----+---------+----------------+ # | Field | Type | Null | Key | Default | Extra | @@ -128,18 +129,17 @@ async def fetch_one( if id is None and name is None and email is None: raise ValueError("Must provide at least one parameter.") - query = f"""\ - SELECT {'*' if fetch_all_fields else READ_PARAMS} - FROM users - WHERE id = COALESCE(:id, id) - AND safe_name = COALESCE(:safe_name, safe_name) - AND email = COALESCE(:email, email) - """ - params: dict[str, Any] = { - "id": id, - "safe_name": make_safe_name(name) if name is not None else None, - "email": email, - } + safe_name = make_safe_name(name) if name is not None else None + + query, params = bq( + sql( + f"SELECT {'*' if fetch_all_fields else READ_PARAMS} FROM users WHERE 1 = 1" + ), + AND(id, equals_variable("id", "id")), + AND(safe_name, equals_variable("safe_name", "safe_name")), + AND(email, equals_variable("email", "email")), + ) + player = await app.state.services.database.fetch_one(query, params) return cast(Player, dict(player._mapping)) if player is not None else None @@ -153,24 +153,16 @@ async def fetch_count( play_style: int | None = None, ) -> int: """Fetch the number of players in the database.""" - query = """\ - SELECT COUNT(*) AS count - FROM users - WHERE priv = COALESCE(:priv, priv) - AND country = COALESCE(:country, country) - AND clan_id = COALESCE(:clan_id, clan_id) - AND clan_priv = COALESCE(:clan_priv, clan_priv) - AND preferred_mode = COALESCE(:preferred_mode, preferred_mode) - AND play_style = COALESCE(:play_style, play_style) - """ - params: dict[str, Any] = { - "priv": priv, - "country": country, - "clan_id": clan_id, - "clan_priv": clan_priv, - "preferred_mode": preferred_mode, - "play_style": play_style, - } + query, params = bq( + sql("SELECT COUNT(*) AS count FROM users WHERE 1 = 1"), + AND(priv, equals_variable("priv", "priv")), + AND(country, equals_variable("country", "country")), + AND(clan_id, equals_variable("clan_id", "clan_id")), + AND(clan_priv, equals_variable("clan_priv", "clan_priv")), + AND(preferred_mode, equals_variable("preferred_mode", "preferred_mode")), + AND(play_style, equals_variable("play_style", "play_style")), + ) + rec = await app.state.services.database.fetch_one(query, params) assert rec is not None return cast(int, rec._mapping["count"]) @@ -187,32 +179,22 @@ async def fetch_many( page_size: int | None = None, ) -> list[Player]: """Fetch multiple players from the database.""" - query = f"""\ - SELECT {READ_PARAMS} - FROM users - WHERE priv = COALESCE(:priv, priv) - AND country = COALESCE(:country, country) - AND clan_id = COALESCE(:clan_id, clan_id) - AND clan_priv = COALESCE(:clan_priv, clan_priv) - AND preferred_mode = COALESCE(:preferred_mode, preferred_mode) - AND play_style = COALESCE(:play_style, play_style) - """ - params: dict[str, Any] = { - "priv": priv, - "country": country, - "clan_id": clan_id, - "clan_priv": clan_priv, - "preferred_mode": preferred_mode, - "play_style": play_style, - } - - if page is not None and page_size is not None: - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["limit"] = page_size - params["offset"] = (page - 1) * page_size + query, params = bq( + sql(f"SELECT {READ_PARAMS} FROM users WHERE 1 = 1"), + AND(priv, equals_variable("priv", "priv")), + AND(country, equals_variable("country", "country")), + AND(clan_id, equals_variable("clan_id", "clan_id")), + AND(clan_priv, equals_variable("clan_priv", "clan_priv")), + AND(preferred_mode, equals_variable("preferred_mode", "preferred_mode")), + AND(play_style, equals_variable("play_style", "play_style")), + ( + (page_size, "LIMIT :page_size"), + lambda: ( + (page - 1) * page_size if page is not None else None, + "OFFSET :offset", + ), + ), + ) players = await app.state.services.database.fetch_all(query, params) return cast(list[Player], [dict(p._mapping) for p in players]) @@ -232,10 +214,10 @@ async def update( clan_priv: int | _UnsetSentinel = UNSET, preferred_mode: int | _UnsetSentinel = UNSET, play_style: int | _UnsetSentinel = UNSET, - custom_badge_name: str | None | _UnsetSentinel = UNSET, - custom_badge_icon: str | None | _UnsetSentinel = UNSET, - userpage_content: str | None | _UnsetSentinel = UNSET, - api_key: str | None | _UnsetSentinel = UNSET, + custom_badge_name: str | _UnsetSentinel = UNSET, + custom_badge_icon: str | _UnsetSentinel = UNSET, + userpage_content: str | _UnsetSentinel = UNSET, + api_key: str | _UnsetSentinel = UNSET, ) -> Player | None: """Update a player in the database.""" update_fields: PlayerUpdateFields = {} @@ -273,11 +255,12 @@ async def update( if not isinstance(api_key, _UnsetSentinel): update_fields["api_key"] = api_key - query = f"""\ - UPDATE users - SET {",".join(f"{k} = COALESCE(:{k}, {k})" for k in update_fields)} - WHERE id = :id - """ + query, _ = bq( + sql("UPDATE users SET"), + sql(",".join(f"{k} = :{k}" for k in update_fields)), + WHERE(equals_variable("id", "id")), + ) + values = {"id": id} | update_fields await app.state.services.database.execute(query, values) diff --git a/app/repositories/scores.py b/app/repositories/scores.py index dc440c4a7..43b9a2f44 100644 --- a/app/repositories/scores.py +++ b/app/repositories/scores.py @@ -9,6 +9,7 @@ import app.state.services from app._typing import _UnsetSentinel from app._typing import UNSET +from app.query_builder import build as bq, sql, equals_variable, AND, WHERE # +-----------------+-----------------+------+-----+---------+----------------+ # | Field | Type | Null | Key | Default | Extra | @@ -183,22 +184,14 @@ async def fetch_count( mode: int | None = None, user_id: int | None = None, ) -> int: - query = """\ - SELECT COUNT(*) AS count - FROM scores - WHERE map_md5 = COALESCE(:map_md5, map_md5) - AND mods = COALESCE(:mods, mods) - AND status = COALESCE(:status, status) - AND mode = COALESCE(:mode, mode) - AND userid = COALESCE(:userid, userid) - """ - params: dict[str, Any] = { - "map_md5": map_md5, - "mods": mods, - "status": status, - "mode": mode, - "userid": user_id, - } + query, params = bq( + sql("SELECT COUNT(*) count FROM scores WHERE 1 = 1"), + AND(map_md5, equals_variable("map_md5", "map_md5")), + AND(mods, equals_variable("mods", "mods")), + AND(status, equals_variable("status", "status")), + AND(mode, equals_variable("mode", "mode")), + AND(user_id, equals_variable("userid", "userid")), + ) rec = await app.state.services.database.fetch_one(query, params) assert rec is not None return cast(int, rec._mapping["count"]) @@ -213,29 +206,21 @@ async def fetch_many( page: int | None = None, page_size: int | None = None, ) -> list[Score]: - query = f"""\ - SELECT {READ_PARAMS} - FROM scores - WHERE map_md5 = COALESCE(:map_md5, map_md5) - AND mods = COALESCE(:mods, mods) - AND status = COALESCE(:status, status) - AND mode = COALESCE(:mode, mode) - AND userid = COALESCE(:userid, userid) - """ - params: dict[str, Any] = { - "map_md5": map_md5, - "mods": mods, - "status": status, - "mode": mode, - "userid": user_id, - } - if page is not None and page_size is not None: - query += """\ - LIMIT :page_size - OFFSET :offset - """ - params["page_size"] = page_size - params["offset"] = (page - 1) * page_size + query, params = bq( + sql(f"SELECT {READ_PARAMS} FROM scores WHERE 1 = 1"), + AND(map_md5, equals_variable("map_md5", "map_md5")), + AND(mods, equals_variable("mods", "mods")), + AND(status, equals_variable("status", "status")), + AND(mode, equals_variable("mode", "mode")), + AND(user_id, equals_variable("userid", "userid")), + ( + (page_size, sql("LIMIT :page_size")), + lambda: ( + (page - 1) * page_size if page is not None else None, + sql("OFFSET :offset"), + ), + ), + ) recs = await app.state.services.database.fetch_all(query, params) return cast(list[Score], [dict(r._mapping) for r in recs]) @@ -253,11 +238,12 @@ async def update( if not isinstance(status, _UnsetSentinel): update_fields["status"] = status - query = f"""\ - UPDATE scores - SET {",".join(f"{k} = COALESCE(:{k}, {k})" for k in update_fields)} - WHERE id = :id - """ + query, _ = bq( + sql("UPDATE scores SET"), + sql(",".join(f"{k} = :{k}" for k in update_fields)), + WHERE(equals_variable("id", "id")), + ) + values = {"id": id} | update_fields await app.state.services.database.execute(query, values) diff --git a/app/repositories/stats.py b/app/repositories/stats.py index 6dc623ff3..49fe174fd 100644 --- a/app/repositories/stats.py +++ b/app/repositories/stats.py @@ -8,6 +8,16 @@ import app.state.services from app._typing import _UnsetSentinel from app._typing import UNSET +from app.query_builder import ( + build as bq, + sql, + equals_variable, + AND, + UPDATE, + SET, + WHERE, + table, +) # +--------------+-----------------+------+-----+---------+----------------+ # | Field | Type | Null | Key | Default | Extra | @@ -158,16 +168,12 @@ async def fetch_count( player_id: int | None = None, mode: int | None = None, ) -> int: - query = """\ - SELECT COUNT(*) AS count - FROM stats - WHERE id = COALESCE(:id, id) - AND mode = COALESCE(:mode, mode) - """ - params: dict[str, Any] = { - "id": player_id, - "mode": mode, - } + query, params = bq( + sql("SELECT COUNT(*) count FROM stats WHERE 1 = 1"), + AND(player_id, equals_variable("id", "id")), + AND(mode, equals_variable("mode", "mode")), + ) + rec = await app.state.services.database.fetch_one(query, params) assert rec is not None return cast(int, rec._mapping["count"]) @@ -179,24 +185,18 @@ async def fetch_many( page: int | None = None, page_size: int | None = None, ) -> list[Stat]: - query = f"""\ - SELECT {READ_PARAMS} - FROM stats - WHERE id = COALESCE(:id, id) - AND mode = COALESCE(:mode, mode) - """ - params: dict[str, Any] = { - "id": player_id, - "mode": mode, - } - - if page is not None and page_size is not None: - query += """\ - LIMIT :limit - OFFSET :offset - """ - params["limit"] = page_size - params["offset"] = (page - 1) * page_size + query, params = bq( + sql(f"SELECT {READ_PARAMS} FROM stats WHERE 1 = 1"), + AND(player_id, equals_variable("id", "id")), + AND(mode, equals_variable("mode", "mode")), + ( + (page_size, "LIMIT :page_size"), + lambda: ( + (page - 1) * page_size if page is not None else None, + "OFFSET :offset", + ), + ), + ) stats = await app.state.services.database.fetch_all(query, params) return cast(list[Stat], [dict(s._mapping) for s in stats]) @@ -251,12 +251,16 @@ async def update( if not isinstance(a_count, _UnsetSentinel): update_fields["a_count"] = a_count - query = f"""\ - UPDATE stats - SET {",".join(f"{k} = COALESCE(:{k}, {k})" for k in update_fields)} - WHERE id = :id - AND mode = :mode - """ + query, _ = bq( + UPDATE( + table("stats"), + SET(*(equals_variable(k, k) for k in update_fields)), + WHERE( + equals_variable("id", "id"), + AND(equals_variable("mode", "mode")), + ), + ) + ) values = {"id": player_id, "mode": mode} | update_fields await app.state.services.database.execute(query, values)