From 3e654f7bb784a62dafab66cff13662c761f2de93 Mon Sep 17 00:00:00 2001 From: nova Date: Mon, 19 Feb 2024 09:32:03 +0000 Subject: [PATCH 1/2] Fix typing --- src/config.py | 2 +- src/extensions/boosts.py | 10 +++------- src/extensions/hello_world.py | 14 +++----------- src/extensions/uptime.py | 8 ++++---- src/extensions/user_roles.py | 34 +++++++++++----------------------- src/utils.py | 6 ++++-- 6 files changed, 26 insertions(+), 48 deletions(-) diff --git a/src/config.py b/src/config.py index a4ef636..42cc71a 100644 --- a/src/config.py +++ b/src/config.py @@ -7,4 +7,4 @@ TOKEN = os.environ.get("TOKEN") # required DEBUG = os.environ.get("DEBUG", False) -CHANNEL_IDS = {"lobby": "627542044390457350"} +CHANNEL_IDS: dict[str, int] = {"lobby": 627542044390457350} diff --git a/src/extensions/boosts.py b/src/extensions/boosts.py index 5f8593a..5fc6c7d 100644 --- a/src/extensions/boosts.py +++ b/src/extensions/boosts.py @@ -12,9 +12,7 @@ hikari.MessageType.USER_PREMIUM_GUILD_SUBSCRIPTION_TIER_3, ] -BOOST_MESSAGE_TYPES: list[hikari.MessageType] = BOOST_TIERS + [ - hikari.MessageType.USER_PREMIUM_GUILD_SUBSCRIPTION, -] +BOOST_MESSAGE_TYPES: list[hikari.MessageType] = BOOST_TIERS + [hikari.MessageType.USER_PREMIUM_GUILD_SUBSCRIPTION] def build_boost_message( @@ -26,9 +24,7 @@ def build_boost_message( assert message_type in BOOST_MESSAGE_TYPES base_message = f"{booster_user.display_name} just boosted the server" - multiple_boosts_message = ( - f" **{number_of_boosts}** times" if number_of_boosts else "" - ) + multiple_boosts_message = f" **{number_of_boosts}** times" if number_of_boosts else "" message = base_message + multiple_boosts_message + "!" @@ -40,7 +36,7 @@ def build_boost_message( @plugin.listen() -async def on_message(event: hikari.GuildMessageCreateEvent): +async def on_message(event: hikari.GuildMessageCreateEvent) -> None: if event.message.type not in BOOST_MESSAGE_TYPES: return diff --git a/src/extensions/hello_world.py b/src/extensions/hello_world.py index 1f018cc..be97d92 100644 --- a/src/extensions/hello_world.py +++ b/src/extensions/hello_world.py @@ -31,16 +31,10 @@ async def options( ctx: arc.GatewayContext, option_str: arc.Option[str, arc.StrParams("A string option")], option_int: arc.Option[int, arc.IntParams("An integer option")], - option_attachment: arc.Option[ - hikari.Attachment, arc.AttachmentParams("An attachment option") - ], + option_attachment: arc.Option[hikari.Attachment, arc.AttachmentParams("An attachment option")], ) -> None: """A command with lots of options""" - embed = hikari.Embed( - title="There are a lot of options here", - description="Maybe too many", - colour=0x5865F2, - ) + embed = hikari.Embed(title="There are a lot of options here", description="Maybe too many", colour=0x5865F2) embed.set_image(option_attachment) embed.add_field("String option", option_str, inline=False) embed.add_field("Integer option", str(option_int), inline=False) @@ -52,9 +46,7 @@ async def options( async def components(ctx: arc.GatewayContext) -> None: """A command with components""" builder = ctx.client.rest.build_message_action_row() - select_menu = builder.add_text_menu( - "select_me", placeholder="I wonder what this does", min_values=1, max_values=2 - ) + select_menu = builder.add_text_menu("select_me", placeholder="I wonder what this does", min_values=1, max_values=2) for opt in ("Select me!", "No, select me!", "Select me too!"): select_menu.add_option(opt, opt) diff --git a/src/extensions/uptime.py b/src/extensions/uptime.py index 2b1cf89..d477b6c 100644 --- a/src/extensions/uptime.py +++ b/src/extensions/uptime.py @@ -9,14 +9,14 @@ @plugin.include @arc.slash_command("uptime", "Show formatted uptime of Blockbot") -async def uptime(ctx): +async def uptime(ctx: arc.GatewayContext) -> None: up_time = datetime.now() - start_time d = up_time.days h, ms = divmod(up_time.seconds, 3600) m, s = divmod(ms, 60) - def format(val, str): - return f"{val} {str}{'s' if val != 1 else ''}" + def format(val: int, type_: str): + return f"{val} {type_}{'s' if val != 1 else ''}" message_parts = [(d, "day"), (h, "hour"), (m, "minute"), (s, "second")] formatted_parts = [format(val, str) for val, str in message_parts if val] @@ -25,5 +25,5 @@ def format(val, str): @arc.loader -def loader(client): +def loader(client: arc.GatewayClient) -> None: client.add_plugin(plugin) diff --git a/src/extensions/user_roles.py b/src/extensions/user_roles.py index 63892c2..14e9c09 100644 --- a/src/extensions/user_roles.py +++ b/src/extensions/user_roles.py @@ -24,18 +24,11 @@ async def add_role( assert ctx.member if int(role) in ctx.member.role_ids: - return await ctx.respond( - f"You already have the {role_mention(role)} role.", - flags=hikari.MessageFlag.EPHEMERAL, - ) + await ctx.respond(f"You already have the {role_mention(role)} role.", flags=hikari.MessageFlag.EPHEMERAL) + return - await ctx.client.rest.add_role_to_member( - ctx.guild_id, ctx.author, int(role), reason="Self-service role." - ) - await ctx.respond( - f"Done! Added {role_mention(role)} to your roles.", - flags=hikari.MessageFlag.EPHEMERAL, - ) + await ctx.client.rest.add_role_to_member(ctx.guild_id, ctx.author, int(role), reason="Self-service role.") + await ctx.respond(f"Done! Added {role_mention(role)} to your roles.", flags=hikari.MessageFlag.EPHEMERAL) @role.include @@ -48,18 +41,13 @@ async def remove_role( assert ctx.member if int(role) not in ctx.member.role_ids: - return await ctx.respond( - f"You don't have the {role_mention(role)} role.", - flags=hikari.MessageFlag.EPHEMERAL, - ) + await ctx.respond(f"You don't have the {role_mention(role)} role.", flags=hikari.MessageFlag.EPHEMERAL) + return await ctx.client.rest.remove_role_from_member( ctx.guild_id, ctx.author, int(role), reason=f"{ctx.author} removed role." ) - await ctx.respond( - f"Done! Removed {role_mention(role)} from your roles.", - flags=hikari.MessageFlag.EPHEMERAL, - ) + await ctx.respond(f"Done! Removed {role_mention(role)} from your roles.", flags=hikari.MessageFlag.EPHEMERAL) @role.set_error_handler @@ -68,15 +56,15 @@ async def role_error_handler(ctx: arc.GatewayContext, exc: Exception) -> None: assert role is not None if isinstance(exc, hikari.ForbiddenError): - return await ctx.respond( + await ctx.respond( f"❌ Blockbot is not permitted to self-service the {role_mention(role)} role.", flags=hikari.MessageFlag.EPHEMERAL, ) + return if isinstance(exc, hikari.NotFoundError): - return await ctx.respond( - "❌ Blockbot can't find that role.", flags=hikari.MessageFlag.EPHEMERAL - ) + await ctx.respond("❌ Blockbot can't find that role.", flags=hikari.MessageFlag.EPHEMERAL) + return raise exc diff --git a/src/utils.py b/src/utils.py index f4a1892..20762ce 100644 --- a/src/utils.py +++ b/src/utils.py @@ -2,9 +2,11 @@ from arc import GatewayClient -async def get_guild(client: GatewayClient, event: hikari.GuildMessageCreateEvent): +async def get_guild( + client: GatewayClient, event: hikari.GuildMessageCreateEvent +) -> hikari.GatewayGuild | hikari.RESTGuild: return event.get_guild() or await client.rest.fetch_guild(event.guild_id) -def role_mention(role_id: hikari.Snowflake | int | str): +def role_mention(role_id: hikari.Snowflake | int | str) -> str: return f"<@&{role_id}>" From 23f75c766239c8487f6bb8d6af2360ea8f1cea85 Mon Sep 17 00:00:00 2001 From: nova Date: Thu, 22 Feb 2024 10:52:17 +0000 Subject: [PATCH 2/2] Add starboard base --- .dockerignore | 5 ++ Dockerfile | 2 +- docker-compose.yaml | 24 ++++++ requirements.txt | 10 ++- src/bot.py | 16 ++-- src/config.py | 15 +++- src/database.py | 34 ++++++++ src/extensions/starboard.py | 164 ++++++++++++++++++++++++++++++++++++ 8 files changed, 259 insertions(+), 11 deletions(-) create mode 100644 .dockerignore create mode 100644 docker-compose.yaml create mode 100644 src/database.py create mode 100644 src/extensions/starboard.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..c2f97b7 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,5 @@ +.github/ +.ruff_cache/ +.venv/ +postgres_data/ +__pycache__/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 10415ec..f238901 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.10.13-alpine3.18 +FROM python:3.12.2-alpine3.19 WORKDIR /app diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..8a69253 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,24 @@ +version: "2" + +services: + bot: + build: . + depends_on: + - postgres + environment: + - POSTGRES_HOST=postgres + - POSTGRES_PORT=5432 + - POSTGRES_DB_NAME=blockbot + restart: unless-stopped + + postgres: + image: postgres:16.2-alpine3.19 + environment: + POSTGRES_DB: blockbot + PGDATA: /var/lib/postgresql/data + restart: unless-stopped + volumes: + - ./postgres_data:/var/lib/postgresql/data + +volumes: + postgres_data: \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index dc1dbf7..a2fce9d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ +asyncpg==0.29.0 hikari==2.0.0.dev122 -hikari-arc==1.1.0 -ruff==0.2.0 -pre-commit==3.6.0 -python-dotenv==1.0.1 \ No newline at end of file +hikari-arc==1.2.1 +pre-commit==3.6.2 +python-dotenv==1.0.1 +ruff==0.2.2 +SQLAlchemy==2.0.27 \ No newline at end of file diff --git a/src/bot.py b/src/bot.py index e078fcf..8659b56 100644 --- a/src/bot.py +++ b/src/bot.py @@ -1,14 +1,11 @@ import logging -import sys import arc import hikari +from sqlalchemy.ext.asyncio import AsyncEngine from src.config import DEBUG, TOKEN - -if TOKEN is None: - print("TOKEN environment variable not set. Exiting.") - sys.exit(1) +from src.database import Base, engine bot = hikari.GatewayBot( token=TOKEN, @@ -22,7 +19,16 @@ client = arc.GatewayClient(bot, is_dm_enabled=False) client.load_extensions_from("./src/extensions/") +@client.set_startup_hook +async def startup_hook(client: arc.GatewayClient) -> None: + client.set_type_dependency(AsyncEngine, engine) + + async with engine.begin() as conn: + # await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) +# TODO: fix bug where commands that error are respond to twice +# (Once by our message, once by built-in response) @client.set_error_handler async def error_handler(ctx: arc.GatewayContext, exc: Exception) -> None: if DEBUG: diff --git a/src/config.py b/src/config.py index 42cc71a..14e7459 100644 --- a/src/config.py +++ b/src/config.py @@ -1,10 +1,23 @@ import os +import sys from dotenv import load_dotenv load_dotenv() -TOKEN = os.environ.get("TOKEN") # required +def get_required_var(var: str) -> str: + env = os.environ.get(var) + if env is None: + print(f"{var} environment variable not set. Exiting.") + sys.exit(1) + return env + +TOKEN = get_required_var("TOKEN") DEBUG = os.environ.get("DEBUG", False) +POSTGRES_USER = get_required_var("POSTGRES_USER") +POSTGRES_PASSWORD = get_required_var("POSTGRES_PASSWORD") +POSTGRES_HOST = get_required_var("POSTGRES_HOST") +POSTGRES_PORT = get_required_var("POSTGRES_PORT") +POSTGRES_DB_NAME = get_required_var("POSTGRES_DB_NAME") CHANNEL_IDS: dict[str, int] = {"lobby": 627542044390457350} diff --git a/src/database.py b/src/database.py new file mode 100644 index 0000000..31e8426 --- /dev/null +++ b/src/database.py @@ -0,0 +1,34 @@ +from sqlalchemy import BigInteger, Column, Integer, SmallInteger +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import declarative_base + +from src.config import POSTGRES_HOST, POSTGRES_PASSWORD, POSTGRES_PORT, POSTGRES_USER, POSTGRES_DB_NAME + + +Base = declarative_base() + +# TODO: add reprs? + +class StarboardSettings(Base): + __tablename__ = "starboard_settings" + + guild = Column(BigInteger, nullable=False, primary_key=True) + channel = Column(BigInteger, nullable=True) + threshold = Column(SmallInteger, nullable=False, default=3) + + +class Starboard(Base): + __tablename__ = "starboard" + + id = Column(Integer, nullable=False, primary_key=True, autoincrement=True) + channel = Column(BigInteger, nullable=False) + message = Column(BigInteger, nullable=False) + stars = Column(SmallInteger, nullable=False) + starboard_channel = Column(BigInteger, nullable=False) + starboard_message = Column(BigInteger, nullable=False) + starboard_stars = Column(SmallInteger, nullable=False) + + +engine = create_async_engine( + f"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB_NAME}" +) diff --git a/src/extensions/starboard.py b/src/extensions/starboard.py new file mode 100644 index 0000000..5551845 --- /dev/null +++ b/src/extensions/starboard.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import logging + +import arc +import hikari +from sqlalchemy import delete, insert, select, update +from sqlalchemy.ext.asyncio import AsyncEngine + +from src.database import Starboard, StarboardSettings + +logger = logging.getLogger(__name__) + +plugin = arc.GatewayPlugin("Starboard") + +@plugin.listen() +@plugin.inject_dependencies +async def on_reaction( + event: hikari.GuildReactionAddEvent, + session: AsyncEngine = arc.inject(), +) -> None: + logger.info("Received guild reaction add event") + + if event.emoji_name != "⭐": + return + + message = await plugin.client.rest.fetch_message(event.channel_id, event.message_id) + star_count = sum(r.emoji == "⭐" for r in message.reactions) + + stmt = select(StarboardSettings).where(StarboardSettings.guild == event.guild_id) + async with session.connect() as conn: + result = await conn.execute(stmt) + + settings = result.first() + + # TODO: remove temporary logging and merge into one if statement + if not settings: + logger.info("Received star but no guild starboard set") + return + if star_count < settings.threshold: + logger.info("Not enough stars to post to starboard") + return + if not settings.channel: + logger.info("No starboard channel set") + return + + async with session.connect() as conn: + stmt = select(Starboard).where(Starboard.message == event.message_id) + result = await conn.execute(stmt) + starboard = result.first() + + logger.info(starboard) + + if not starboard: + stmt = select(Starboard).where(Starboard.starboard_message == event.message_id) + result = await conn.execute(stmt) + starboard = result.first() + + logger.info(starboard) + + embed = hikari.Embed(description=f"⭐ {star_count}\n[link]({message.make_link(event.guild_id)})") + + # TODO: handle starring the starboard message + # i.e. don't create a starboard message for the starboard message + + if not starboard: + try: + logger.info("Creating message") + message = await plugin.client.rest.create_message( + settings.channel, + embed, + ) + stmt = insert(Starboard).values( + channel=event.channel_id, + message=event.message_id, + stars=star_count, + starboard_channel=settings.channel, + starboard_message=message.id, + starboard_stars=0, + ) + + async with session.begin() as conn: + await conn.execute(stmt) + await conn.commit() + except hikari.ForbiddenError: + logger.info("Can't access starboard channel") + stmt = update(StarboardSettings).where(StarboardSettings.guild == event.guild_id).values( + channel=None) + + async with session.begin() as conn: + await conn.execute(stmt) + await conn.commit() + + else: + try: + logger.info("Editing message") + await plugin.client.rest.edit_message( + starboard.starboard_channel, + starboard.starboard_message, + embed + ) + except hikari.ForbiddenError: + logger.info("Can't edit starboard message") + stmt = delete(StarboardSettings).where(StarboardSettings.guild == event.guild_id) + + async with session.begin() as conn: + await conn.execute(stmt) + await conn.commit() + +@plugin.include +@arc.slash_command("starboard", "Edit or view starboard settings.", default_permissions=hikari.Permissions.MANAGE_GUILD) +async def starboard_settings( + ctx: arc.GatewayContext, + channel: arc.Option[hikari.TextableGuildChannel | None, arc.ChannelParams("The channel to post starboard messages to.")] = None, + threshold: arc.Option[int | None, arc.IntParams("The minimum number of stars before this message is posted to the starboard", min=1)] = None, + session: AsyncEngine = arc.inject(), +) -> None: + assert ctx.guild_id + + stmt = select(StarboardSettings).where(StarboardSettings.guild == ctx.guild_id) + async with session.connect() as conn: + result = await conn.execute(stmt) + + settings = result.first() + + if not channel and not threshold: + if not settings: + await ctx.respond("This server has no starboard settings.", flags=hikari.MessageFlag.EPHEMERAL) + else: + # TODO: `channel` and `threshold` can be None + embed = hikari.Embed( + title="Starboard Settings", + description=( + f"**Channel:** <#{settings.channel}>\n" + f"**Threshold:** {settings.threshold}" + ), + ) + await ctx.respond(embed) + + return + + if not settings: + stmt = insert(StarboardSettings).values(guild=ctx.guild_id) + else: + stmt = update(StarboardSettings).where(StarboardSettings.guild == ctx.guild_id) + + # TODO: simplify logic + if channel and threshold: + stmt = stmt.values(channel=channel.id, threshold=threshold) + elif channel: + stmt = stmt.values(channel=channel.id) + elif threshold: + stmt = stmt.values(threshold=threshold) + + async with session.begin() as conn: + await conn.execute(stmt) + await conn.commit() + + # TODO: respond with embed of new settings? + await ctx.respond("Settings updated.", flags=hikari.MessageFlag.EPHEMERAL) + +@arc.loader +def loader(client: arc.GatewayClient) -> None: + client.add_plugin(plugin)