From 68b077a7769b95b09b5bc1f72a6adb9471c8f603 Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Wed, 5 Jul 2023 14:02:15 +0930 Subject: [PATCH 01/22] Refactored config into separate file. --- config.example.json | 12 +++++++----- utils/config.py | 23 +++++++++++++++++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) create mode 100644 utils/config.py diff --git a/config.example.json b/config.example.json index 6563df7..d031dca 100644 --- a/config.example.json +++ b/config.example.json @@ -1,7 +1,9 @@ { - "token": "", - "guild": 0, - "channel": 0, - "prefix": "", - "status": "" + "token": "", + "guild": 0, + "channel": 0, + "prefix": "", + "status": "", + "application_id": 0, + "id_prefix": "" } diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000..619effc --- /dev/null +++ b/utils/config.py @@ -0,0 +1,23 @@ +import json +import os +from pathlib import Path + +_path = Path(__file__).parent / "../config.json" +_config = json.load(open(_path, 'r')) + +token = os.getenv( + "MODMAIL_TOKEN") if "MODMAIL_TOKEN" in os.environ else _config["token"] +application_id = int( + os.getenv("MODMAIL_APPLICATION_ID") +) if "MODMAIL_APPLICATION_ID" in os.environ else _config["application_id"] +guild = int(os.getenv( + "MODMAIL_GUILD")) if "MODMAIL_GUILD" in os.environ else _config["guild"] +channel = int(os.getenv("MODMAIL_CHANNEL") + ) if "MODMAIL_CHANNEL" in os.environ else _config["channel"] +prefix = os.getenv( + "MODMAIL_PREFIX") if "MODMAIL_PREFIX" in os.environ else _config["prefix"] +status = os.getenv( + "MODMAIL_STATUS") if "MODMAIL_STATUS" in os.environ else _config["status"] +id_prefix = os.getenv( + "MODMAIL_ID_PREFIX" +) if "MODMAIL_ID_PREFIX" in os.environ else _config["id_prefix"] From a258170b0193d01f04dfc114c6476cecf5bffa22 Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Wed, 5 Jul 2023 14:02:35 +0930 Subject: [PATCH 02/22] Refactored db functions to use async sqlite library. --- db.py | 319 ++++++++++++++++++++++++++++++++++------------------------ 1 file changed, 185 insertions(+), 134 deletions(-) diff --git a/db.py b/db.py index cdc4fda..8fbca10 100644 --- a/db.py +++ b/db.py @@ -1,187 +1,238 @@ -import sqlite3 - -def database(func): - def wrapper(*args,**kwargs): - conn = sqlite3.connect('/database/modmail.db') - conn.row_factory = sqlite3.Row - cursor = conn.cursor() - ret = func(cursor, *args, *kwargs) - conn.commit() - conn.close() - return ret - return wrapper - -@database -def get_ticket(cursor, ticket_id): +from contextlib import asynccontextmanager +import aiosqlite +from typing import Optional + +# FIX: Change this to `/database/modmail.db` before push +path = 'modmail.db' + + +@asynccontextmanager +async def db_ops(): + conn = await aiosqlite.connect(path) + conn.row_factory = aiosqlite.Row + cursor = await conn.cursor() + yield cursor + await conn.commit() + await conn.close() + + +class Ticket: + + def __init__(self, ticket_id: int, user: int, open: int, + message_id: Optional[int]): + self.ticket_id = ticket_id + self.user = user + self.open = open + self.message_id = message_id + + def __repr__(self) -> str: + return f"Ticket({self.ticket_id}, {self.user}, {self.open}, {self.message_id})" + + +class TicketResponse: + + def __init__(self, user: int, response: str, timestamp: int, + as_server: bool): + self.user = user + self.response = response + self.timestamp = timestamp + self.as_server = as_server + + def __repr__(self) -> str: + return f"TicketResponse({self.user}, {self.response}, {self.timestamp}, {self.as_server})" + + +class Timeout: + + def __init__(self, timeout_id: int, timestamp: int): + self.timeout_id = timeout_id + self.timestamp = timestamp + + def __repr__(self) -> str: + return f"Timeout({self.timeout_id}, {self.user}, {self.timestamp})" + + +async def get_ticket(ticket_id: int) -> Optional[Ticket]: sql = """ SELECT ticket_id, user, open, message_id FROM mm_tickets WHERE ticket_id=? """ - cursor.execute(sql, [ticket_id]) - ticket = cursor.fetchone() - if ticket is None or len(ticket) == 0: - return -1 - else: - return ticket - -@database -def get_ticket_by_user(cursor, user): + async with db_ops() as cursor: + await cursor.execute(sql, [ticket_id]) + ticket = await cursor.fetchone() + if ticket is None or len(ticket) == 0: + return None + else: + return Ticket(*ticket) + + +async def get_ticket_by_user(user: int) -> Optional[Ticket]: sql = """ SELECT ticket_id, user, open, message_id FROM mm_tickets WHERE user=? AND open=1 """ - cursor.execute(sql, [user]) - ticket = cursor.fetchone() - if ticket is None or len(ticket) == 0: - return {'ticket_id': -1, 'user': -1, 'open':0, 'message_id':-1} - else: - return ticket - -@database -def get_ticket_by_message(cursor, message_id): + async with db_ops() as cursor: + await cursor.execute(sql, [user]) + ticket = await cursor.fetchone() + if ticket is None or len(ticket) == 0: + return None + else: + return Ticket(*ticket) + + +async def get_ticket_by_message(message_id: int) -> Optional[Ticket]: sql = """ SELECT ticket_id, user, open, message_id FROM mm_tickets WHERE message_id=? """ - cursor.execute(sql, [message_id]) - ticket = cursor.fetchone() - if ticket is None or len(ticket) == 0: - return {'ticket_id': -1, 'user': -1, 'open':0, 'message_id':-1} - else: - return ticket - -@database -def open_ticket(cursor, user): + async with db_ops() as cursor: + await cursor.execute(sql, [message_id]) + ticket = await cursor.fetchone() + if ticket is None or len(ticket) == 0: + return None + else: + return Ticket(*ticket) + + +async def open_ticket(user: int) -> Optional[int]: sql = """ INSERT INTO mm_tickets (user) VALUES (?) """ - cursor.execute(sql, [user]) - return cursor.lastrowid + async with db_ops() as cursor: + await cursor.execute(sql, [user]) + return cursor.lastrowid -@database -def update_ticket_message(cursor, ticket_id, message_id): + +async def update_ticket_message(ticket_id: int, message_id: int) -> bool: sql = """ UPDATE mm_tickets SET message_id=? WHERE ticket_id=? """ - cursor.execute(sql, [message_id, ticket_id]) - return cursor.rowcount != 0 + async with db_ops() as cursor: + await cursor.execute(sql, [message_id, ticket_id]) + return cursor.rowcount != 0 + -@database -def close_ticket(cursor, ticket_id): +async def close_ticket(ticket_id: int) -> bool: sql = """ UPDATE mm_tickets SET open=0 WHERE ticket_id=? """ - cursor.execute(sql,[ticket_id]) + async with db_ops() as cursor: + await cursor.execute(sql, [ticket_id]) + return cursor.rowcount != 0 - return cursor.rowcount != 0 -@database -def get_ticket_responses(cursor, ticket_id): +async def get_ticket_responses(ticket_id: int) -> list[TicketResponse]: sql = """ SELECT user, response, timestamp, as_server FROM mm_ticket_responses WHERE ticket_id=? """ - cursor.execute(sql, [ticket_id]) - return cursor.fetchall() + async with db_ops() as cursor: + await cursor.execute(sql, [ticket_id]) + rows = await cursor.fetchall() + return [TicketResponse(*row) for row in rows] + -@database -def add_ticket_response(cursor, ticket_id, user, response, as_server): +async def add_ticket_response(ticket_id: int, user: int, response: str, + as_server: bool) -> Optional[int]: sql = """ INSERT INTO mm_ticket_responses (ticket_id, user, response, as_server) VALUES (?, ?, ?, ?) """ - cursor.execute(sql, [ticket_id, user, response, as_server]) - return True + async with db_ops() as cursor: + await cursor.execute(sql, [ticket_id, user, response, as_server]) + return cursor.lastrowid -@database -def get_timeout(cursor, user): + +async def get_timeout(user: int) -> Optional[Timeout]: sql = """ - SELECT timestamp + SELECT timeout_id, timestamp FROM mm_timeouts WHERE user=? """ - cursor.execute(sql , [user]) - timeout = cursor.fetchone() - if timeout is None or len(timeout) == 0: - return False - else: - return timeout - -@database -def set_timeout(cursor, user, timestamp): - sql = """ - INSERT OR REPLACE INTO mm_timeouts (user, timestamp) - VALUES (?, ?) - """ - cursor.execute(sql, [user, timestamp]) - return True - -@database -def init(cursor): - - #Create modmail tickets table - sql = """ - CREATE TABLE IF NOT EXISTS mm_tickets ( - ticket_id INTEGER PRIMARY KEY AUTOINCREMENT, - user INTEGER NOT NULL, - open BOOLEAN DEFAULT 1 NOT NULL, - message_id INTEGER - ); - """ - result = cursor.execute(sql) - - #Create modmail ticket user index - sql = "CREATE INDEX IF NOT EXISTS mm_tickets_user ON mm_tickets(user);" - result = cursor.execute(sql) + async with db_ops() as cursor: + await cursor.execute(sql, [user]) + timeout = await cursor.fetchone() + if timeout is None or len(timeout) == 0: + return None + else: + return Timeout(*timeout) - #Create modmail ticket message index - sql = "CREATE INDEX IF NOT EXISTS mm_tickets_message ON mm_tickets(message_id);" - result = cursor.execute(sql) - #Create modmail ticket repsonses table +async def set_timeout(user: int, timestamp: int) -> Optional[int]: sql = """ - CREATE TABLE IF NOT EXISTS mm_ticket_responses ( - response_id INTEGER PRIMARY KEY AUTOINCREMENT, - ticket_id INTEGER, - user INTEGER NOT NULL, - response TEXT NOT NULL, - timestamp TIMESTAMP DEFAULT (strftime('%s', 'now')) NOT NULL, - as_server BOOLEAN NOT NULL, - FOREIGN KEY (ticket_id) REFERENCES mm_tickets (ticket_id) - ); - """ - result = cursor.execute(sql) - - #Create modmail ticket response ticket id index - sql = "CREATE INDEX IF NOT EXISTS mm_ticket_responses_ticket_id ON mm_ticket_responses(ticket_id);" - result = cursor.execute(sql) - - #Create modmail ticket response user index - sql = "CREATE INDEX IF NOT EXISTS mm_ticket_responses_user ON mm_ticket_responses(user);" - result = cursor.execute(sql) - - #Create modmail timeouts table - sql = """ - CREATE TABLE IF NOT EXISTS mm_timeouts ( - timeout_id INTEGER PRIMARY KEY AUTOINCREMENT, - user INTEGER NOT NULL UNIQUE, - timestamp TIMESTAMP DEFAULT (strftime('%s', 'now')) NOT NULL - ); + INSERT OR REPLACE INTO mm_timeouts (user, timestamp) + VALUES (?, ?) """ - result = cursor.execute(sql) - - #Create modmail timeout user index - sql = "CREATE UNIQUE INDEX IF NOT EXISTS mm_timeouts_user ON mm_timeouts(user);" - result = cursor.execute(sql) - - return True \ No newline at end of file + async with db_ops() as cursor: + await cursor.execute(sql, [user, timestamp]) + return cursor.lastrowid + + +async def init(): + async with db_ops() as cursor: + # Create modmail tickets table + sql = """ + CREATE TABLE IF NOT EXISTS mm_tickets ( + ticket_id INTEGER PRIMARY KEY AUTOINCREMENT, + user INTEGER NOT NULL, + open BOOLEAN DEFAULT 1 NOT NULL, + message_id INTEGER + ); + """ + await cursor.execute(sql) + + # Create modmail ticket user index + sql = "CREATE INDEX IF NOT EXISTS mm_tickets_user ON mm_tickets(user);" + await cursor.execute(sql) + + # Create modmail ticket message index + sql = "CREATE INDEX IF NOT EXISTS mm_tickets_message ON mm_tickets(message_id);" + await cursor.execute(sql) + + # Create modmail ticket repsonses table + sql = """ + CREATE TABLE IF NOT EXISTS mm_ticket_responses ( + response_id INTEGER PRIMARY KEY AUTOINCREMENT, + ticket_id INTEGER, + user INTEGER NOT NULL, + response TEXT NOT NULL, + timestamp TIMESTAMP DEFAULT (strftime('%s', 'now')) NOT NULL, + as_server BOOLEAN NOT NULL, + FOREIGN KEY (ticket_id) REFERENCES mm_tickets (ticket_id) + ); + """ + await cursor.execute(sql) + + # Create modmail ticket response ticket id index + sql = "CREATE INDEX IF NOT EXISTS mm_ticket_responses_ticket_id ON mm_ticket_responses(ticket_id);" + await cursor.execute(sql) + + # Create modmail ticket response user index + sql = "CREATE INDEX IF NOT EXISTS mm_ticket_responses_user ON mm_ticket_responses(user);" + await cursor.execute(sql) + + # Create modmail timeouts table + sql = """ + CREATE TABLE IF NOT EXISTS mm_timeouts ( + timeout_id INTEGER PRIMARY KEY AUTOINCREMENT, + user INTEGER NOT NULL UNIQUE, + timestamp TIMESTAMP DEFAULT (strftime('%s', 'now')) NOT NULL + ); + """ + await cursor.execute(sql) + + # Create modmail timeout user index + sql = "CREATE UNIQUE INDEX IF NOT EXISTS mm_timeouts_user ON mm_timeouts(user);" + await cursor.execute(sql) + + return True From 487dfe173a4e34e13d1ca2c95012bea2092ff752 Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Wed, 5 Jul 2023 14:03:01 +0930 Subject: [PATCH 03/22] Updated `requirements.txt` and `.gitignore` --- .gitignore | 3 ++- requirements.txt | 11 ++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index d9ecc4f..b276478 100644 --- a/.gitignore +++ b/.gitignore @@ -102,6 +102,7 @@ celerybeat.pid *.sage.py # Environments +modmail.env .env .venv env/ @@ -129,4 +130,4 @@ dmypy.json .pyre/ config.json -modmail.db \ No newline at end of file +modmail.db diff --git a/requirements.txt b/requirements.txt index 503dba9..196d252 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,10 @@ -discord.py \ No newline at end of file +aiohttp==3.8.4 +aiosignal==1.3.1 +async-timeout==4.0.2 +attrs==23.1.0 +charset-normalizer==3.1.0 +discord.py==2.3.1 +frozenlist==1.3.3 +idna==3.4 +multidict==6.0.4 +yarl==1.9.2 From 9ee6d431bf3f7d3e80dc8ade01af256fe14edc7a Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Wed, 5 Jul 2023 14:03:26 +0930 Subject: [PATCH 04/22] Converted to using slash commands and interactions. --- cogs/command_actions.py | 160 --------------------------- cogs/commands.py | 147 +++++++++++++++++++++++++ cogs/listeners.py | 89 +++++++++------ modmail.py | 205 ++++++++++------------------------- utils/actions.py | 226 +++++++++++++++++++++++++++++++++++++++ utils/embed_reactions.py | 143 ------------------------- utils/ticket_embed.py | 206 +++++++++++++++++++++++++++-------- utils/uformatter.py | 18 ++-- 8 files changed, 663 insertions(+), 531 deletions(-) delete mode 100644 cogs/command_actions.py create mode 100644 cogs/commands.py create mode 100644 utils/actions.py delete mode 100644 utils/embed_reactions.py diff --git a/cogs/command_actions.py b/cogs/command_actions.py deleted file mode 100644 index dd4792b..0000000 --- a/cogs/command_actions.py +++ /dev/null @@ -1,160 +0,0 @@ -import datetime, asyncio, discord -import utils.umember as umember, db, utils.ticket_embed as ticket_embed -from utils.embed_reactions import embed_reactions -from discord.ext import commands - -class CommandActions(commands.Cog): - """Cog to contain command action methods.""" - - def __init__(self, bot, modmail_channel): - """Constructs necessary attributes for all command action methods. - - Args: - bot (commands.Bot): The bot object. - modmail_channel (discord.Channel): The modmail channel specified in config. - """ - - self.bot = bot - self.modmail_channel = modmail_channel - print("\nCog 'Command Actions' loaded") - - async def cog_check(self, ctx: commands.Context): - if ctx.channel != self.modmail_channel: - return False - - if ctx.author == self.bot: - return False - - return True - - @commands.command(name="open") - @commands.guild_only() - async def open_ticket(self, ctx, member): - """Opens ticket for specified user if no tickets are currently open.""" - - member = umember.assert_member(ctx.guild, member) - - ticket = db.get_ticket_by_user(member.id) - - if ticket['ticket_id'] != -1: - await ctx.send('There is already a ticket open for {0}'.format(member)) - return - - ticket_id = db.open_ticket(member.id) - ticket = db.get_ticket(ticket_id) - - ticket_message = await ctx.send(embed=ticket_embed.channel_embed(ctx.guild, ticket['ticket_id'])) - db.update_ticket_message(ticket['ticket_id'], ticket_message.id) - - await post_reactions(ticket_message) - - @commands.command(name="refresh") - @commands.guild_only() - async def refresh_ticket(self, ctx, member): - """Resends embed for specified user if there is a ticket that is already open.""" - - member = umember.assert_member(ctx.guild, member) - - ticket = db.get_ticket_by_user(member.id) - - if ticket['ticket_id'] == -1: - await self.message.channel.send('There is no ticket open for {0}.'.format(member)) - return - - ticket_message = await ctx.send(embed=ticket_embed.channel_embed(ctx.guild, ticket['ticket_id'])) - db.update_ticket_message(ticket['ticket_id'], ticket_message.id) - - if ticket['message_id'] is not None and ticket['message_id'] != -1: - old_ticket_message = await ctx.channel.fetch_message(ticket['message_id']) - await old_ticket_message.delete() - - await post_reactions(ticket_message) - - @commands.command(name="close") - @commands.guild_only() - async def close_ticket(self, ctx, member): - """Closes ticket for specified user given that a ticket is already open.""" - - member = umember.assert_member(ctx.guild, member) - - ticket = db.get_ticket_by_user(member.id) - - if ticket['ticket_id'] == -1: - await ctx.send('There is no ticket open for {0}.'.format(member)) - return - - embed_commands = embed_reactions(self.bot, ctx.guild, self.modmail_channel, ctx.author, ticket) - await embed_commands.message_close() - - @commands.command(name="timeout") - @commands.guild_only() - async def timeout_ticket(self, ctx, member): - """Times out specified user.""" - - member = umember.assert_member(ctx.guild, member) - - embed_commands = embed_reactions(self.bot, ctx.guild, self.modmail_channel, ctx.author) - await embed_commands.message_timeout(member) - - @commands.command(name="untimeout") - @commands.guild_only() - async def untimeout_ticket(self, ctx, member): - """Removes timeout for specified user given that user is currently timed out.""" - - member = umember.assert_member(ctx.guild, member) - - timeout = db.get_timeout(member.id) - current_time = int(datetime.datetime.now().timestamp()) - - if timeout == False or (timeout != False and current_time > timeout['timestamp']): - await ctx.send('{0} is not currently timed out.'.format(member)) - return - - confirmation = await ctx.send(embed=ticket_embed.untimeout_confirmation(member, timeout['timestamp'])) - await confirmation.add_reaction('✅') - await confirmation.add_reaction('❎') - - def untimeout_check(reaction, user): - return user == ctx.author and confirmation == reaction.message and (str(reaction.emoji) == '✅' or str(reaction.emoji) == '❎') - - try: - reaction, user = await self.bot.wait_for('reaction_add', timeout=60.0, check=untimeout_check) - if str(reaction.emoji) == '✅': - timestamp = int(datetime.datetime.now().timestamp()) - db.set_timeout(member.id, timestamp) - await member.send(embed=ticket_embed.user_untimeout()) - await self.modmail_channel.send('{0} has been successfully untimed out.'.format(member)) - except asyncio.TimeoutError: - pass - except discord.errors.Forbidden: - await self.modmail_channel.send("Could not send timeout message to specified user due to privacy settings. Timeout has not been set.") - - await confirmation.delete() - - # ! Fix errors - async def cog_command_error(self, ctx, error): - if type(error) == commands.errors.CheckFailure: - print("Command executed in wrong channel.") - elif type(error) == commands.errors.MissingRequiredArgument: - await ctx.send("A valid user (and one who is still on the server) was not specified.") - else: - await ctx.send(str(error) + "\nIf you do not understand, contact a bot dev.") - - -async def post_reactions(message): - """Adds specified reactions to message. - - Args: - message (discord.Message): The specified message. - """ - - message_reactions = ['🗣️', '❎', '⏲️'] - - for reaction in message_reactions: - try: - await message.add_reaction(reaction) - except discord.errors.NotFound: - pass - -def setup(bot): - bot.add_cog(CommandActions(bot)) \ No newline at end of file diff --git a/cogs/commands.py b/cogs/commands.py new file mode 100644 index 0000000..f7249de --- /dev/null +++ b/cogs/commands.py @@ -0,0 +1,147 @@ +from typing import Literal, Optional +from discord.ext import commands +from discord import app_commands +import discord +import sys + +import db +from utils import config, actions + +import logging + +logger = logging.getLogger(__name__) + + +class Commands(commands.Cog): + """Cog to contain command action methods.""" + + def __init__(self, bot: commands.Bot, + modmail_channel: discord.TextChannel) -> None: + """Constructs necessary attributes for all command action methods. + + Args: + bot (commands.Bot): The bot object. + modmail_channel (discord.TextChannel): The modmail channel specified in config. + """ + + self.bot = bot + self.modmail_channel = modmail_channel + + async def cog_check(self, ctx: commands.Context): + if ctx.channel != self.modmail_channel: + return False + + if ctx.author == self.bot: + return False + + return True + + @commands.command(name="sync") + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def sync(self, + ctx: commands.Context, + spec: Optional[Literal["~"]] = None): + if spec == "~": + synced = await ctx.bot.tree.sync() + else: + ctx.bot.tree.copy_global_to(guild=ctx.guild) + synced = await ctx.bot.tree.sync(guild=ctx.guild) + + await ctx.send( + f"Synced {len(synced)} commands {'to the current guild' if spec is None else 'globally'}." + ) + return + + @app_commands.command(name="open") + @commands.guild_only() + async def open_ticket(self, interaction: discord.Interaction, + member: discord.Member): + """Opens ticket for specified user if no tickets are currently open.""" + + # TODO: Check if you can share this across commands + if member.bot: + await interaction.response.send_message( + f'Cannot open ticket for {member.name} (bot).') + return + await actions.message_open(self.bot, interaction, member) + + @app_commands.command(name="refresh") + @commands.guild_only() + async def refresh_ticket(self, interaction: discord.Interaction, + member: discord.Member): + """Resends embed for specified user if there is a ticket that is already open.""" + + if member.bot: + await interaction.response.send_message('Invalid member specified.' + ) + return + await actions.message_refresh(self.bot, interaction, member) + + @app_commands.command(name="close") + @commands.guild_only() + async def close_ticket(self, interaction: discord.Interaction, + member: discord.Member): + """Closes ticket for specified user given that a ticket is already open.""" + + if member.bot: + await interaction.response.send_message('Invalid member specified.' + ) + return + + ticket = await db.get_ticket_by_user(member.id) + + if not ticket: + await interaction.response.send_message( + f'There is no ticket open for {member.name}.') + return + + await actions.message_close(interaction, ticket, member) + + @app_commands.command(name="timeout") + @commands.guild_only() + async def timeout_ticket(self, interaction: discord.Interaction, + member: discord.Member): + """Times out specified user.""" + if member.bot: + await interaction.response.send_message('Invalid member specified.' + ) + return + await actions.message_timeout(interaction, member) + + @app_commands.command(name="untimeout") + @commands.guild_only() + async def untimeout_ticket(self, interaction: discord.Interaction, + member: discord.Member): + """Removes timeout for specified user given that user is currently timed out.""" + + if member.bot: + await interaction.response.send_message('Invalid member specified.' + ) + return + await actions.message_untimeout(interaction, member) + + # FIX: Fix errors + async def cog_command_error(self, ctx, error): + if type(error) == commands.errors.CheckFailure: + print("Command executed in wrong channel.") + elif type(error) == commands.errors.MissingRequiredArgument: + await ctx.send( + "A valid user (and one who is still on the server) was not specified." + ) + else: + await ctx.send( + str(error) + "\nIf you do not understand, contact a bot dev.") + + +async def setup(bot: commands.Bot): + try: + modmail_channel = await bot.fetch_channel(config.channel) + except Exception as e: + logger.error(e) + logger.fatal( + "The channel specified in config was not found. Please check your config." + ) + sys.exit(-1) + + await bot.add_cog(Commands(bot, modmail_channel)) diff --git a/cogs/listeners.py b/cogs/listeners.py index 56c106e..05832b6 100644 --- a/cogs/listeners.py +++ b/cogs/listeners.py @@ -1,30 +1,33 @@ -import datetime from discord.ext import commands import discord +import datetime +import sys + +import db +from utils import config, uformatter, ticket_embed + +import logging + +logger = logging.getLogger(__name__) -import utils.ticket_embed as ticket_embed, db -import utils.uformatter as uformatter -from cogs.command_actions import post_reactions class Listeners(commands.Cog): """Cog to contain all main listener methods.""" - def __init__(self, bot, guild, modmail_channel): + def __init__(self, bot: commands.Bot, + modmail_channel: discord.TextChannel) -> None: """Constructs necessary attributes for all command action methods. Args: bot (commands.Bot): The bot object. - guild (discord.Guild): The specified guild in config. - modmail_channel (discord.Channel): The specified channel in config. + modmail_channel (discord.TextChannel): The specified channel in config. """ self.bot = bot - self.guild = guild self.modmail_channel = modmail_channel - print("\nCog 'Listeners' loaded") - + @commands.Cog.listener() - async def on_message(self, message): + async def on_message(self, message: discord.Message): """Listener for both DM and server messages. Args: @@ -35,53 +38,75 @@ async def on_message(self, message): await self.handle_dm(message) return - async def handle_dm(self, message): + async def handle_dm(self, message: discord.Message): """Handle DM messages. Args: message (discord.Message): The current message. """ - + user = message.author - timeout = db.get_timeout(user.id) + timeout = await db.get_timeout(user.id) current_time = int(datetime.datetime.now().timestamp()) - if timeout != False and current_time < timeout['timestamp']: - await user.send(embed=ticket_embed.user_timeout(timeout['timestamp'])) + if timeout and current_time < timeout.timestamp: + await user.send(embed=ticket_embed.user_timeout(timeout.timestamp)) return - + response = uformatter.format_message(message) if not response.strip(): return - + # ! Fix for longer messages if len(response) > 1000: - await message.channel.send('Your message is too long. Please shorten your message or send in multiple parts.') + await message.channel.send( + 'Your message is too long. Please shorten your message or send in multiple parts.' + ) return - ticket = db.get_ticket_by_user(user.id) + ticket = await db.get_ticket_by_user(user.id) - if ticket['ticket_id'] == -1: - ticket_id = db.open_ticket(user.id) - ticket = db.get_ticket(ticket_id) + if not ticket: + ticket_id = await db.open_ticket(user.id) + ticket = await db.get_ticket(ticket_id) + logger.info(f"Opened new ticket for: {user.id}") try: - if ticket['message_id'] is not None and ticket['message_id'] != -1: - old_ticket_message = await self.modmail_channel.fetch_message(ticket['message_id']) + if ticket and ticket.message_id is not None and ticket.message_id != -1: + # WARNING: Fix handling other channels + # FIX: what if someone deletes the embed + old_ticket_message = await self.modmail_channel.fetch_message( + ticket.message_id) await old_ticket_message.delete() except discord.errors.NotFound: - await message.channel.send('You are being rate limited. Please wait a few seconds before trying again.') + await message.channel.send( + 'You are being rate limited. Please wait a few seconds before trying again.' + ) return - db.add_ticket_response(ticket['ticket_id'], user.id, response, False) + # `ticket` truthiness has been checked prior to the following lines + await db.add_ticket_response(ticket.ticket_id, user.id, response, + False) - ticket_message = await self.modmail_channel.send(embed=ticket_embed.channel_embed(self.guild, ticket['ticket_id'])) + message_embed, buttons_view = await ticket_embed.channel_embed( + self.bot, self.modmail_channel.guild, ticket) + ticket_message = await self.modmail_channel.send(embed=message_embed, + view=buttons_view) await message.add_reaction('📨') - db.update_ticket_message(ticket['ticket_id'], ticket_message.id) - await post_reactions(ticket_message) + await db.update_ticket_message(ticket.ticket_id, ticket_message.id) + + +async def setup(bot: commands.Bot): + try: + modmail_channel = await bot.fetch_channel(config.channel) + except Exception as e: + logger.error(e) + logger.fatal( + "The channel specified in config was not found. Please check your config." + ) + sys.exit(-1) -def setup(bot): - bot.add_cog(Listeners(bot)) \ No newline at end of file + await bot.add_cog(Listeners(bot, modmail_channel)) diff --git a/modmail.py b/modmail.py index cc77bbd..1a36c64 100644 --- a/modmail.py +++ b/modmail.py @@ -1,154 +1,67 @@ -import os - -import discord, json +import discord from discord.ext import commands -import db, utils.embed_reactions as embed_reactions -from cogs.command_actions import CommandActions -from cogs.listeners import Listeners - -with open('./config.json', 'r') as config_json: - config = json.load(config_json) - - # Load from environment variable overrides - if "MODMAIL_TOKEN" in os.environ: - config["token"] = os.getenv("MODMAIL_TOKEN") - if "MODMAIL_GUILD" in os.environ: - config["guild"] = int(os.getenv("MODMAIL_GUILD")) - if "MODMAIL_CHANNEL" in os.environ: - config["channel"] = int(os.getenv("MODMAIL_CHANNEL")) - if "MODMAIL_PREFIX" in os.environ: - config["prefix"] = os.getenv("MODMAIL_PREFIX") - if "MODMAIL_STATUS" in os.environ: - config["status"] = os.getenv("MODMAIL_STATUS") +import db +from utils import config + +import logging + +from utils.ticket_embed import MessageButtonsView + +logger = logging.getLogger('bot') +logger.setLevel(logging.DEBUG) # TODO: Change back to logging.INFO + +logging.basicConfig( + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%d-%b-%y %H:%M:%S') intents = discord.Intents.default() intents.members = True -bot = commands.Bot(intents=intents, command_prefix=config['prefix']) -guild = None -modmail_channel = None - -def bot_ready(func): - def wrapper(*args,**kwargs): - if guild is None or modmail_channel is None: - print('Bot not ready yet') +intents.message_content = True + + +class Modmail(commands.Bot): + + def __init__(self): + super().__init__( + intents=intents, + command_prefix=config.prefix, + description=config.status, + application_id=config.application_id, + ) + + async def setup_hook(self): + if await db.init(): + logger.info('Database sucessfully initialized!') + else: + logger.error('Error while initializing database!') return - ret = func(cursor, *args, *kwargs) - return ret - return wrapper - -@bot_ready -@bot.event -async def on_raw_reaction_add(payload): - """Handles reactions to messages. - - Args: - payload (discord.RawReactionActionEvent): Payload for raw reaction methods. - """ - # Ignore if self - if payload.user_id == bot.user.id: - return - - # Ignore if not in guild - if not payload.guild_id or not payload.guild_id == config['guild']: - return - - # Ignore if not in modmail channel - if not payload.channel_id == config['channel']: - return - - # Ignore if not unicode emoji - if not payload.emoji.is_unicode_emoji(): - return - - # Get member object - reaction_user = payload.member - - # Ignore if bot - if reaction_user.bot: - return - - # Get unicode emoji - emoji = payload.emoji.name - - # Get message object - message = await modmail_channel.fetch_message(payload.message_id) - - await handle_reaction(emoji, message, reaction_user) - -async def handle_reaction(emoji, message, reaction_user): - """Handles reactions for ModMail embeds. - - Args: - emoji (discord.PartialEmoji): The emoji being used. - message (discord.Message): The current message. - reaction_user (discord.Member): The user who triggered the reaction event. - """ - ticket = db.get_ticket_by_message(message.id) - if ticket['ticket_id'] == -1: - return - - await message.remove_reaction(emoji, reaction_user) - - embed_actions = embed_reactions.embed_reactions(bot, guild, modmail_channel, reaction_user, ticket) - - if str(emoji) == '🗣️': - await embed_actions.message_reply() - elif str(emoji) == '❎': - await embed_actions.message_close() - elif str(emoji) == '⏲️': - ticket_user = await bot.fetch_user(ticket['user']) - await embed_actions.message_timeout(ticket_user) - -@bot.event -async def on_ready(): - global guild, modmail_channel - - guild = bot.get_guild(config['guild']) - - if guild is None: - print('Failed to find Guild from provided ID.') - await bot.close() - return - - modmail_channel = guild.get_channel(config['channel']) - - if modmail_channel is None: - print('Failed to find Modmail Channel from provided ID.') - await bot.close() - return - - await bot.change_presence(activity=discord.Game(name=config['status']), status=discord.Status.online) - - bot.add_cog(Listeners(bot, guild, modmail_channel)) - bot.add_cog(CommandActions(bot, modmail_channel)) - -def ready(): - if db.init(): - print('Database sucessfully initialized!') - else: - print('Error while initializing database!') - return False - - if "guild" not in config: - print('No Guild ID provided.') - return False - - if "channel" not in config: - print('No Channel ID provided.') - return False - - if "prefix" not in config: - print('Failed to find prefix in config.') - return False - - return True - -success = ready() - -if success: - bot.run(config['token']) + for cog in ('commands', 'listeners'): + try: + await bot.load_extension(f'cogs.{cog}') + logger.debug(f'Imported cog "{cog}".') + except commands.errors.NoEntryPointError as e: + logger.warning(e) + except commands.errors.ExtensionNotFound as e: + logger.warning(e) + except commands.errors.ExtensionFailed as e: + logger.error(e) + logger.info("Loaded all cogs.") + + self.add_view(MessageButtonsView(bot)) + logger.info("Added all views.") + + async def on_ready(self): + await bot.change_presence(activity=discord.Game(name=config.status), + status=discord.Status.online) + + logger.info(f"Bot \"{bot.user.name}\" is now connected.") + + async def on_command_error(self, ctx: commands.Context, exception) -> None: + await super().on_command_error(ctx, exception) + await ctx.send(exception) + -else: - print('Error during starting process') +bot = Modmail() +bot.run(config.token) diff --git a/utils/actions.py b/utils/actions.py new file mode 100644 index 0000000..401b066 --- /dev/null +++ b/utils/actions.py @@ -0,0 +1,226 @@ +import asyncio +import datetime +from typing import Optional +import discord +from discord.ext import commands + +import db +from utils import ticket_embed, uformatter + +import logging + +logger = logging.getLogger(__name__) + + +async def waiter( + bot: commands.Bot, + interaction: discord.Interaction) -> Optional[discord.Message]: + + def check(message: discord.Message) -> bool: + return message.author == interaction.user and message.channel == interaction.channel + + try: + message = await bot.wait_for('message', check=check) + except asyncio.TimeoutError: + return None + + return message + + +async def message_open(bot: commands.Bot, interaction: discord.Interaction, + member: discord.Member): + ticket = await db.get_ticket_by_user(member.id) + + if ticket: + await interaction.response.send_message( + f'There is already a ticket open for {member.name}.') + return + + ticket_id = await db.open_ticket(member.id) + ticket = await db.get_ticket(ticket_id) + + message_embed, buttons_view = await ticket_embed.channel_embed( + bot, interaction.guild, ticket) + await interaction.response.send_message(embed=message_embed, + view=buttons_view) + + ticket_message = await interaction.original_response() + logger.debug(f"Ticket message: {ticket_message}") + await db.update_ticket_message(ticket.ticket_id, ticket_message.id) + + +async def message_refresh(bot: commands.Bot, interaction: discord.Interaction, + member: discord.Member): + ticket = await db.get_ticket_by_user(member.id) + + if not ticket: + await interaction.response.send_message( + f'There is no ticket open for {member.name}.') + return + + message_embed, buttons_view = await ticket_embed.channel_embed( + bot, interaction.guild, ticket) + await interaction.response.send_message(embed=message_embed, + view=buttons_view) + + message = await interaction.original_response() + await db.update_ticket_message(ticket.ticket_id, message.id) + + if ticket.message_id is not None: + old_ticket_message = await interaction.channel.fetch_message( + ticket.message_id) + await old_ticket_message.delete() + + +async def message_close(interaction: discord.Interaction, ticket: db.Ticket, + user: discord.Member): + """Sends close confirmation embed, and if confirmed, will close the ticket.""" + + close_embed, confirmation_view = ticket_embed.close_confirmation(user) + + await interaction.response.send_message(embed=close_embed, + view=confirmation_view) + confirmation_view.message = await interaction.original_response() + + await confirmation_view.wait() + + if confirmation_view.value is None: + return + elif confirmation_view.value: + await db.close_ticket(ticket.ticket_id) + ticket_message = await interaction.channel.fetch_message( + ticket.message_id) + await ticket_message.delete() + + await interaction.channel.send( + embed=ticket_embed.closed_ticket(interaction.user, user)) + logger.info( + f"Ticket for user {user.id} closed by {interaction.user.id}") + + +async def message_reply(bot: commands.Bot, interaction: discord.Interaction, + ticket: db.Ticket): + """Sends reply embed, and if confirmed, will add the staff message to the ticket embeds.""" + + ticket_user = interaction.guild.get_member(ticket.user) + + if not ticket_user: + await interaction.channel.send( + "Cannot reply to ticket as user is not in the server.") + return + + task = bot.loop.create_task(waiter(bot, interaction)) + + reply_embed, cancel_view = ticket_embed.reply_cancel(ticket_user, task) + await interaction.response.send_message(embed=reply_embed, + view=cancel_view) + cancel_view.message = await interaction.original_response() + + await task + await cancel_view.view_cleanup() + + try: + message = task.result() + + if not message: + return + + response = uformatter.format_message(message) + + if not response.strip(): + return + + # ! Fix for longer messages + if len(response) > 1000: + await interaction.channel.send( + 'Your message is too long. Please shorten your message or send in multiple parts.' + ) + return + + try: + await ticket_user.send( + embed=ticket_embed.user_embed(interaction.guild, response)) + await db.add_ticket_response(ticket.ticket_id, interaction.user.id, + response, True) + ticket_message = await interaction.channel.fetch_message( + ticket.message_id) + + channel_embed, _ = await ticket_embed.channel_embed( + bot, interaction.guild, ticket) + await ticket_message.edit(embed=channel_embed) + except discord.errors.Forbidden: + await interaction.channel.send( + "Could not send ModMail message to specified user due to privacy settings." + ) + + except Exception as e: + raise RuntimeError(e) + except asyncio.CancelledError: + return + + +async def message_timeout(interaction: discord.Interaction, + member: discord.Member): + """Sends timeout confirmation embed, and if confirmed, will timeout the specified ticket user. + + Args: + ticket_user (discord.User): The ticket user. + """ + + timeout_embed, confirmation_view = ticket_embed.timeout_confirmation( + member) + + await interaction.response.send_message(embed=timeout_embed, + view=confirmation_view) + confirmation_view.message = await interaction.original_response() + + await confirmation_view.wait() + + if confirmation_view.value is None: + return + elif confirmation_view.value: + timeout = datetime.datetime.now() + datetime.timedelta(days=1) + timestamp = int(timeout.timestamp()) + await db.set_timeout(member.id, timestamp) + logger.info(f"User {member.id} timed out by {interaction.user.id}") + + await interaction.channel.send( + f'{member.name} has been successfully timed out for 24 hours. They will be able to message ModMail again after .' + ) + + # TODO: Handle when DMs are disabled + # "Could not send timeout message to specified user due to privacy settings. Timeout has not been set." + await member.send(embed=ticket_embed.user_timeout(timestamp)) + + +async def message_untimeout(interaction: discord.Interaction, + member: discord.Member): + timeout = await db.get_timeout(member.id) + current_time = int(datetime.datetime.now().timestamp()) + + if not timeout or (current_time > timeout.timestamp): + await interaction.response.send_message( + f'{member.name} is not currently timed out.') + return + + untimeout_embed, confirmation_view = ticket_embed.untimeout_confirmation( + member, timeout.timestamp) + + await interaction.response.send_message(embed=untimeout_embed, + view=confirmation_view) + confirmation_view.message = await interaction.original_response() + + await confirmation_view.wait() + + if confirmation_view.value is None: + return + elif confirmation_view.value: + timestamp = int(datetime.datetime.now().timestamp()) + await db.set_timeout(member.id, timestamp) + logger.info(f"Timeout removed for {member.id}.") + + await interaction.channel.send( + f'Timeout has been removed for {member.name}.') + + # TODO: Handle when DMs are disabled + await member.send(embed=ticket_embed.user_untimeout()) diff --git a/utils/embed_reactions.py b/utils/embed_reactions.py deleted file mode 100644 index 5da2ad1..0000000 --- a/utils/embed_reactions.py +++ /dev/null @@ -1,143 +0,0 @@ -import asyncio, datetime -import discord -import db, utils.uformatter as uformatter, utils.ticket_embed as ticket_embed - -class embed_reactions(): - """Class to contain channel embed reaction methods.""" - - def __init__(self, bot, guild, modmail_channel, reaction_user, ticket=None): - """Constructs necessary attributes for all embed reaction methods. - - Args: - bot (discord.Bot): The bot object. - guild (discord.Guild): The current guild. - modmail_channel (discord.Channel): The ModMail guild channel. - reaction_user (discord.User): The user who triggered the reaction. - ticket (DB Object, optional): Object containing values for a specific ticket. Defaults to None. - """ - - self.bot = bot - self.guild = guild - self.modmail_channel = modmail_channel - self.reaction_user = reaction_user - self.ticket = ticket - - async def message_reply(self): - """Sends reply embed, and if confirmed, will add the staff message to the ticket embeds.""" - - ticket_user = self.guild.get_member(self.ticket['user']) - - if not ticket_user: - await self.modmail_channel.send("Cannot reply to ticket as user is not in the server.") - return - - cancel = await self.modmail_channel.send(embed=ticket_embed.reply_cancel(ticket_user)) - await cancel.add_reaction('❎') - - def reply_cancel(reaction, user): - return user == self.reaction_user and cancel == reaction.message and str(reaction.emoji) == '❎' - def reply_message(message): - return message.author == self.reaction_user and message.channel == self.modmail_channel - - try: - tasks = [ - asyncio.create_task(self.bot.wait_for('reaction_add', timeout=60.0, check=reply_cancel), name='cancel'), - asyncio.create_task(self.bot.wait_for('message', timeout=60.0,check=reply_message), name='respond') - ] - - done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - - event: asyncio.Task = list(done)[0] - - for task in pending: - try: - task.cancel() - except asyncio.CancelledError: - pass - - - if event.get_name() == 'respond': - message = event.result() - response = uformatter.format_message(message) - - if not response.strip(): - return - - # ! Fix for longer messages - if len(response) > 1000: - await message.channel.send('Your message is too long. Please shorten your message or send in multiple parts.') - return - # Removed - can add as config option later - # await message.delete() - await ticket_user.send(embed=ticket_embed.user_embed(self.guild, response)) - db.add_ticket_response(self.ticket['ticket_id'], self.reaction_user.id, response, True) - ticket_message = await self.modmail_channel.fetch_message(self.ticket['message_id']) - await ticket_message.edit(embed=ticket_embed.channel_embed(self.guild, self.ticket['ticket_id'])) - - except asyncio.TimeoutError: - pass - except discord.errors.Forbidden: - await self.modmail_channel.send("Could not send ModMail message to specified user due to privacy settings.") - except Exception as e: - raise RuntimeError(e) - - await cancel.delete() - - async def message_close(self): - """Sends close confirmation embed, and if confirmed, will close the ticket.""" - - ticket_user = await self.bot.fetch_user(self.ticket['user']) - - if not ticket_user: - ticket_user = "user" - - confirmation = await self.modmail_channel.send(embed=ticket_embed.close_confirmation(ticket_user)) - await confirmation.add_reaction('✅') - await confirmation.add_reaction('❎') - - def close_check(reaction, user): - return user == self.reaction_user and confirmation == reaction.message and (str(reaction.emoji) == '✅' or str(reaction.emoji) == '❎') - - try: - reaction, user = await self.bot.wait_for('reaction_add', timeout=60.0, check=close_check) - if str(reaction.emoji) == '✅': - db.close_ticket(self.ticket['ticket_id']) - ticket_message = await self.modmail_channel.fetch_message(self.ticket['message_id']) - await ticket_message.delete() - await self.modmail_channel.send(embed=ticket_embed.closed_ticket(self.reaction_user, ticket_user)) - except asyncio.TimeoutError: - pass - - await confirmation.delete() - - async def message_timeout(self, ticket_user): - """Sends timeout confirmation embed, and if confirmed, will timeout the specified ticket user. - - Args: - ticket_user (discord.User): The ticket user. - """ - - confirmation = await self.modmail_channel.send(embed=ticket_embed.timeout_confirmation(ticket_user)) - await confirmation.add_reaction('✅') - await confirmation.add_reaction('❎') - - def timeout_check(reaction, user): - return user == self.reaction_user and confirmation == reaction.message and (str(reaction.emoji) == '✅' or str(reaction.emoji) == '❎') - - try: - reaction, user = await self.bot.wait_for('reaction_add', timeout=60.0, check=timeout_check) - if str(reaction.emoji) == '✅': - # Change below value to custom - timeout = datetime.datetime.now() + datetime.timedelta(days=1) - timestamp = int(timeout.timestamp()) - db.set_timeout(ticket_user.id, timestamp) - await self.modmail_channel.send('{0} has been successfully timed out for 24 hours. They will be able to message ModMail again after .'.format(ticket_user, timestamp)) - await ticket_user.send(embed=ticket_embed.user_timeout(timestamp)) - except asyncio.TimeoutError: - pass - except discord.errors.Forbidden: - pass - except Exception as e: - raise RuntimeError(e) - - await confirmation.delete() diff --git a/utils/ticket_embed.py b/utils/ticket_embed.py index 78d3541..6c2ec5d 100644 --- a/utils/ticket_embed.py +++ b/utils/ticket_embed.py @@ -1,7 +1,102 @@ +import asyncio import discord +from discord.ext import commands +from discord.utils import format_dt import db +from typing import Optional, Union -def user_embed(guild, message): +from utils import actions, config + +import logging + +logger = logging.getLogger(__name__) + + +class ConfirmationView(discord.ui.View): + + def __init__(self, + message: Optional[discord.Message] = None, + timeout: int = 60): + super().__init__(timeout=timeout) + self.message = message + self.value = None + + async def on_timeout(self) -> None: + await self.message.delete() + await super().on_timeout() + + @discord.ui.button(label='Yes', style=discord.ButtonStyle.green) + async def yes(self, interaction: discord.Interaction, + button: discord.ui.Button): + self.value = True + await self.message.delete() + self.stop() + + @discord.ui.button(label='No', style=discord.ButtonStyle.red) + async def no(self, interaction: discord.Interaction, + button: discord.ui.Button): + self.value = False + await self.message.delete() + self.stop() + + +class CancelView(discord.ui.View): + + def __init__(self, + task: asyncio.Task, + message: Optional[discord.Message] = None, + timeout: int = 60) -> None: + super().__init__(timeout=timeout) + self.message = message + self.task = task + + async def view_cleanup(self) -> None: + self.stop() + if self.message: + await self.message.delete() + + self.task.cancel() + + async def on_timeout(self) -> None: + await self.view_cleanup() + await super().on_timeout() + + @discord.ui.button(label='Cancel', style=discord.ButtonStyle.red) + async def cancel(self, interaction: discord.Interaction, + button: discord.ui.Button): + await self.view_cleanup() + + +class MessageButtonsView(discord.ui.View): + + def __init__(self, bot: commands.Bot): + super().__init__(timeout=None) + self.bot = bot + + @discord.ui.button(emoji="💬", custom_id=f"{config.id_prefix}:reply") + async def mail_reply(self, interaction: discord.Interaction, + button: discord.ui.Button): + ticket = await db.get_ticket_by_message(interaction.message.id) + await actions.message_reply(self.bot, interaction, ticket) + + @discord.ui.button(emoji="❎", custom_id=f"{config.id_prefix}:close") + async def mail_close(self, interaction: discord.Interaction, + button: discord.ui.Button): + ticket = await db.get_ticket_by_message(interaction.message.id) + member = interaction.guild.get_member( + ticket.user) or await interaction.guild.fetch_member(ticket.user) + await actions.message_close(interaction, ticket, member) + + @discord.ui.button(emoji="⏲️", custom_id=f"{config.id_prefix}:timeout") + async def mail_timeout(self, interaction: discord.Interaction, + button: discord.ui.Button): + ticket = await db.get_ticket_by_message(interaction.message.id) + member = interaction.guild.get_member( + ticket.user) or await interaction.guild.fetch_member(ticket.user) + await actions.message_timeout(interaction, member) + + +def user_embed(guild: discord.Guild, message: str) -> discord.Embed: """Returns formatted embed for user DMs. Args: @@ -12,45 +107,52 @@ def user_embed(guild, message): discord.Embed: User DM embed containing the message content. """ - message_embed = discord.Embed( - title="New Mail from {0}".format(guild.name), - description=message - ) + message_embed = discord.Embed(title=f"New Mail from {guild.name}", + description=message) return message_embed -def channel_embed(guild, ticket_id): + +async def channel_embed( + bot: commands.Bot, guild: discord.Guild, + ticket: db.Ticket) -> tuple[discord.Embed, discord.ui.View]: """Returns formatted embed for channel. Args: guild (discord.Guild): The guild. - ticket_id (int): The ticket id. + ticket (Ticket): The ticket. Returns: discord.Embed: Channel embed containing message and user content. """ - ticket = db.get_ticket(ticket_id) - - ticket_member = guild.get_member(ticket['user']) + # WARNING: Handle when user is not in guild or in DM listener + ticket_member = guild.get_member(ticket.user) or await guild.fetch_member( + ticket.user) message_embed = discord.Embed( - title="ModMail Conversation for {0}".format(ticket_member), - description="User {0} has **{1}** roles\n Joined Discord: **{2}**\n Joined Server: **{3}**" - .format(ticket_member.mention, len(ticket_member.roles), ticket_member.created_at.strftime("%B %d %Y"), ticket_member.joined_at.strftime("%B %d %Y")) + title=f"ModMail Conversation for {ticket_member}", + description= + f"User {ticket_member.mention} has **{len(ticket_member.roles) - 1}** roles\n Joined Discord: **{format_dt(ticket_member.created_at, 'D')}**\n Joined Server: **{format_dt(ticket_member.joined_at, 'D')}**" ) - responses = db.get_ticket_responses(ticket_id) + responses = await db.get_ticket_responses(ticket.ticket_id) for response in responses: author = 'user' - if response['as_server']: - author = '{0} as server'.format(guild.get_member(response['user'])) - message_embed.add_field(name=", {1} wrote".format(response['timestamp'], author), value=response['response'], inline=False) + if response.as_server: + author = f'{guild.get_member(response.user)} as server' + message_embed.add_field( + name=f", {author} wrote", + value=response.response, + inline=False) + + message_buttons_view = MessageButtonsView(bot) + return message_embed, message_buttons_view - return message_embed -def close_confirmation(member): +def close_confirmation( + member: discord.Member) -> tuple[discord.Embed, discord.ui.View]: """Returns embed for ticket close confirmation. Args: @@ -60,13 +162,18 @@ def close_confirmation(member): discord.Embed: Channel embed for close confirmation. """ + confirmation_view = ConfirmationView() + message_embed = discord.Embed( - description="Do you want to close the ModMail conversation for **{0}**?".format(member) + description= + f"Do you want to close the ModMail conversation for **{member.name}**?" ) - return message_embed + return message_embed, confirmation_view + -def timeout_confirmation(member): +def timeout_confirmation( + member: discord.Member) -> tuple[discord.Embed, discord.ui.View]: """Returns embed for ticket timeout confirmation. Args: @@ -76,13 +183,17 @@ def timeout_confirmation(member): discord.Embed: Channel embed for timeout confirmation. """ + confirmation_view = ConfirmationView() + message_embed = discord.Embed( - description="Do you want to timeout **{0}** for 24 hours?".format(member) - ) + description=f"Do you want to timeout **{member.name}** for 24 hours?") + + return message_embed, confirmation_view - return message_embed -def untimeout_confirmation(member, timeout): +def untimeout_confirmation( + member: discord.Member, + timeout: int) -> tuple[discord.Embed, discord.ui.View]: """Returns embed for ticket untimeout confirmation. Args: @@ -92,14 +203,18 @@ def untimeout_confirmation(member, timeout): Returns: discord.Embed: Channel embed for untimeout confirmation. """ + confirmation_view = ConfirmationView() message_embed = discord.Embed( - description="Do you want to untimeout **{0}** (they are currently timed out until )?".format(member, timeout) + description= + f"Do you want to untimeout **{member.name}** (they are currently timed out until )?" ) - return message_embed + return message_embed, confirmation_view + -def reply_cancel(member): +def reply_cancel(member: discord.Member, + task: asyncio.Task) -> tuple[discord.Embed, discord.ui.View]: """Returns embed for replying to ticket with cancel reaction. Args: @@ -109,17 +224,19 @@ def reply_cancel(member): discord.Embed: Channel embed for ticket reply. """ + cancel_view = CancelView(task) message_embed = discord.Embed( - description="Replying to ModMail conversation for **{0}**".format(member) - ) + description=f"Replying to ModMail conversation for **{member}**") + + return message_embed, cancel_view - return message_embed -def closed_ticket(staff, member): +def closed_ticket(staff: Union[discord.User, discord.Member], + member: discord.Member) -> discord.Embed: """Returns embed for closed ticket. Args: - staff (discord.Member): The staff member who closed the ticket. + staff (Union[discord.User, discord.Member]): The staff member who closed the ticket. member (discord.Member): The ticket user. Returns: @@ -127,12 +244,13 @@ def closed_ticket(staff, member): """ message_embed = discord.Embed( - description="**{0}** closed the ModMail conversation for **{1}**".format(staff, member) - ) + description= + f"**{staff}** closed the ModMail conversation for **{member}**") return message_embed -def user_timeout(timeout): + +def user_timeout(timeout: int) -> discord.Embed: """Returns embed for user timeout in DMs. Args: @@ -143,20 +261,22 @@ def user_timeout(timeout): """ message_embed = discord.Embed( - description="You have been timed out. You will be able to message ModMail again after .".format(timeout) + description= + f"You have been timed out. You will be able to message ModMail again after ." ) return message_embed -def user_untimeout(): + +def user_untimeout() -> discord.Embed: """Returns embed for user untimeout in DMs. Returns: discord.Embed: Channel embed for user untimeout. """ - + message_embed = discord.Embed( - description="Your timeout has been removed. You can message ModMail again.".format() - ) + description= + "Your timeout has been removed. You can message ModMail again.") - return message_embed \ No newline at end of file + return message_embed diff --git a/utils/uformatter.py b/utils/uformatter.py index 0910d68..57070f4 100644 --- a/utils/uformatter.py +++ b/utils/uformatter.py @@ -1,11 +1,15 @@ -def format_message(message): +import discord + + +def format_message(message: discord.Message) -> str: attachments = message.attachments - formatted_message = '{0}'.format(message.content) + formatted_message = f'{message.content}' for attachment in attachments: - type = 'Unknown' + attachment_type = 'Unknown' if attachment.content_type.startswith('image'): - type = 'Image' + attachment_type = 'Image' elif attachment.content_type.startswith('video'): - type = 'Video' - formatted_message += '\n[{0} Attachment]({1})'.format(type, attachment.url) - return formatted_message \ No newline at end of file + attachment_type = 'Video' + formatted_message += f'\n[{attachment_type} Attachment]({attachment.url})' + + return formatted_message From c7527e4e8bda04170593be25b9fecebb96a88e90 Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Fri, 7 Jul 2023 17:21:46 +0930 Subject: [PATCH 05/22] Updated error handling and cog checks. --- cogs/commands.py | 69 ++++++++++++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/cogs/commands.py b/cogs/commands.py index f7249de..95692c2 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -29,19 +29,45 @@ def __init__(self, bot: commands.Bot, async def cog_check(self, ctx: commands.Context): if ctx.channel != self.modmail_channel: + await ctx.send("Command must be used in the modmail channel.") return False if ctx.author == self.bot: + await ctx.send("Bots cannot use commands.") return False return True + async def interaction_check(self, interaction: discord.Interaction): + if interaction.channel != self.modmail_channel: + await interaction.response.send_message( + "Command must be used in the modmail channel.") + return False + + if "resolved" in interaction.data and "users" in interaction.data[ + "resolved"] and len(interaction.data["resolved"]["users"]) > 0: + user_id = next(iter(interaction.data["resolved"]["users"])) + user = interaction.data["resolved"]["users"][user_id] + if "bot" in user and user["bot"]: + await interaction.response.send_message( + "Invalid user specified.") + return False + + return True + @commands.command(name="sync") @commands.has_permissions(administrator=True) @commands.guild_only() async def sync(self, ctx: commands.Context, spec: Optional[Literal["~"]] = None): + """ + Syncs commands to the current guild or globally. + + Args: + ctx (commands.Context): The command context. + spec (Optional[Literal["~"]]): If "~", syncs globally. Defaults to None. + """ if spec == "~": synced = await ctx.bot.tree.sync() else: @@ -59,11 +85,6 @@ async def open_ticket(self, interaction: discord.Interaction, member: discord.Member): """Opens ticket for specified user if no tickets are currently open.""" - # TODO: Check if you can share this across commands - if member.bot: - await interaction.response.send_message( - f'Cannot open ticket for {member.name} (bot).') - return await actions.message_open(self.bot, interaction, member) @app_commands.command(name="refresh") @@ -72,10 +93,6 @@ async def refresh_ticket(self, interaction: discord.Interaction, member: discord.Member): """Resends embed for specified user if there is a ticket that is already open.""" - if member.bot: - await interaction.response.send_message('Invalid member specified.' - ) - return await actions.message_refresh(self.bot, interaction, member) @app_commands.command(name="close") @@ -84,16 +101,11 @@ async def close_ticket(self, interaction: discord.Interaction, member: discord.Member): """Closes ticket for specified user given that a ticket is already open.""" - if member.bot: - await interaction.response.send_message('Invalid member specified.' - ) - return - ticket = await db.get_ticket_by_user(member.id) if not ticket: await interaction.response.send_message( - f'There is no ticket open for {member.name}.') + f'There is no ticket open for {member.name}.', ephemeral=True) return await actions.message_close(interaction, ticket, member) @@ -103,10 +115,6 @@ async def close_ticket(self, interaction: discord.Interaction, async def timeout_ticket(self, interaction: discord.Interaction, member: discord.Member): """Times out specified user.""" - if member.bot: - await interaction.response.send_message('Invalid member specified.' - ) - return await actions.message_timeout(interaction, member) @app_commands.command(name="untimeout") @@ -115,16 +123,12 @@ async def untimeout_ticket(self, interaction: discord.Interaction, member: discord.Member): """Removes timeout for specified user given that user is currently timed out.""" - if member.bot: - await interaction.response.send_message('Invalid member specified.' - ) - return await actions.message_untimeout(interaction, member) - # FIX: Fix errors - async def cog_command_error(self, ctx, error): + async def cog_command_error(self, ctx: commands.Context, + error: commands.CommandError): if type(error) == commands.errors.CheckFailure: - print("Command executed in wrong channel.") + logger.error("Checks failed for interaction.") elif type(error) == commands.errors.MissingRequiredArgument: await ctx.send( "A valid user (and one who is still on the server) was not specified." @@ -133,10 +137,23 @@ async def cog_command_error(self, ctx, error): await ctx.send( str(error) + "\nIf you do not understand, contact a bot dev.") + async def cog_app_command_error(self, interaction: discord.Interaction, + error: app_commands.AppCommandError): + if type(error) == app_commands.errors.CheckFailure: + logger.error("Checks failed for interaction.") + else: + await interaction.response.send_message( + str(error) + "\nIf you do not understand, contact a bot dev.") + logger.error(error) + async def setup(bot: commands.Bot): try: modmail_channel = await bot.fetch_channel(config.channel) + + if type(modmail_channel) != discord.TextChannel: + raise TypeError( + "The channel specified in config was not a text channel.") except Exception as e: logger.error(e) logger.fatal( From 1a8b2e106126ebd5ac05e64cf72a1cd085d4794d Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Fri, 7 Jul 2023 17:22:20 +0930 Subject: [PATCH 06/22] Updated error handling and pagination view for embeds. --- cogs/listeners.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/cogs/listeners.py b/cogs/listeners.py index 05832b6..5e56363 100644 --- a/cogs/listeners.py +++ b/cogs/listeners.py @@ -35,8 +35,19 @@ async def on_message(self, message: discord.Message): """ if message.guild is None and not message.author.bot: + # Handle if user is not in guild + if not (self.modmail_channel.guild.get_member(message.author.id) + or await self.modmail_channel.guild.fetch_member( + message.author.id)): + try: + await message.author.send( + 'Unable to send message. Please ensure you have joined the server.' + ) + except discord.errors.Forbidden: + pass + return + await self.handle_dm(message) - return async def handle_dm(self, message: discord.Message): """Handle DM messages. @@ -74,9 +85,7 @@ async def handle_dm(self, message: discord.Message): logger.info(f"Opened new ticket for: {user.id}") try: - if ticket and ticket.message_id is not None and ticket.message_id != -1: - # WARNING: Fix handling other channels - # FIX: what if someone deletes the embed + if ticket and ticket.message_id is not None: old_ticket_message = await self.modmail_channel.fetch_message( ticket.message_id) await old_ticket_message.delete() @@ -90,8 +99,12 @@ async def handle_dm(self, message: discord.Message): await db.add_ticket_response(ticket.ticket_id, user.id, response, False) - message_embed, buttons_view = await ticket_embed.channel_embed( - self.bot, self.modmail_channel.guild, ticket) + embeds = await ticket_embed.channel_embed(self.modmail_channel.guild, + ticket) + + message_embed, buttons_view = await ticket_embed.MessageButtonsView( + self.bot, embeds).return_paginated_embed() + ticket_message = await self.modmail_channel.send(embed=message_embed, view=buttons_view) await message.add_reaction('📨') @@ -100,8 +113,17 @@ async def handle_dm(self, message: discord.Message): async def setup(bot: commands.Bot): + """Setup function for the listeners cog. + + Args: + bot (commands.Bot): The bot. + """ try: modmail_channel = await bot.fetch_channel(config.channel) + + if type(modmail_channel) != discord.TextChannel: + raise TypeError( + "The channel specified in config was not a text channel.") except Exception as e: logger.error(e) logger.fatal( From 6c1405f27d87d641edc1d33e713ee2828415a6c8 Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Fri, 7 Jul 2023 17:24:37 +0930 Subject: [PATCH 07/22] Added pagination support to message embeds. --- utils/pagination.py | 79 +++++++++++++++++++++++ utils/ticket_embed.py | 145 ++++++++++++++++++++++++++++++++---------- 2 files changed, 191 insertions(+), 33 deletions(-) create mode 100644 utils/pagination.py diff --git a/utils/pagination.py b/utils/pagination.py new file mode 100644 index 0000000..ac6de0f --- /dev/null +++ b/utils/pagination.py @@ -0,0 +1,79 @@ +from typing import Union, Optional +from collections.abc import Collection + +import discord + +NAME_SIZE_LIMIT = 256 +VALUE_SIZE_LIMIT = 1024 + + +def paginated_embed_menus( + names: Collection[str], + values: Collection[str], + pagesize: int = 10, + *, + inline: Union[Collection[bool], bool] = False, + embed_dict: Optional[dict] = None, +) -> Collection[discord.Embed]: + """ + Generates embeds for a paginated embed view. + + Args: + names (Collection[str]): Names of fields to be added/paginated. + values (Collection[str]): Values of fields to be added/paginated. + pagesize (int, optional): Maximum number of items per page. Defaults to 10. + inline (Union[Collection[bool], bool], optional): Whether embed fields should be inline or not. Defaults to False. + embed_dict (Optional[dict], optional): Partial embed dictionary (for setting a title, description, etc.). Footer and fields must not be set. Defaults to None. + + Returns: + Collection[discord.Embed]: Collection of embeds for paginated embed view. + """ + N = len(names) + if N != len(values): + raise ValueError( + 'names and values for paginated embed menus must be of equal length.' + ) + if isinstance(inline, bool): + inline = [inline] * N + elif N != len(inline): + raise ValueError( + '"inline" must be boolean or a collection of booleans of equal length to names/values for paginated embed menus.' + ) + + if embed_dict: + if 'title' in embed_dict and len(embed_dict['title']) > 256: + raise ValueError('title cannot be over 256 characters') + if 'description' in embed_dict and len( + embed_dict['description']) > 4096: + raise ValueError('desription cannot be over 4096 characters') + if 'footer' in embed_dict: + raise ValueError('embed_dict "footer" key must not be set.') + if 'fields' in embed_dict: + raise ValueError('embed_dict "fields" key must not be set.') + else: + embed_dict = { # default + 'description': 'Here is a list of entries.' + } + + if N == 0: + return [discord.Embed.from_dict(embed_dict)] + + embeds: Collection[discord.Embed] = [] + current: discord.Embed = discord.Embed.from_dict(embed_dict) + pages = 1 + items = 0 + for name, value, inline_field in zip(names, values, inline): + if items == pagesize or len(current) + len(name) + len( + value) > 5090: # leave 10 chars for footers + embeds.append(current) + current = discord.Embed.from_dict(embed_dict) + pages += 1 + items = 0 + + current.add_field(name=name, value=value, inline=inline_field) + items += 1 + embeds.append(current) + for page, embed in enumerate(embeds): + embed.set_footer(text=f"Page {page+1}/{pages}") + + return embeds diff --git a/utils/ticket_embed.py b/utils/ticket_embed.py index 6c2ec5d..fef5da2 100644 --- a/utils/ticket_embed.py +++ b/utils/ticket_embed.py @@ -3,16 +3,19 @@ from discord.ext import commands from discord.utils import format_dt import db -from typing import Optional, Union +from typing import Collection, Optional, Union from utils import actions, config import logging +from utils.pagination import paginated_embed_menus + logger = logging.getLogger(__name__) class ConfirmationView(discord.ui.View): + """Confirmation view for yes/no operations.""" def __init__(self, message: Optional[discord.Message] = None, @@ -41,6 +44,7 @@ async def no(self, interaction: discord.Interaction, class CancelView(discord.ui.View): + """Cancel view for cancelling operations if requested by the user.""" def __init__(self, task: asyncio.Task, @@ -68,33 +72,105 @@ async def cancel(self, interaction: discord.Interaction, class MessageButtonsView(discord.ui.View): + """Message buttons view for ticket messages.""" - def __init__(self, bot: commands.Bot): + def __init__(self, bot: commands.Bot, embeds: Collection[discord.Embed]): super().__init__(timeout=None) self.bot = bot + self.embeds = embeds + self.current_page = len(self.embeds) - 1 @discord.ui.button(emoji="💬", custom_id=f"{config.id_prefix}:reply") - async def mail_reply(self, interaction: discord.Interaction, - button: discord.ui.Button): + async def mail_reply(self, interaction: discord.Interaction, _): + """ + Replies to the ticket. + """ ticket = await db.get_ticket_by_message(interaction.message.id) await actions.message_reply(self.bot, interaction, ticket) @discord.ui.button(emoji="❎", custom_id=f"{config.id_prefix}:close") - async def mail_close(self, interaction: discord.Interaction, - button: discord.ui.Button): + async def mail_close(self, interaction: discord.Interaction, _): + """ + Closes the ticket. + """ ticket = await db.get_ticket_by_message(interaction.message.id) member = interaction.guild.get_member( ticket.user) or await interaction.guild.fetch_member(ticket.user) await actions.message_close(interaction, ticket, member) @discord.ui.button(emoji="⏲️", custom_id=f"{config.id_prefix}:timeout") - async def mail_timeout(self, interaction: discord.Interaction, - button: discord.ui.Button): + async def mail_timeout(self, interaction: discord.Interaction, _): + """ + Times out the user of the ticket. + """ ticket = await db.get_ticket_by_message(interaction.message.id) member = interaction.guild.get_member( ticket.user) or await interaction.guild.fetch_member(ticket.user) await actions.message_timeout(interaction, member) + @discord.ui.button(emoji="⬅️", + style=discord.ButtonStyle.blurple, + custom_id=f"{config.id_prefix}:before_page") + async def before_page(self, interaction: discord.Interaction, _): + """ + Goes to the previous page. + """ + if len(self.embeds) == 0: + await interaction.response.send_message( + "Please refresh this ticket to be able to use pagination.", + ephemeral=True) + return + + if self.current_page > 0: + self.current_page -= 1 + self.update_pagination_buttons() + await self.update_view(interaction) + + @discord.ui.button(emoji="➡️", + style=discord.ButtonStyle.blurple, + custom_id=f"{config.id_prefix}:next_page") + async def next_page(self, interaction: discord.Interaction, _): + """ + Goes to the next page. + """ + if len(self.embeds) == 0: + await interaction.response.send_message( + "Please refresh this ticket to be able to use pagination.", + ephemeral=True) + return + + if self.current_page < len(self.embeds) - 1: + self.current_page += 1 + self.update_pagination_buttons() + await self.update_view(interaction) + + def update_pagination_buttons(self): + """ + Updates the buttons based on the current page. + """ + for i in self.children: + i.disabled = False + if self.current_page == 0: + self.children[3].disabled = True + if self.current_page == len(self.embeds) - 1: + self.children[4].disabled = True + + async def update_view(self, interaction: discord.Interaction): + """ + Updates the embed and view. + """ + await interaction.response.edit_message( + embed=self.embeds[self.current_page], view=self) + + async def return_paginated_embed( + self) -> tuple[discord.Embed, discord.ui.View | None]: + """ + Returns the first embed and containing view. + """ + self.update_pagination_buttons() # Disable buttons only one embed + + return self.embeds[self.current_page], self + def user_embed(guild: discord.Guild, message: str) -> discord.Embed: """Returns formatted embed for user DMs. @@ -113,42 +189,43 @@ def user_embed(guild: discord.Guild, message: str) -> discord.Embed: return message_embed -async def channel_embed( - bot: commands.Bot, guild: discord.Guild, - ticket: db.Ticket) -> tuple[discord.Embed, discord.ui.View]: - """Returns formatted embed for channel. +async def channel_embed(guild: discord.Guild, + ticket: db.Ticket) -> Collection[discord.Embed]: + """Returns formatted embed for modmail channel. Args: guild (discord.Guild): The guild. - ticket (Ticket): The ticket. + ticket (db.Ticket): The ticket. Returns: - discord.Embed: Channel embed containing message and user content. + Collection[discord.Embed]: Collection of embeds for the ticket. """ - # WARNING: Handle when user is not in guild or in DM listener ticket_member = guild.get_member(ticket.user) or await guild.fetch_member( ticket.user) - message_embed = discord.Embed( - title=f"ModMail Conversation for {ticket_member}", - description= - f"User {ticket_member.mention} has **{len(ticket_member.roles) - 1}** roles\n Joined Discord: **{format_dt(ticket_member.created_at, 'D')}**\n Joined Server: **{format_dt(ticket_member.joined_at, 'D')}**" - ) - responses = await db.get_ticket_responses(ticket.ticket_id) + names = [] + values = [] + for response in responses: author = 'user' if response.as_server: author = f'{guild.get_member(response.user)} as server' - message_embed.add_field( - name=f", {author} wrote", - value=response.response, - inline=False) + names.append(f", {author} wrote") + values.append(response.response) + + embed_dict = { + "title": + f"ModMail Conversation for {ticket_member.name}", + "description": + f"User {ticket_member.mention} has **{len(ticket_member.roles) - 1}** roles\n Joined Discord: **{format_dt(ticket_member.created_at, 'D')}**\n Joined Server: **{format_dt(ticket_member.joined_at, 'D')}**" + } + + embeds = paginated_embed_menus(names, values, embed_dict=embed_dict) - message_buttons_view = MessageButtonsView(bot) - return message_embed, message_buttons_view + return embeds def close_confirmation( @@ -159,7 +236,7 @@ def close_confirmation( member (discord.Member): The ticket user. Returns: - discord.Embed: Channel embed for close confirmation. + tuple[discord.Embed, discord.ui.View]: Tuple containing channel embed and view for close confirmation. """ confirmation_view = ConfirmationView() @@ -180,7 +257,7 @@ def timeout_confirmation( member (discord.Member): The ticket user. Returns: - discord.Embed: Channel embed for timeout confirmation. + tuple[discord.Embed, discord.ui.View]: Tuple containing channel embed and view for timeout confirmation. """ confirmation_view = ConfirmationView() @@ -201,7 +278,7 @@ def untimeout_confirmation( timeout (int): The timeout as Epoch milliseconds. Returns: - discord.Embed: Channel embed for untimeout confirmation. + tuple[discord.Embed, discord.ui.View]: Tuple containing channel embed and view for untimeout confirmation. """ confirmation_view = ConfirmationView() @@ -219,14 +296,15 @@ def reply_cancel(member: discord.Member, Args: member (discord.Member): The ticket user. + task (asyncio.Task): The task for the reply (e.g., waiting for user message). Returns: - discord.Embed: Channel embed for ticket reply. + tuple[discord.Embed, discord.ui.View]: Tuple containing channel embed and view for reply cancellation. """ cancel_view = CancelView(task) message_embed = discord.Embed( - description=f"Replying to ModMail conversation for **{member}**") + description=f"Replying to ModMail conversation for **{member.name}**") return message_embed, cancel_view @@ -245,7 +323,8 @@ def closed_ticket(staff: Union[discord.User, discord.Member], message_embed = discord.Embed( description= - f"**{staff}** closed the ModMail conversation for **{member}**") + f"**{staff.name}** closed the ModMail conversation for **{member.name}**" + ) return message_embed From 5ba2b646fac83cc2b33ebf925dd5b38e687a802f Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Fri, 7 Jul 2023 17:25:07 +0930 Subject: [PATCH 08/22] Added docstrings and removed `umember.py` --- utils/actions.py | 112 +++++++++++++++++++++++++++++++++++--------- utils/uformatter.py | 9 ++++ utils/umember.py | 26 ---------- 3 files changed, 99 insertions(+), 48 deletions(-) delete mode 100644 utils/umember.py diff --git a/utils/actions.py b/utils/actions.py index 401b066..eb983bb 100644 --- a/utils/actions.py +++ b/utils/actions.py @@ -15,6 +15,16 @@ async def waiter( bot: commands.Bot, interaction: discord.Interaction) -> Optional[discord.Message]: + """ + Waits for a message from the user who initiated the interaction. + + Args: + bot (commands.Bot): The bot object. + interaction (discord.Interaction): The interaction object. + + Returns: + Optional[discord.Message]: The message sent by the user. + """ def check(message: discord.Message) -> bool: return message.author == interaction.user and message.channel == interaction.channel @@ -29,18 +39,29 @@ def check(message: discord.Message) -> bool: async def message_open(bot: commands.Bot, interaction: discord.Interaction, member: discord.Member): + """ + Sends message embed and opens a ticket for the specified user, if not open already. + + Args: + bot (commands.Bot): The bot object. + interaction (discord.Interaction): The interaction object. + member (discord.Member): The member to open a ticket for. + """ ticket = await db.get_ticket_by_user(member.id) if ticket: await interaction.response.send_message( - f'There is already a ticket open for {member.name}.') + f'There is already a ticket open for {member.name}.', + ephemeral=True) return ticket_id = await db.open_ticket(member.id) ticket = await db.get_ticket(ticket_id) - message_embed, buttons_view = await ticket_embed.channel_embed( - bot, interaction.guild, ticket) + embeds = await ticket_embed.channel_embed(interaction.guild, ticket) + + message_embed, buttons_view = await ticket_embed.MessageButtonsView( + bot, embeds).return_paginated_embed() await interaction.response.send_message(embed=message_embed, view=buttons_view) @@ -51,15 +72,25 @@ async def message_open(bot: commands.Bot, interaction: discord.Interaction, async def message_refresh(bot: commands.Bot, interaction: discord.Interaction, member: discord.Member): + """ + Resends message embed for the specified user if open. + + Args: + bot (commands.Bot): The bot object. + interaction (discord.Interaction): The interaction object. + member (discord.Member): The member to refresh the ticket for. + """ ticket = await db.get_ticket_by_user(member.id) if not ticket: await interaction.response.send_message( - f'There is no ticket open for {member.name}.') + f'There is no ticket open for {member.name}.', ephemeral=True) return - message_embed, buttons_view = await ticket_embed.channel_embed( - bot, interaction.guild, ticket) + embeds = await ticket_embed.channel_embed(interaction.guild, ticket) + + message_embed, buttons_view = await ticket_embed.MessageButtonsView( + bot, embeds).return_paginated_embed() await interaction.response.send_message(embed=message_embed, view=buttons_view) @@ -67,14 +98,25 @@ async def message_refresh(bot: commands.Bot, interaction: discord.Interaction, await db.update_ticket_message(ticket.ticket_id, message.id) if ticket.message_id is not None: - old_ticket_message = await interaction.channel.fetch_message( - ticket.message_id) - await old_ticket_message.delete() + try: + old_ticket_message = await interaction.channel.fetch_message( + ticket.message_id) + await old_ticket_message.delete() + except discord.errors.NotFound: + # Pass if original ticket message has been deleted already + pass async def message_close(interaction: discord.Interaction, ticket: db.Ticket, user: discord.Member): - """Sends close confirmation embed, and if confirmed, will close the ticket.""" + """ + Sends close confirmation embed, and if confirmed, will close the ticket. + + Args: + interaction (discord.Interaction): The interaction object. + ticket (db.Ticket): The ticket object. + user (discord.Member): The member to close the ticket for. + """ close_embed, confirmation_view = ticket_embed.close_confirmation(user) @@ -100,12 +142,19 @@ async def message_close(interaction: discord.Interaction, ticket: db.Ticket, async def message_reply(bot: commands.Bot, interaction: discord.Interaction, ticket: db.Ticket): - """Sends reply embed, and if confirmed, will add the staff message to the ticket embeds.""" + """ + Adds staff reply to ticket and updates original embed. + + Args: + bot (commands.Bot): The bot object. + interaction (discord.Interaction): The interaction object. + ticket (db.Ticket): The ticket object. + """ ticket_user = interaction.guild.get_member(ticket.user) if not ticket_user: - await interaction.channel.send( + await interaction.response.send_message( "Cannot reply to ticket as user is not in the server.") return @@ -145,9 +194,13 @@ async def message_reply(bot: commands.Bot, interaction: discord.Interaction, ticket_message = await interaction.channel.fetch_message( ticket.message_id) - channel_embed, _ = await ticket_embed.channel_embed( - bot, interaction.guild, ticket) - await ticket_message.edit(embed=channel_embed) + embeds = await ticket_embed.channel_embed(interaction.guild, + ticket) + + channel_embed, buttons_view = await ticket_embed.MessageButtonsView( + bot, embeds).return_paginated_embed() + + await ticket_message.edit(embed=channel_embed, view=buttons_view) except discord.errors.Forbidden: await interaction.channel.send( "Could not send ModMail message to specified user due to privacy settings." @@ -164,7 +217,8 @@ async def message_timeout(interaction: discord.Interaction, """Sends timeout confirmation embed, and if confirmed, will timeout the specified ticket user. Args: - ticket_user (discord.User): The ticket user. + interaction (discord.Interaction): The interaction object. + member (discord.Member): The member to timeout. """ timeout_embed, confirmation_view = ticket_embed.timeout_confirmation( @@ -188,19 +242,29 @@ async def message_timeout(interaction: discord.Interaction, f'{member.name} has been successfully timed out for 24 hours. They will be able to message ModMail again after .' ) - # TODO: Handle when DMs are disabled - # "Could not send timeout message to specified user due to privacy settings. Timeout has not been set." - await member.send(embed=ticket_embed.user_timeout(timestamp)) + try: + await member.send(embed=ticket_embed.user_timeout(timestamp)) + except discord.errors.Forbidden: + await interaction.channel.send( + "Could not send timeout message to specified user due to privacy settings." + ) async def message_untimeout(interaction: discord.Interaction, member: discord.Member): + """ + Sends untimeout confirmation embed, and if confirmed, will remove the timeout for the specified ticket user. + + Args: + interaction (discord.Interaction): The interaction object. + member (discord.Member): The member to remove the timeout for. + """ timeout = await db.get_timeout(member.id) current_time = int(datetime.datetime.now().timestamp()) if not timeout or (current_time > timeout.timestamp): await interaction.response.send_message( - f'{member.name} is not currently timed out.') + f'{member.name} is not currently timed out.', ephemeral=True) return untimeout_embed, confirmation_view = ticket_embed.untimeout_confirmation( @@ -222,5 +286,9 @@ async def message_untimeout(interaction: discord.Interaction, await interaction.channel.send( f'Timeout has been removed for {member.name}.') - # TODO: Handle when DMs are disabled - await member.send(embed=ticket_embed.user_untimeout()) + try: + await member.send(embed=ticket_embed.user_untimeout()) + except discord.errors.Forbidden: + await interaction.channel.send( + "Could not send untimeout message to specified user due to privacy settings." + ) diff --git a/utils/uformatter.py b/utils/uformatter.py index 57070f4..37c99ac 100644 --- a/utils/uformatter.py +++ b/utils/uformatter.py @@ -2,6 +2,15 @@ def format_message(message: discord.Message) -> str: + """ + Formats message to include attachments as links. + + Args: + message (discord.Message): Message to format. + + Returns: + str: Formatted message. + """ attachments = message.attachments formatted_message = f'{message.content}' for attachment in attachments: diff --git a/utils/umember.py b/utils/umember.py deleted file mode 100644 index 0223733..0000000 --- a/utils/umember.py +++ /dev/null @@ -1,26 +0,0 @@ -import re - -ID_PATTERN = re.compile('^\\d{17,20}$') -MEMBER_MENTION_PATTERN = re.compile('^<@!?\\d{17,20}>$') -NAME_PATTERN = re.compile('^.{2,32}#[0-9]{4}$') - -def get_member(guild, input): - member = None - - if ID_PATTERN.match(input) or MEMBER_MENTION_PATTERN.match(input): - member = guild.get_member(int(re.sub('[^0-9]', '', input))) - elif NAME_PATTERN.match(input): - member = guild.get_member_named(input) - - if member is not None and not member.bot: - return member - else: - return None - -def assert_member(guild, argument): - member = get_member(guild, argument) - - if member is None: - raise RuntimeError("Please specify a valid user.") - - return member From 927d7f660c98aec79f83cd6bec73e1533baf1c90 Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Fri, 7 Jul 2023 17:25:29 +0930 Subject: [PATCH 09/22] Updated README and example config. --- README.md | 13 ++++++++++++- config.example.json | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d8abf14..160485f 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,16 @@ ModMail, but in Python +# Configuration + +- `token`: The bot's user token retrieved from Discord Developer Portal. +- `application_id`: The bot's application ID retrieved from Discord Developer Portal. +- `guild`: The guild ID. +- `channel`: Modmail channel ID in specified guild (must be `TextChannel`). +- `prefix`: The bot prefix (needed for slash command sync command). +- `status`: The bot status. +- `id_prefix`: The bot prefix for persistent views (e.g., `mm`) + # Running the bot 1. Navigate to the root directory. @@ -11,7 +21,7 @@ cd /modmail.py ``` 2. Copy the `config.example.json`, rename to `config.json`, and replace the relevant values. -If you want to inject the config at runtime using environment variables, don't replace the values. + If you want to inject the config at runtime using environment variables, don't replace the values. 3. Build the bot using Docker. @@ -28,6 +38,7 @@ docker container run --name modmail \ ``` As aforementioned, you can also inject environment variables. + ``` docker container run --name modmail \ -v database:/database \ diff --git a/config.example.json b/config.example.json index d031dca..7b7afab 100644 --- a/config.example.json +++ b/config.example.json @@ -1,9 +1,9 @@ { "token": "", + "application_id": 0, "guild": 0, "channel": 0, "prefix": "", "status": "", - "application_id": 0, "id_prefix": "" } From d62854340d26eacaac475a4eea1abd2b7d8929b6 Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Fri, 7 Jul 2023 17:26:10 +0930 Subject: [PATCH 10/22] Changed debug level, added member cache, and updated path to DB. --- db.py | 3 +-- modmail.py | 7 +++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/db.py b/db.py index 8fbca10..ac72925 100644 --- a/db.py +++ b/db.py @@ -2,8 +2,7 @@ import aiosqlite from typing import Optional -# FIX: Change this to `/database/modmail.db` before push -path = 'modmail.db' +path = '/database/modmail.db' @asynccontextmanager diff --git a/modmail.py b/modmail.py index 1a36c64..d68e21e 100644 --- a/modmail.py +++ b/modmail.py @@ -9,7 +9,7 @@ from utils.ticket_embed import MessageButtonsView logger = logging.getLogger('bot') -logger.setLevel(logging.DEBUG) # TODO: Change back to logging.INFO +logger.setLevel(logging.INFO) logging.basicConfig( format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', @@ -19,6 +19,8 @@ intents.members = True intents.message_content = True +member_cache = discord.MemberCacheFlags() + class Modmail(commands.Bot): @@ -28,6 +30,7 @@ def __init__(self): command_prefix=config.prefix, description=config.status, application_id=config.application_id, + member_cache_flags=member_cache, ) async def setup_hook(self): @@ -49,7 +52,7 @@ async def setup_hook(self): logger.error(e) logger.info("Loaded all cogs.") - self.add_view(MessageButtonsView(bot)) + self.add_view(MessageButtonsView(bot, [])) logger.info("Added all views.") async def on_ready(self): From 216e57fcf9b12a729b6c8c400d61d2c5ecfca41e Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Mon, 10 Jul 2023 13:42:09 +0930 Subject: [PATCH 11/22] Added CI/CD workflow with GitHub actions and Ansible with default `docker-compose.yml` --- .github/workflows/deploy.yml | 42 +++++++++++++++++++++++++++++++++++ .github/workflows/publish.yml | 39 ++++++++++++++++++++++++++++++++ docker-compose.yml | 12 ++++++++++ playbook.yml | 14 ++++++++++++ 4 files changed, 107 insertions(+) create mode 100644 .github/workflows/deploy.yml create mode 100644 .github/workflows/publish.yml create mode 100644 docker-compose.yml create mode 100644 playbook.yml diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml new file mode 100644 index 0000000..66d32d7 --- /dev/null +++ b/.github/workflows/deploy.yml @@ -0,0 +1,42 @@ +# yaml-language-server: $schema=https://raw.githubusercontent.com/SchemaStore/schemastore/master/src/schemas/json/github-workflow.json + +name: ModMail CD + +on: + workflow_run: + workflows: ["ModMail CI"] + types: + - completed + workflow_dispatch: + +jobs: + lint-playbook: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Run ansible-lint + uses: ansible-community/ansible-lint-action@main + run-ansible-playbook: + runs-on: ubuntu-latest + needs: [lint-playbook] + steps: + - uses: actions/checkout@v3 + - name: Install Ansible via Apt + run: > + sudo apt update && + sudo apt install software-properties-common && + sudo apt-add-repository --yes --update ppa:ansible/ansible && + sudo apt install ansible + - name: Write Inventory to File + env: + INVENTORY: ${{ secrets.INVENTORY }} + run: 'echo "$INVENTORY" > inventory' + - name: Install SSH Key + uses: shimataro/ssh-key-action@v2 + with: + key: ${{ secrets.ANSIBLE_KEY }} + name: ansible + known_hosts: ${{ secrets.KNOWN_HOSTS }} + - name: Run Ansible Playbook + run: | + ansible-playbook -i inventory playbook.yml diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..a8d8ac3 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,39 @@ +name: ModMail CI + +on: + push: + branches: ["main"] + + workflow_dispatch: + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +jobs: + build-and-publish-latest: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Set Up Docker Buildx + uses: docker/setup-buildx-action@v2 + + - name: Log in to the Container registry + uses: docker/login-action@v2 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build and push Docker image + uses: docker/build-push-action@v4 + with: + context: . + push: true + tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..8ae5b1e --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,12 @@ +version: "3.8" + +services: + modmail: + image: ghcr.io/ib-ai/modmail-py:latest + env_file: modmail.env + restart: on-failure + volumes: + - mm-data:/database + +volumes: + mm-data: diff --git a/playbook.yml b/playbook.yml new file mode 100644 index 0000000..ce21ed0 --- /dev/null +++ b/playbook.yml @@ -0,0 +1,14 @@ +--- +- hosts: all + become: true + tasks: + - name: Pull Latest Docker Image + community.docker.docker_compose: + project_src: /ibobots/modmail + pull: true + state: absent + + - name: Deploy Docker Image + community.docker.docker_compose: + project_src: /ibobots/modmail + build: false From 42dd80ba27e26f87ba8b3ab99abeab84a3edead9 Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Mon, 10 Jul 2023 16:19:25 +0930 Subject: [PATCH 12/22] Separated stopping containers and pulling images into separate steps following documentation. --- playbook.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/playbook.yml b/playbook.yml index ce21ed0..d833765 100644 --- a/playbook.yml +++ b/playbook.yml @@ -6,6 +6,10 @@ community.docker.docker_compose: project_src: /ibobots/modmail pull: true + + - name: Stop and Remove Existing Containers + community.docker.docker_compose: + project_src: /ibobots/modmail state: absent - name: Deploy Docker Image From 8b811e83751429c1b70457b946d3bacb3ea89530 Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Mon, 10 Jul 2023 16:21:08 +0930 Subject: [PATCH 13/22] Change ordering of stopping and pulling containers. --- playbook.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/playbook.yml b/playbook.yml index d833765..cde46c1 100644 --- a/playbook.yml +++ b/playbook.yml @@ -2,15 +2,15 @@ - hosts: all become: true tasks: - - name: Pull Latest Docker Image + - name: Stop and Remove Existing Containers community.docker.docker_compose: project_src: /ibobots/modmail - pull: true + state: absent - - name: Stop and Remove Existing Containers + - name: Pull Latest Docker Image community.docker.docker_compose: project_src: /ibobots/modmail - state: absent + pull: true - name: Deploy Docker Image community.docker.docker_compose: From 53e89d9b94c9cad16e64fbf2c769d94851e82548 Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Thu, 13 Jul 2023 16:23:24 +0930 Subject: [PATCH 14/22] Modified `sync` command permissions --- cogs/commands.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cogs/commands.py b/cogs/commands.py index 95692c2..f6a0ec1 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -56,7 +56,7 @@ async def interaction_check(self, interaction: discord.Interaction): return True @commands.command(name="sync") - @commands.has_permissions(administrator=True) + @commands.has_permissions(manage_guild=True) @commands.guild_only() async def sync(self, ctx: commands.Context, From c3d4a1dd1042f780e1f70a40a3bece672864084a Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Sat, 26 Aug 2023 14:13:42 +0930 Subject: [PATCH 15/22] Fixed formatting and linting issues. --- .flake8 | 2 + utils/actions.py | 121 +++++++++++++++++++----------------- utils/config.py | 47 ++++++++------ utils/pagination.py | 30 +++++---- utils/ticket_embed.py | 138 ++++++++++++++++++++++-------------------- utils/uformatter.py | 14 ++--- 6 files changed, 190 insertions(+), 162 deletions(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..6deafc2 --- /dev/null +++ b/.flake8 @@ -0,0 +1,2 @@ +[flake8] +max-line-length = 120 diff --git a/utils/actions.py b/utils/actions.py index eb983bb..4e88df5 100644 --- a/utils/actions.py +++ b/utils/actions.py @@ -1,6 +1,7 @@ import asyncio import datetime from typing import Optional + import discord from discord.ext import commands @@ -13,8 +14,8 @@ async def waiter( - bot: commands.Bot, - interaction: discord.Interaction) -> Optional[discord.Message]: + bot: commands.Bot, interaction: discord.Interaction +) -> Optional[discord.Message]: """ Waits for a message from the user who initiated the interaction. @@ -27,18 +28,22 @@ async def waiter( """ def check(message: discord.Message) -> bool: - return message.author == interaction.user and message.channel == interaction.channel + return ( + message.author == interaction.user + and message.channel == interaction.channel + ) try: - message = await bot.wait_for('message', check=check) + message = await bot.wait_for("message", check=check) except asyncio.TimeoutError: return None return message -async def message_open(bot: commands.Bot, interaction: discord.Interaction, - member: discord.Member): +async def message_open( + bot: commands.Bot, interaction: discord.Interaction, member: discord.Member +): """ Sends message embed and opens a ticket for the specified user, if not open already. @@ -51,8 +56,8 @@ async def message_open(bot: commands.Bot, interaction: discord.Interaction, if ticket: await interaction.response.send_message( - f'There is already a ticket open for {member.name}.', - ephemeral=True) + f"There is already a ticket open for {member.name}.", ephemeral=True + ) return ticket_id = await db.open_ticket(member.id) @@ -61,17 +66,18 @@ async def message_open(bot: commands.Bot, interaction: discord.Interaction, embeds = await ticket_embed.channel_embed(interaction.guild, ticket) message_embed, buttons_view = await ticket_embed.MessageButtonsView( - bot, embeds).return_paginated_embed() - await interaction.response.send_message(embed=message_embed, - view=buttons_view) + bot, embeds + ).return_paginated_embed() + await interaction.response.send_message(embed=message_embed, view=buttons_view) ticket_message = await interaction.original_response() logger.debug(f"Ticket message: {ticket_message}") await db.update_ticket_message(ticket.ticket_id, ticket_message.id) -async def message_refresh(bot: commands.Bot, interaction: discord.Interaction, - member: discord.Member): +async def message_refresh( + bot: commands.Bot, interaction: discord.Interaction, member: discord.Member +): """ Resends message embed for the specified user if open. @@ -84,15 +90,16 @@ async def message_refresh(bot: commands.Bot, interaction: discord.Interaction, if not ticket: await interaction.response.send_message( - f'There is no ticket open for {member.name}.', ephemeral=True) + f"There is no ticket open for {member.name}.", ephemeral=True + ) return embeds = await ticket_embed.channel_embed(interaction.guild, ticket) message_embed, buttons_view = await ticket_embed.MessageButtonsView( - bot, embeds).return_paginated_embed() - await interaction.response.send_message(embed=message_embed, - view=buttons_view) + bot, embeds + ).return_paginated_embed() + await interaction.response.send_message(embed=message_embed, view=buttons_view) message = await interaction.original_response() await db.update_ticket_message(ticket.ticket_id, message.id) @@ -100,15 +107,17 @@ async def message_refresh(bot: commands.Bot, interaction: discord.Interaction, if ticket.message_id is not None: try: old_ticket_message = await interaction.channel.fetch_message( - ticket.message_id) + ticket.message_id + ) await old_ticket_message.delete() except discord.errors.NotFound: # Pass if original ticket message has been deleted already pass -async def message_close(interaction: discord.Interaction, ticket: db.Ticket, - user: discord.Member): +async def message_close( + interaction: discord.Interaction, ticket: db.Ticket, user: discord.Member +): """ Sends close confirmation embed, and if confirmed, will close the ticket. @@ -120,28 +129,27 @@ async def message_close(interaction: discord.Interaction, ticket: db.Ticket, close_embed, confirmation_view = ticket_embed.close_confirmation(user) - await interaction.response.send_message(embed=close_embed, - view=confirmation_view) + await interaction.response.send_message(embed=close_embed, view=confirmation_view) confirmation_view.message = await interaction.original_response() await confirmation_view.wait() - if confirmation_view.value is None: + if not confirmation_view.value: return elif confirmation_view.value: await db.close_ticket(ticket.ticket_id) - ticket_message = await interaction.channel.fetch_message( - ticket.message_id) + ticket_message = await interaction.channel.fetch_message(ticket.message_id) await ticket_message.delete() await interaction.channel.send( - embed=ticket_embed.closed_ticket(interaction.user, user)) - logger.info( - f"Ticket for user {user.id} closed by {interaction.user.id}") + embed=ticket_embed.closed_ticket(interaction.user, user) + ) + logger.info(f"Ticket for user {user.id} closed by {interaction.user.id}") -async def message_reply(bot: commands.Bot, interaction: discord.Interaction, - ticket: db.Ticket): +async def message_reply( + bot: commands.Bot, interaction: discord.Interaction, ticket: db.Ticket +): """ Adds staff reply to ticket and updates original embed. @@ -155,14 +163,14 @@ async def message_reply(bot: commands.Bot, interaction: discord.Interaction, if not ticket_user: await interaction.response.send_message( - "Cannot reply to ticket as user is not in the server.") + "Cannot reply to ticket as user is not in the server." + ) return task = bot.loop.create_task(waiter(bot, interaction)) reply_embed, cancel_view = ticket_embed.reply_cancel(ticket_user, task) - await interaction.response.send_message(embed=reply_embed, - view=cancel_view) + await interaction.response.send_message(embed=reply_embed, view=cancel_view) cancel_view.message = await interaction.original_response() await task @@ -182,23 +190,24 @@ async def message_reply(bot: commands.Bot, interaction: discord.Interaction, # ! Fix for longer messages if len(response) > 1000: await interaction.channel.send( - 'Your message is too long. Please shorten your message or send in multiple parts.' + "Your message is too long. Please shorten your message or send in multiple parts." ) return try: await ticket_user.send( - embed=ticket_embed.user_embed(interaction.guild, response)) - await db.add_ticket_response(ticket.ticket_id, interaction.user.id, - response, True) - ticket_message = await interaction.channel.fetch_message( - ticket.message_id) + embed=ticket_embed.user_embed(interaction.guild, response) + ) + await db.add_ticket_response( + ticket.ticket_id, interaction.user.id, response, True + ) + ticket_message = await interaction.channel.fetch_message(ticket.message_id) - embeds = await ticket_embed.channel_embed(interaction.guild, - ticket) + embeds = await ticket_embed.channel_embed(interaction.guild, ticket) channel_embed, buttons_view = await ticket_embed.MessageButtonsView( - bot, embeds).return_paginated_embed() + bot, embeds + ).return_paginated_embed() await ticket_message.edit(embed=channel_embed, view=buttons_view) except discord.errors.Forbidden: @@ -212,8 +221,7 @@ async def message_reply(bot: commands.Bot, interaction: discord.Interaction, return -async def message_timeout(interaction: discord.Interaction, - member: discord.Member): +async def message_timeout(interaction: discord.Interaction, member: discord.Member): """Sends timeout confirmation embed, and if confirmed, will timeout the specified ticket user. Args: @@ -221,11 +229,9 @@ async def message_timeout(interaction: discord.Interaction, member (discord.Member): The member to timeout. """ - timeout_embed, confirmation_view = ticket_embed.timeout_confirmation( - member) + timeout_embed, confirmation_view = ticket_embed.timeout_confirmation(member) - await interaction.response.send_message(embed=timeout_embed, - view=confirmation_view) + await interaction.response.send_message(embed=timeout_embed, view=confirmation_view) confirmation_view.message = await interaction.original_response() await confirmation_view.wait() @@ -239,7 +245,7 @@ async def message_timeout(interaction: discord.Interaction, logger.info(f"User {member.id} timed out by {interaction.user.id}") await interaction.channel.send( - f'{member.name} has been successfully timed out for 24 hours. They will be able to message ModMail again after .' + f"{member.name} has been successfully timed out for 24 hours. They will be able to message ModMail again after ." ) try: @@ -250,8 +256,7 @@ async def message_timeout(interaction: discord.Interaction, ) -async def message_untimeout(interaction: discord.Interaction, - member: discord.Member): +async def message_untimeout(interaction: discord.Interaction, member: discord.Member): """ Sends untimeout confirmation embed, and if confirmed, will remove the timeout for the specified ticket user. @@ -264,14 +269,17 @@ async def message_untimeout(interaction: discord.Interaction, if not timeout or (current_time > timeout.timestamp): await interaction.response.send_message( - f'{member.name} is not currently timed out.', ephemeral=True) + f"{member.name} is not currently timed out.", ephemeral=True + ) return untimeout_embed, confirmation_view = ticket_embed.untimeout_confirmation( - member, timeout.timestamp) + member, timeout.timestamp + ) - await interaction.response.send_message(embed=untimeout_embed, - view=confirmation_view) + await interaction.response.send_message( + embed=untimeout_embed, view=confirmation_view + ) confirmation_view.message = await interaction.original_response() await confirmation_view.wait() @@ -283,8 +291,7 @@ async def message_untimeout(interaction: discord.Interaction, await db.set_timeout(member.id, timestamp) logger.info(f"Timeout removed for {member.id}.") - await interaction.channel.send( - f'Timeout has been removed for {member.name}.') + await interaction.channel.send(f"Timeout has been removed for {member.name}.") try: await member.send(embed=ticket_embed.user_untimeout()) diff --git a/utils/config.py b/utils/config.py index 619effc..6d6adcf 100644 --- a/utils/config.py +++ b/utils/config.py @@ -3,21 +3,34 @@ from pathlib import Path _path = Path(__file__).parent / "../config.json" -_config = json.load(open(_path, 'r')) +_config = json.load(open(_path, "r")) -token = os.getenv( - "MODMAIL_TOKEN") if "MODMAIL_TOKEN" in os.environ else _config["token"] -application_id = int( - os.getenv("MODMAIL_APPLICATION_ID") -) if "MODMAIL_APPLICATION_ID" in os.environ else _config["application_id"] -guild = int(os.getenv( - "MODMAIL_GUILD")) if "MODMAIL_GUILD" in os.environ else _config["guild"] -channel = int(os.getenv("MODMAIL_CHANNEL") - ) if "MODMAIL_CHANNEL" in os.environ else _config["channel"] -prefix = os.getenv( - "MODMAIL_PREFIX") if "MODMAIL_PREFIX" in os.environ else _config["prefix"] -status = os.getenv( - "MODMAIL_STATUS") if "MODMAIL_STATUS" in os.environ else _config["status"] -id_prefix = os.getenv( - "MODMAIL_ID_PREFIX" -) if "MODMAIL_ID_PREFIX" in os.environ else _config["id_prefix"] +token = ( + os.getenv("MODMAIL_TOKEN") if "MODMAIL_TOKEN" in os.environ else _config["token"] +) +application_id = ( + int(os.getenv("MODMAIL_APPLICATION_ID")) + if "MODMAIL_APPLICATION_ID" in os.environ + else _config["application_id"] +) +guild = ( + int(os.getenv("MODMAIL_GUILD")) + if "MODMAIL_GUILD" in os.environ + else _config["guild"] +) +channel = ( + int(os.getenv("MODMAIL_CHANNEL")) + if "MODMAIL_CHANNEL" in os.environ + else _config["channel"] +) +prefix = ( + os.getenv("MODMAIL_PREFIX") if "MODMAIL_PREFIX" in os.environ else _config["prefix"] +) +status = ( + os.getenv("MODMAIL_STATUS") if "MODMAIL_STATUS" in os.environ else _config["status"] +) +id_prefix = ( + os.getenv("MODMAIL_ID_PREFIX") + if "MODMAIL_ID_PREFIX" in os.environ + else _config["id_prefix"] +) diff --git a/utils/pagination.py b/utils/pagination.py index ac6de0f..c12a184 100644 --- a/utils/pagination.py +++ b/utils/pagination.py @@ -31,29 +31,26 @@ def paginated_embed_menus( N = len(names) if N != len(values): raise ValueError( - 'names and values for paginated embed menus must be of equal length.' + "names and values for paginated embed menus must be of equal length." ) if isinstance(inline, bool): inline = [inline] * N elif N != len(inline): raise ValueError( - '"inline" must be boolean or a collection of booleans of equal length to names/values for paginated embed menus.' + "'inline' must be boolean or a collection of booleans of equal length to names/values for paginated embed menus." ) if embed_dict: - if 'title' in embed_dict and len(embed_dict['title']) > 256: - raise ValueError('title cannot be over 256 characters') - if 'description' in embed_dict and len( - embed_dict['description']) > 4096: - raise ValueError('desription cannot be over 4096 characters') - if 'footer' in embed_dict: - raise ValueError('embed_dict "footer" key must not be set.') - if 'fields' in embed_dict: - raise ValueError('embed_dict "fields" key must not be set.') + if "title" in embed_dict and len(embed_dict["title"]) > 256: + raise ValueError("title cannot be over 256 characters") + if "description" in embed_dict and len(embed_dict["description"]) > 4096: + raise ValueError("desription cannot be over 4096 characters") + if "footer" in embed_dict: + raise ValueError("embed_dict 'footer' key must not be set.") + if "fields" in embed_dict: + raise ValueError("embed_dict 'fields' key must not be set.") else: - embed_dict = { # default - 'description': 'Here is a list of entries.' - } + embed_dict = {"description": "Here is a list of entries."} # default if N == 0: return [discord.Embed.from_dict(embed_dict)] @@ -63,8 +60,9 @@ def paginated_embed_menus( pages = 1 items = 0 for name, value, inline_field in zip(names, values, inline): - if items == pagesize or len(current) + len(name) + len( - value) > 5090: # leave 10 chars for footers + if ( + items == pagesize or len(current) + len(name) + len(value) > 5090 + ): # leave 10 chars for footers embeds.append(current) current = discord.Embed.from_dict(embed_dict) pages += 1 diff --git a/utils/ticket_embed.py b/utils/ticket_embed.py index fef5da2..476493b 100644 --- a/utils/ticket_embed.py +++ b/utils/ticket_embed.py @@ -1,25 +1,23 @@ import asyncio +from typing import Collection, Optional, Union + import discord from discord.ext import commands from discord.utils import format_dt -import db -from typing import Collection, Optional, Union +import db from utils import actions, config +from utils.pagination import paginated_embed_menus import logging -from utils.pagination import paginated_embed_menus - logger = logging.getLogger(__name__) class ConfirmationView(discord.ui.View): """Confirmation view for yes/no operations.""" - def __init__(self, - message: Optional[discord.Message] = None, - timeout: int = 60): + def __init__(self, message: Optional[discord.Message] = None, timeout: int = 60): super().__init__(timeout=timeout) self.message = message self.value = None @@ -28,16 +26,14 @@ async def on_timeout(self) -> None: await self.message.delete() await super().on_timeout() - @discord.ui.button(label='Yes', style=discord.ButtonStyle.green) - async def yes(self, interaction: discord.Interaction, - button: discord.ui.Button): + @discord.ui.button(label="Yes", style=discord.ButtonStyle.green) + async def yes(self, interaction: discord.Interaction, button: discord.ui.Button): self.value = True await self.message.delete() self.stop() - @discord.ui.button(label='No', style=discord.ButtonStyle.red) - async def no(self, interaction: discord.Interaction, - button: discord.ui.Button): + @discord.ui.button(label="No", style=discord.ButtonStyle.red) + async def no(self, interaction: discord.Interaction, button: discord.ui.Button): self.value = False await self.message.delete() self.stop() @@ -46,10 +42,12 @@ async def no(self, interaction: discord.Interaction, class CancelView(discord.ui.View): """Cancel view for cancelling operations if requested by the user.""" - def __init__(self, - task: asyncio.Task, - message: Optional[discord.Message] = None, - timeout: int = 60) -> None: + def __init__( + self, + task: asyncio.Task, + message: Optional[discord.Message] = None, + timeout: int = 60, + ) -> None: super().__init__(timeout=timeout) self.message = message self.task = task @@ -65,9 +63,8 @@ async def on_timeout(self) -> None: await self.view_cleanup() await super().on_timeout() - @discord.ui.button(label='Cancel', style=discord.ButtonStyle.red) - async def cancel(self, interaction: discord.Interaction, - button: discord.ui.Button): + @discord.ui.button(label="Cancel", style=discord.ButtonStyle.red) + async def cancel(self, interaction: discord.Interaction, button: discord.ui.Button): await self.view_cleanup() @@ -95,7 +92,8 @@ async def mail_close(self, interaction: discord.Interaction, _): """ ticket = await db.get_ticket_by_message(interaction.message.id) member = interaction.guild.get_member( - ticket.user) or await interaction.guild.fetch_member(ticket.user) + ticket.user + ) or await interaction.guild.fetch_member(ticket.user) await actions.message_close(interaction, ticket, member) @discord.ui.button(emoji="⏲️", custom_id=f"{config.id_prefix}:timeout") @@ -105,20 +103,24 @@ async def mail_timeout(self, interaction: discord.Interaction, _): """ ticket = await db.get_ticket_by_message(interaction.message.id) member = interaction.guild.get_member( - ticket.user) or await interaction.guild.fetch_member(ticket.user) + ticket.user + ) or await interaction.guild.fetch_member(ticket.user) await actions.message_timeout(interaction, member) - @discord.ui.button(emoji="⬅️", - style=discord.ButtonStyle.blurple, - custom_id=f"{config.id_prefix}:before_page") - async def before_page(self, interaction: discord.Interaction, _): + @discord.ui.button( + emoji="⬅️", + style=discord.ButtonStyle.blurple, + custom_id=f"{config.id_prefix}:previous_page", + ) + async def previous_page(self, interaction: discord.Interaction, _): """ Goes to the previous page. """ if len(self.embeds) == 0: await interaction.response.send_message( "Please refresh this ticket to be able to use pagination.", - ephemeral=True) + ephemeral=True, + ) return if self.current_page > 0: @@ -126,9 +128,11 @@ async def before_page(self, interaction: discord.Interaction, _): self.update_pagination_buttons() await self.update_view(interaction) - @discord.ui.button(emoji="➡️", - style=discord.ButtonStyle.blurple, - custom_id=f"{config.id_prefix}:next_page") + @discord.ui.button( + emoji="➡️", + style=discord.ButtonStyle.blurple, + custom_id=f"{config.id_prefix}:next_page", + ) async def next_page(self, interaction: discord.Interaction, _): """ Goes to the next page. @@ -136,7 +140,8 @@ async def next_page(self, interaction: discord.Interaction, _): if len(self.embeds) == 0: await interaction.response.send_message( "Please refresh this ticket to be able to use pagination.", - ephemeral=True) + ephemeral=True, + ) return if self.current_page < len(self.embeds) - 1: @@ -160,10 +165,12 @@ async def update_view(self, interaction: discord.Interaction): Updates the embed and view. """ await interaction.response.edit_message( - embed=self.embeds[self.current_page], view=self) + embed=self.embeds[self.current_page], view=self + ) async def return_paginated_embed( - self) -> tuple[discord.Embed, discord.ui.View | None]: + self, + ) -> tuple[discord.Embed, discord.ui.View | None]: """ Returns the first embed and containing view. """ @@ -183,14 +190,16 @@ def user_embed(guild: discord.Guild, message: str) -> discord.Embed: discord.Embed: User DM embed containing the message content. """ - message_embed = discord.Embed(title=f"New Mail from {guild.name}", - description=message) + message_embed = discord.Embed( + title=f"New Mail from {guild.name}", description=message + ) return message_embed -async def channel_embed(guild: discord.Guild, - ticket: db.Ticket) -> Collection[discord.Embed]: +async def channel_embed( + guild: discord.Guild, ticket: db.Ticket +) -> Collection[discord.Embed]: """Returns formatted embed for modmail channel. Args: @@ -202,7 +211,8 @@ async def channel_embed(guild: discord.Guild, """ ticket_member = guild.get_member(ticket.user) or await guild.fetch_member( - ticket.user) + ticket.user + ) responses = await db.get_ticket_responses(ticket.ticket_id) @@ -210,17 +220,15 @@ async def channel_embed(guild: discord.Guild, values = [] for response in responses: - author = 'user' + author = "user" if response.as_server: - author = f'{guild.get_member(response.user)} as server' + author = f"{guild.get_member(response.user)} as server" names.append(f", {author} wrote") values.append(response.response) embed_dict = { - "title": - f"ModMail Conversation for {ticket_member.name}", - "description": - f"User {ticket_member.mention} has **{len(ticket_member.roles) - 1}** roles\n Joined Discord: **{format_dt(ticket_member.created_at, 'D')}**\n Joined Server: **{format_dt(ticket_member.joined_at, 'D')}**" + "title": f"ModMail Conversation for {ticket_member.name}", + "description": f"User {ticket_member.mention} has **{len(ticket_member.roles) - 1}** roles\n Joined Discord: **{format_dt(ticket_member.created_at, 'D')}**\n Joined Server: **{format_dt(ticket_member.joined_at, 'D')}**", } embeds = paginated_embed_menus(names, values, embed_dict=embed_dict) @@ -228,8 +236,7 @@ async def channel_embed(guild: discord.Guild, return embeds -def close_confirmation( - member: discord.Member) -> tuple[discord.Embed, discord.ui.View]: +def close_confirmation(member: discord.Member) -> tuple[discord.Embed, discord.ui.View]: """Returns embed for ticket close confirmation. Args: @@ -242,15 +249,15 @@ def close_confirmation( confirmation_view = ConfirmationView() message_embed = discord.Embed( - description= - f"Do you want to close the ModMail conversation for **{member.name}**?" + description=f"Do you want to close the ModMail conversation for **{member.name}**?" ) return message_embed, confirmation_view def timeout_confirmation( - member: discord.Member) -> tuple[discord.Embed, discord.ui.View]: + member: discord.Member, +) -> tuple[discord.Embed, discord.ui.View]: """Returns embed for ticket timeout confirmation. Args: @@ -263,14 +270,15 @@ def timeout_confirmation( confirmation_view = ConfirmationView() message_embed = discord.Embed( - description=f"Do you want to timeout **{member.name}** for 24 hours?") + description=f"Do you want to timeout **{member.name}** for 24 hours?" + ) return message_embed, confirmation_view def untimeout_confirmation( - member: discord.Member, - timeout: int) -> tuple[discord.Embed, discord.ui.View]: + member: discord.Member, timeout: int +) -> tuple[discord.Embed, discord.ui.View]: """Returns embed for ticket untimeout confirmation. Args: @@ -283,15 +291,15 @@ def untimeout_confirmation( confirmation_view = ConfirmationView() message_embed = discord.Embed( - description= - f"Do you want to untimeout **{member.name}** (they are currently timed out until )?" + description=f"Do you want to untimeout **{member.name}** (they are currently timed out until )?" ) return message_embed, confirmation_view -def reply_cancel(member: discord.Member, - task: asyncio.Task) -> tuple[discord.Embed, discord.ui.View]: +def reply_cancel( + member: discord.Member, task: asyncio.Task +) -> tuple[discord.Embed, discord.ui.View]: """Returns embed for replying to ticket with cancel reaction. Args: @@ -304,13 +312,15 @@ def reply_cancel(member: discord.Member, cancel_view = CancelView(task) message_embed = discord.Embed( - description=f"Replying to ModMail conversation for **{member.name}**") + description=f"Replying to ModMail conversation for **{member.name}**" + ) return message_embed, cancel_view -def closed_ticket(staff: Union[discord.User, discord.Member], - member: discord.Member) -> discord.Embed: +def closed_ticket( + staff: Union[discord.User, discord.Member], member: discord.Member +) -> discord.Embed: """Returns embed for closed ticket. Args: @@ -322,8 +332,7 @@ def closed_ticket(staff: Union[discord.User, discord.Member], """ message_embed = discord.Embed( - description= - f"**{staff.name}** closed the ModMail conversation for **{member.name}**" + description=f"**{staff.name}** closed the ModMail conversation for **{member.name}**" ) return message_embed @@ -340,8 +349,7 @@ def user_timeout(timeout: int) -> discord.Embed: """ message_embed = discord.Embed( - description= - f"You have been timed out. You will be able to message ModMail again after ." + description=f"You have been timed out. You will be able to message ModMail again after ()." ) return message_embed @@ -355,7 +363,7 @@ def user_untimeout() -> discord.Embed: """ message_embed = discord.Embed( - description= - "Your timeout has been removed. You can message ModMail again.") + description="Your timeout has been removed. You can message ModMail again." + ) return message_embed diff --git a/utils/uformatter.py b/utils/uformatter.py index 37c99ac..f097d5a 100644 --- a/utils/uformatter.py +++ b/utils/uformatter.py @@ -12,13 +12,13 @@ def format_message(message: discord.Message) -> str: str: Formatted message. """ attachments = message.attachments - formatted_message = f'{message.content}' + formatted_message = f"{message.content}".strip() for attachment in attachments: - attachment_type = 'Unknown' - if attachment.content_type.startswith('image'): - attachment_type = 'Image' - elif attachment.content_type.startswith('video'): - attachment_type = 'Video' - formatted_message += f'\n[{attachment_type} Attachment]({attachment.url})' + attachment_type = "Unknown" + if attachment.content_type.startswith("image"): + attachment_type = "Image" + elif attachment.content_type.startswith("video"): + attachment_type = "Video" + formatted_message += f"\n[{attachment_type} Attachment]({attachment.url})" return formatted_message From 056321f1ac468f38a4e5c857bc95dec647d8da6e Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Sat, 26 Aug 2023 14:14:46 +0930 Subject: [PATCH 16/22] Modified cog setup functions/initialisation and error handling. Fixed formatting and linting issues. --- cogs/commands.py | 91 +++++++++++++++++++++++++++-------------------- cogs/listeners.py | 54 ++++++++++++++-------------- modmail.py | 47 ++++++++++++------------ 3 files changed, 102 insertions(+), 90 deletions(-) diff --git a/cogs/commands.py b/cogs/commands.py index f6a0ec1..006124a 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -1,8 +1,8 @@ from typing import Literal, Optional + from discord.ext import commands from discord import app_commands import discord -import sys import db from utils import config, actions @@ -15,8 +15,7 @@ class Commands(commands.Cog): """Cog to contain command action methods.""" - def __init__(self, bot: commands.Bot, - modmail_channel: discord.TextChannel) -> None: + def __init__(self, bot: commands.Bot, modmail_channel: discord.TextChannel) -> None: """Constructs necessary attributes for all command action methods. Args: @@ -41,16 +40,20 @@ async def cog_check(self, ctx: commands.Context): async def interaction_check(self, interaction: discord.Interaction): if interaction.channel != self.modmail_channel: await interaction.response.send_message( - "Command must be used in the modmail channel.") + "Command must be used in the modmail channel." + ) return False - if "resolved" in interaction.data and "users" in interaction.data[ - "resolved"] and len(interaction.data["resolved"]["users"]) > 0: + if ( + interaction.data + and "resolved" in interaction.data + and "users" in interaction.data["resolved"] + and len(interaction.data["resolved"]["users"]) > 0 + ): user_id = next(iter(interaction.data["resolved"]["users"])) user = interaction.data["resolved"]["users"][user_id] if "bot" in user and user["bot"]: - await interaction.response.send_message( - "Invalid user specified.") + await interaction.response.send_message("Invalid user specified.") return False return True @@ -58,9 +61,7 @@ async def interaction_check(self, interaction: discord.Interaction): @commands.command(name="sync") @commands.has_permissions(manage_guild=True) @commands.guild_only() - async def sync(self, - ctx: commands.Context, - spec: Optional[Literal["~"]] = None): + async def sync(self, ctx: commands.Context, spec: Optional[Literal["~"]] = None): """ Syncs commands to the current guild or globally. @@ -81,84 +82,96 @@ async def sync(self, @app_commands.command(name="open") @commands.guild_only() - async def open_ticket(self, interaction: discord.Interaction, - member: discord.Member): + async def open_ticket( + self, interaction: discord.Interaction, member: discord.Member + ): """Opens ticket for specified user if no tickets are currently open.""" await actions.message_open(self.bot, interaction, member) @app_commands.command(name="refresh") @commands.guild_only() - async def refresh_ticket(self, interaction: discord.Interaction, - member: discord.Member): + async def refresh_ticket( + self, interaction: discord.Interaction, member: discord.Member + ): """Resends embed for specified user if there is a ticket that is already open.""" await actions.message_refresh(self.bot, interaction, member) @app_commands.command(name="close") @commands.guild_only() - async def close_ticket(self, interaction: discord.Interaction, - member: discord.Member): + async def close_ticket( + self, interaction: discord.Interaction, member: discord.Member + ): """Closes ticket for specified user given that a ticket is already open.""" ticket = await db.get_ticket_by_user(member.id) if not ticket: await interaction.response.send_message( - f'There is no ticket open for {member.name}.', ephemeral=True) + f"There is no ticket open for {member.name}.", ephemeral=True + ) return await actions.message_close(interaction, ticket, member) @app_commands.command(name="timeout") @commands.guild_only() - async def timeout_ticket(self, interaction: discord.Interaction, - member: discord.Member): + async def timeout_ticket( + self, interaction: discord.Interaction, member: discord.Member + ): """Times out specified user.""" await actions.message_timeout(interaction, member) @app_commands.command(name="untimeout") @commands.guild_only() - async def untimeout_ticket(self, interaction: discord.Interaction, - member: discord.Member): + async def untimeout_ticket( + self, interaction: discord.Interaction, member: discord.Member + ): """Removes timeout for specified user given that user is currently timed out.""" await actions.message_untimeout(interaction, member) - async def cog_command_error(self, ctx: commands.Context, - error: commands.CommandError): - if type(error) == commands.errors.CheckFailure: + async def cog_command_error( + self, ctx: commands.Context, error: commands.CommandError + ): + if isinstance(error, commands.errors.CheckFailure): logger.error("Checks failed for interaction.") - elif type(error) == commands.errors.MissingRequiredArgument: + if isinstance(error, commands.errors.MissingRequiredArgument): await ctx.send( "A valid user (and one who is still on the server) was not specified." ) else: await ctx.send( - str(error) + "\nIf you do not understand, contact a bot dev.") + str(error) + "\nIf you do not understand, contact a bot dev." + ) - async def cog_app_command_error(self, interaction: discord.Interaction, - error: app_commands.AppCommandError): - if type(error) == app_commands.errors.CheckFailure: + async def cog_app_command_error( + self, interaction: discord.Interaction, error: app_commands.AppCommandError + ): + if isinstance(error, app_commands.errors.CheckFailure): logger.error("Checks failed for interaction.") else: await interaction.response.send_message( - str(error) + "\nIf you do not understand, contact a bot dev.") + str(error) + "\nIf you do not understand, contact a bot dev." + ) logger.error(error) async def setup(bot: commands.Bot): + """Setup function for the listeners cog. + + Args: + bot (commands.Bot): The bot. + """ try: modmail_channel = await bot.fetch_channel(config.channel) - - if type(modmail_channel) != discord.TextChannel: - raise TypeError( - "The channel specified in config was not a text channel.") except Exception as e: - logger.error(e) - logger.fatal( + raise ValueError( "The channel specified in config was not found. Please check your config." - ) - sys.exit(-1) + ) from e + + if not isinstance(modmail_channel, discord.TextChannel): + raise TypeError("The channel specified in config was not a text channel.") await bot.add_cog(Commands(bot, modmail_channel)) diff --git a/cogs/listeners.py b/cogs/listeners.py index 5e56363..ff3e450 100644 --- a/cogs/listeners.py +++ b/cogs/listeners.py @@ -1,7 +1,7 @@ +import datetime + from discord.ext import commands import discord -import datetime -import sys import db from utils import config, uformatter, ticket_embed @@ -14,8 +14,7 @@ class Listeners(commands.Cog): """Cog to contain all main listener methods.""" - def __init__(self, bot: commands.Bot, - modmail_channel: discord.TextChannel) -> None: + def __init__(self, bot: commands.Bot, modmail_channel: discord.TextChannel) -> None: """Constructs necessary attributes for all command action methods. Args: @@ -36,12 +35,13 @@ async def on_message(self, message: discord.Message): if message.guild is None and not message.author.bot: # Handle if user is not in guild - if not (self.modmail_channel.guild.get_member(message.author.id) - or await self.modmail_channel.guild.fetch_member( - message.author.id)): + if not ( + self.modmail_channel.guild.get_member(message.author.id) + or await self.modmail_channel.guild.fetch_member(message.author.id) + ): try: await message.author.send( - 'Unable to send message. Please ensure you have joined the server.' + "Unable to send message. Please ensure you have joined the server." ) except discord.errors.Forbidden: pass @@ -67,13 +67,13 @@ async def handle_dm(self, message: discord.Message): response = uformatter.format_message(message) - if not response.strip(): + if not response: return # ! Fix for longer messages if len(response) > 1000: await message.channel.send( - 'Your message is too long. Please shorten your message or send in multiple parts.' + "Your message is too long. Please shorten your message or send in multiple parts." ) return @@ -87,27 +87,28 @@ async def handle_dm(self, message: discord.Message): try: if ticket and ticket.message_id is not None: old_ticket_message = await self.modmail_channel.fetch_message( - ticket.message_id) + ticket.message_id + ) await old_ticket_message.delete() except discord.errors.NotFound: await message.channel.send( - 'You are being rate limited. Please wait a few seconds before trying again.' + "You are being rate limited. Please wait a few seconds before trying again." ) return # `ticket` truthiness has been checked prior to the following lines - await db.add_ticket_response(ticket.ticket_id, user.id, response, - False) + await db.add_ticket_response(ticket.ticket_id, user.id, response, False) - embeds = await ticket_embed.channel_embed(self.modmail_channel.guild, - ticket) + embeds = await ticket_embed.channel_embed(self.modmail_channel.guild, ticket) message_embed, buttons_view = await ticket_embed.MessageButtonsView( - self.bot, embeds).return_paginated_embed() + self.bot, embeds + ).return_paginated_embed() - ticket_message = await self.modmail_channel.send(embed=message_embed, - view=buttons_view) - await message.add_reaction('📨') + ticket_message = await self.modmail_channel.send( + embed=message_embed, view=buttons_view + ) + await message.add_reaction("📨") await db.update_ticket_message(ticket.ticket_id, ticket_message.id) @@ -120,15 +121,12 @@ async def setup(bot: commands.Bot): """ try: modmail_channel = await bot.fetch_channel(config.channel) - - if type(modmail_channel) != discord.TextChannel: - raise TypeError( - "The channel specified in config was not a text channel.") except Exception as e: - logger.error(e) - logger.fatal( + raise ValueError( "The channel specified in config was not found. Please check your config." - ) - sys.exit(-1) + ) from e + + if not isinstance(modmail_channel, discord.TextChannel): + raise TypeError("The channel specified in config was not a text channel.") await bot.add_cog(Listeners(bot, modmail_channel)) diff --git a/modmail.py b/modmail.py index d68e21e..fa7997b 100644 --- a/modmail.py +++ b/modmail.py @@ -3,17 +3,17 @@ import db from utils import config +from utils.ticket_embed import MessageButtonsView import logging -from utils.ticket_embed import MessageButtonsView - -logger = logging.getLogger('bot') +logger = logging.getLogger("bot") logger.setLevel(logging.INFO) logging.basicConfig( - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%d-%b-%y %H:%M:%S') + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%d-%b-%y %H:%M:%S", +) intents = discord.Intents.default() intents.members = True @@ -21,9 +21,10 @@ member_cache = discord.MemberCacheFlags() +INITIAL_COGS = ["commands", "listeners"] -class Modmail(commands.Bot): +class Modmail(commands.Bot): def __init__(self): super().__init__( intents=intents, @@ -34,32 +35,32 @@ def __init__(self): ) async def setup_hook(self): - if await db.init(): - logger.info('Database sucessfully initialized!') - else: - logger.error('Error while initializing database!') - return + await db.init() + logger.info("Database sucessfully initialized!") - for cog in ('commands', 'listeners'): + for cog in INITIAL_COGS: try: - await bot.load_extension(f'cogs.{cog}') - logger.debug(f'Imported cog "{cog}".') - except commands.errors.NoEntryPointError as e: - logger.warning(e) - except commands.errors.ExtensionNotFound as e: - logger.warning(e) - except commands.errors.ExtensionFailed as e: - logger.error(e) + await bot.load_extension(f"cogs.{cog}") + logger.debug(f"Imported cog '{cog}'.") + except ( + commands.errors.NoEntryPointError, + commands.errors.ExtensionNotFound, + commands.errors.ExtensionFailed, + ) as e: + logger.fatal(f"Failed to import cog '{cog}'.") + raise SystemExit(e) + logger.info("Loaded all cogs.") self.add_view(MessageButtonsView(bot, [])) logger.info("Added all views.") async def on_ready(self): - await bot.change_presence(activity=discord.Game(name=config.status), - status=discord.Status.online) + await bot.change_presence( + activity=discord.Game(name=config.status), status=discord.Status.online + ) - logger.info(f"Bot \"{bot.user.name}\" is now connected.") + logger.info(f"Bot '{bot.user.name}' is now connected.") async def on_command_error(self, ctx: commands.Context, exception) -> None: await super().on_command_error(ctx, exception) From fd3571cfaae13926d3eb2f03a03ed627f192fc20 Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Sat, 26 Aug 2023 14:15:28 +0930 Subject: [PATCH 17/22] Added `CONTRIBUTING.md` and `pyproject.toml`. Updated `README.md` to include sample config and contributing section. --- CONTRIBUTING.md | 20 ++++++++++++++++++++ README.md | 18 ++++++++++++++++++ pyproject.toml | 5 +++++ 3 files changed, 43 insertions(+) create mode 100644 CONTRIBUTING.md create mode 100644 pyproject.toml diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..986859a --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,20 @@ +If you are reading this page, then thank you for your interest in contributing towards the bot. +We are grateful for any help, however, please ensure you follow the guidelines laid out below +and ensure that any code you produce for modmail.py is licensed with the GNU GPL v3. + +# Contributions + +Not all contribution PRs (read below) will be accepted. +For ideas as to what to contribute, please refer to the GitHub issues or contact a member of the development team. + +Bug fixes and optimisations are also greatly appreciated! + +# VCS + +1. Create a new branch (and fork if applicable). Label it appropriately. +2. Make the changes on that branch. +3. Commit to and push the changes. +4. Create a PR from your branch to `main`. +5. A maintainer will then review your PR. + +If you have questions, please ask a maintainer. diff --git a/README.md b/README.md index 160485f..2331b7d 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,20 @@ ModMail, but in Python - `status`: The bot status. - `id_prefix`: The bot prefix for persistent views (e.g., `mm`) +## Sample `config.json` + +```json +{ + "token": "abc123", + "application_id": 1234567890, + "guild": 1234567890, + "channel": 1234567890, + "prefix": "]", + "status": "DM to contact", + "id_prefix": "mm" +} +``` + # Running the bot 1. Navigate to the root directory. @@ -48,3 +62,7 @@ docker container run --name modmail \ -e MODMAIL_PREFIX=! \ modmail-py ``` + +# Contributions + +For information regarding contributing to this project, please read [CONTRIBUTING.md](CONTRIBUTING.md). diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8fb8be0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,5 @@ +[project] +name = "modmail.py" +version = "2.0.0" +readme = "README.md" +description = "A Discord bot for managing modmail written in Python." From 0cfde21def0d14795f44cd76196e01d6dfb103e6 Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Sat, 26 Aug 2023 14:16:31 +0930 Subject: [PATCH 18/22] Switched classes to dataclasses. Refactored to use `async_db_cursor` decorator with type hints. Fixed formatting and linting issues. --- db.py | 304 ++++++++++++++++++++++++++++++---------------------------- 1 file changed, 155 insertions(+), 149 deletions(-) diff --git a/db.py b/db.py index ac72925..8cd7a7d 100644 --- a/db.py +++ b/db.py @@ -1,237 +1,243 @@ from contextlib import asynccontextmanager -import aiosqlite -from typing import Optional +from dataclasses import dataclass +import functools +from typing import Awaitable, Callable, Concatenate, Optional, ParamSpec, TypeVar +from aiosqlite import connect, Row, Cursor -path = '/database/modmail.db' + +PATH = "/database/modmail.db" + +P = ParamSpec("P") +R = TypeVar("R") @asynccontextmanager async def db_ops(): - conn = await aiosqlite.connect(path) - conn.row_factory = aiosqlite.Row + conn = await connect(PATH) + conn.row_factory = Row cursor = await conn.cursor() yield cursor await conn.commit() await conn.close() -class Ticket: - - def __init__(self, ticket_id: int, user: int, open: int, - message_id: Optional[int]): - self.ticket_id = ticket_id - self.user = user - self.open = open - self.message_id = message_id +def async_db_cursor( + func: Callable[Concatenate[Cursor, P], Awaitable[R]] +) -> Callable[P, Awaitable[R]]: + @functools.wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + async with db_ops() as cursor: + return await func(cursor, *args, **kwargs) - def __repr__(self) -> str: - return f"Ticket({self.ticket_id}, {self.user}, {self.open}, {self.message_id})" + return wrapper -class TicketResponse: +@dataclass +class Ticket: + ticket_id: int + user: int + open: int + message_id: Optional[int] - def __init__(self, user: int, response: str, timestamp: int, - as_server: bool): - self.user = user - self.response = response - self.timestamp = timestamp - self.as_server = as_server - def __repr__(self) -> str: - return f"TicketResponse({self.user}, {self.response}, {self.timestamp}, {self.as_server})" +@dataclass +class TicketResponse: + user: int + response: str + timestamp: int + as_server: bool +@dataclass class Timeout: + timeout_id: int + timestamp: int - def __init__(self, timeout_id: int, timestamp: int): - self.timeout_id = timeout_id - self.timestamp = timestamp - def __repr__(self) -> str: - return f"Timeout({self.timeout_id}, {self.user}, {self.timestamp})" - - -async def get_ticket(ticket_id: int) -> Optional[Ticket]: +@async_db_cursor +async def get_ticket(cursor: Cursor, ticket_id: int) -> Optional[Ticket]: sql = """ SELECT ticket_id, user, open, message_id FROM mm_tickets WHERE ticket_id=? """ - async with db_ops() as cursor: - await cursor.execute(sql, [ticket_id]) - ticket = await cursor.fetchone() - if ticket is None or len(ticket) == 0: - return None - else: - return Ticket(*ticket) + await cursor.execute(sql, [ticket_id]) + ticket = await cursor.fetchone() + if ticket is None or len(ticket) == 0: + return None + else: + return Ticket(*ticket) -async def get_ticket_by_user(user: int) -> Optional[Ticket]: +@async_db_cursor +async def get_ticket_by_user(cursor: Cursor, user: int) -> Optional[Ticket]: sql = """ SELECT ticket_id, user, open, message_id FROM mm_tickets WHERE user=? AND open=1 """ - async with db_ops() as cursor: - await cursor.execute(sql, [user]) - ticket = await cursor.fetchone() - if ticket is None or len(ticket) == 0: - return None - else: - return Ticket(*ticket) + await cursor.execute(sql, [user]) + ticket = await cursor.fetchone() + if ticket is None or len(ticket) == 0: + return None + else: + return Ticket(*ticket) -async def get_ticket_by_message(message_id: int) -> Optional[Ticket]: +@async_db_cursor +async def get_ticket_by_message(cursor: Cursor, message_id: int) -> Optional[Ticket]: sql = """ SELECT ticket_id, user, open, message_id FROM mm_tickets WHERE message_id=? """ - async with db_ops() as cursor: - await cursor.execute(sql, [message_id]) - ticket = await cursor.fetchone() - if ticket is None or len(ticket) == 0: - return None - else: - return Ticket(*ticket) + await cursor.execute(sql, [message_id]) + ticket = await cursor.fetchone() + if ticket is None or len(ticket) == 0: + return None + else: + return Ticket(*ticket) -async def open_ticket(user: int) -> Optional[int]: +@async_db_cursor +async def open_ticket(cursor: Cursor, user: int) -> Optional[int]: sql = """ INSERT INTO mm_tickets (user) VALUES (?) """ - async with db_ops() as cursor: - await cursor.execute(sql, [user]) - return cursor.lastrowid + await cursor.execute(sql, [user]) + return cursor.lastrowid -async def update_ticket_message(ticket_id: int, message_id: int) -> bool: +@async_db_cursor +async def update_ticket_message( + cursor: Cursor, ticket_id: int, message_id: int +) -> bool: sql = """ UPDATE mm_tickets SET message_id=? WHERE ticket_id=? """ - async with db_ops() as cursor: - await cursor.execute(sql, [message_id, ticket_id]) - return cursor.rowcount != 0 + await cursor.execute(sql, [message_id, ticket_id]) + return cursor.rowcount != 0 -async def close_ticket(ticket_id: int) -> bool: +@async_db_cursor +async def close_ticket(cursor: Cursor, ticket_id: int) -> bool: sql = """ UPDATE mm_tickets SET open=0 WHERE ticket_id=? """ - async with db_ops() as cursor: - await cursor.execute(sql, [ticket_id]) - return cursor.rowcount != 0 + await cursor.execute(sql, [ticket_id]) + return cursor.rowcount != 0 -async def get_ticket_responses(ticket_id: int) -> list[TicketResponse]: +@async_db_cursor +async def get_ticket_responses(cursor: Cursor, ticket_id: int) -> list[TicketResponse]: sql = """ SELECT user, response, timestamp, as_server FROM mm_ticket_responses WHERE ticket_id=? """ - async with db_ops() as cursor: - await cursor.execute(sql, [ticket_id]) - rows = await cursor.fetchall() - return [TicketResponse(*row) for row in rows] + await cursor.execute(sql, [ticket_id]) + rows = await cursor.fetchall() + return [TicketResponse(*row) for row in rows] -async def add_ticket_response(ticket_id: int, user: int, response: str, - as_server: bool) -> Optional[int]: +@async_db_cursor +async def add_ticket_response( + cursor: Cursor, ticket_id: int, user: int, response: str, as_server: bool +) -> Optional[int]: sql = """ INSERT INTO mm_ticket_responses (ticket_id, user, response, as_server) VALUES (?, ?, ?, ?) """ - async with db_ops() as cursor: - await cursor.execute(sql, [ticket_id, user, response, as_server]) - return cursor.lastrowid + await cursor.execute(sql, [ticket_id, user, response, as_server]) + return cursor.lastrowid -async def get_timeout(user: int) -> Optional[Timeout]: +@async_db_cursor +async def get_timeout(cursor: Cursor, user: int) -> Optional[Timeout]: sql = """ SELECT timeout_id, timestamp FROM mm_timeouts WHERE user=? """ - async with db_ops() as cursor: - await cursor.execute(sql, [user]) - timeout = await cursor.fetchone() - if timeout is None or len(timeout) == 0: - return None - else: - return Timeout(*timeout) + await cursor.execute(sql, [user]) + timeout = await cursor.fetchone() + if timeout is None or len(timeout) == 0: + return None + else: + return Timeout(*timeout) -async def set_timeout(user: int, timestamp: int) -> Optional[int]: +@async_db_cursor +async def set_timeout(cursor: Cursor, user: int, timestamp: int) -> Optional[int]: sql = """ INSERT OR REPLACE INTO mm_timeouts (user, timestamp) VALUES (?, ?) """ - async with db_ops() as cursor: - await cursor.execute(sql, [user, timestamp]) - return cursor.lastrowid - - -async def init(): - async with db_ops() as cursor: - # Create modmail tickets table - sql = """ - CREATE TABLE IF NOT EXISTS mm_tickets ( - ticket_id INTEGER PRIMARY KEY AUTOINCREMENT, - user INTEGER NOT NULL, - open BOOLEAN DEFAULT 1 NOT NULL, - message_id INTEGER - ); - """ - await cursor.execute(sql) - - # Create modmail ticket user index - sql = "CREATE INDEX IF NOT EXISTS mm_tickets_user ON mm_tickets(user);" - await cursor.execute(sql) - - # Create modmail ticket message index - sql = "CREATE INDEX IF NOT EXISTS mm_tickets_message ON mm_tickets(message_id);" - await cursor.execute(sql) - - # Create modmail ticket repsonses table - sql = """ - CREATE TABLE IF NOT EXISTS mm_ticket_responses ( - response_id INTEGER PRIMARY KEY AUTOINCREMENT, - ticket_id INTEGER, - user INTEGER NOT NULL, - response TEXT NOT NULL, - timestamp TIMESTAMP DEFAULT (strftime('%s', 'now')) NOT NULL, - as_server BOOLEAN NOT NULL, - FOREIGN KEY (ticket_id) REFERENCES mm_tickets (ticket_id) - ); - """ - await cursor.execute(sql) - - # Create modmail ticket response ticket id index - sql = "CREATE INDEX IF NOT EXISTS mm_ticket_responses_ticket_id ON mm_ticket_responses(ticket_id);" - await cursor.execute(sql) - - # Create modmail ticket response user index - sql = "CREATE INDEX IF NOT EXISTS mm_ticket_responses_user ON mm_ticket_responses(user);" - await cursor.execute(sql) - - # Create modmail timeouts table - sql = """ - CREATE TABLE IF NOT EXISTS mm_timeouts ( - timeout_id INTEGER PRIMARY KEY AUTOINCREMENT, - user INTEGER NOT NULL UNIQUE, - timestamp TIMESTAMP DEFAULT (strftime('%s', 'now')) NOT NULL - ); - """ - await cursor.execute(sql) - - # Create modmail timeout user index - sql = "CREATE UNIQUE INDEX IF NOT EXISTS mm_timeouts_user ON mm_timeouts(user);" - await cursor.execute(sql) - - return True + await cursor.execute(sql, [user, timestamp]) + return cursor.lastrowid + + +@async_db_cursor +async def init(cursor: Cursor): + # Create modmail tickets table + sql = """ + CREATE TABLE IF NOT EXISTS mm_tickets ( + ticket_id INTEGER PRIMARY KEY AUTOINCREMENT, + user INTEGER NOT NULL, + open BOOLEAN DEFAULT 1 NOT NULL, + message_id INTEGER + ); + """ + await cursor.execute(sql) + + # Create modmail ticket user index + sql = "CREATE INDEX IF NOT EXISTS mm_tickets_user ON mm_tickets(user);" + await cursor.execute(sql) + + # Create modmail ticket message index + sql = "CREATE INDEX IF NOT EXISTS mm_tickets_message ON mm_tickets(message_id);" + await cursor.execute(sql) + + # Create modmail ticket repsonses table + sql = """ + CREATE TABLE IF NOT EXISTS mm_ticket_responses ( + response_id INTEGER PRIMARY KEY AUTOINCREMENT, + ticket_id INTEGER, + user INTEGER NOT NULL, + response TEXT NOT NULL, + timestamp TIMESTAMP DEFAULT (strftime('%s', 'now')) NOT NULL, + as_server BOOLEAN NOT NULL, + FOREIGN KEY (ticket_id) REFERENCES mm_tickets (ticket_id) + ); + """ + await cursor.execute(sql) + + # Create modmail ticket response ticket id index + sql = "CREATE INDEX IF NOT EXISTS mm_ticket_responses_ticket_id ON mm_ticket_responses(ticket_id);" + await cursor.execute(sql) + + # Create modmail ticket response user index + sql = "CREATE INDEX IF NOT EXISTS mm_ticket_responses_user ON mm_ticket_responses(user);" + await cursor.execute(sql) + + # Create modmail timeouts table + sql = """ + CREATE TABLE IF NOT EXISTS mm_timeouts ( + timeout_id INTEGER PRIMARY KEY AUTOINCREMENT, + user INTEGER NOT NULL UNIQUE, + timestamp TIMESTAMP DEFAULT (strftime('%s', 'now')) NOT NULL + ); + """ + await cursor.execute(sql) + + # Create modmail timeout user index + sql = "CREATE UNIQUE INDEX IF NOT EXISTS mm_timeouts_user ON mm_timeouts(user);" + await cursor.execute(sql) + + return True From a5cbea370869f4e5e9a218f7cc383d088908105b Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Fri, 8 Sep 2023 10:26:44 +0930 Subject: [PATCH 19/22] Modified `.flake8` to work with Black formatting --- .flake8 | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.flake8 b/.flake8 index 6deafc2..8dd399a 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,3 @@ [flake8] -max-line-length = 120 +max-line-length = 88 +extend-ignore = E203 From cca8b1708f7c013884d011cfda602239fa058dbe Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Fri, 8 Sep 2023 10:27:00 +0930 Subject: [PATCH 20/22] Added comments about working in dev/prod with the db --- db.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/db.py b/db.py index 8cd7a7d..5694438 100644 --- a/db.py +++ b/db.py @@ -5,6 +5,8 @@ from aiosqlite import connect, Row, Cursor +# For development/local testing, use "modmail.db" +# For production and working with Docker, use "/database/modmail.db" PATH = "/database/modmail.db" P = ParamSpec("P") From 4ec74f606bed8bb0a5404b8eb1721fd40735b714 Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Fri, 8 Sep 2023 10:27:43 +0930 Subject: [PATCH 21/22] Added `ConfZ` to project and modified `config.py` and references to config --- cogs/commands.py | 7 +++++-- cogs/listeners.py | 7 +++++-- modmail.py | 15 ++++++++------ requirements.txt | 9 ++++++++ utils/config.py | 48 +++++++++++++++---------------------------- utils/ticket_embed.py | 15 ++++++++------ 6 files changed, 53 insertions(+), 48 deletions(-) diff --git a/cogs/commands.py b/cogs/commands.py index 006124a..beddae7 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -5,12 +5,15 @@ import discord import db -from utils import config, actions +from utils import actions +from utils.config import Config import logging logger = logging.getLogger(__name__) +modmail_config = Config() + class Commands(commands.Cog): """Cog to contain command action methods.""" @@ -165,7 +168,7 @@ async def setup(bot: commands.Bot): bot (commands.Bot): The bot. """ try: - modmail_channel = await bot.fetch_channel(config.channel) + modmail_channel = await bot.fetch_channel(modmail_config.channel) except Exception as e: raise ValueError( "The channel specified in config was not found. Please check your config." diff --git a/cogs/listeners.py b/cogs/listeners.py index ff3e450..c5521c9 100644 --- a/cogs/listeners.py +++ b/cogs/listeners.py @@ -4,12 +4,15 @@ import discord import db -from utils import config, uformatter, ticket_embed +from utils import uformatter, ticket_embed +from utils.config import Config import logging logger = logging.getLogger(__name__) +modmail_config = Config() + class Listeners(commands.Cog): """Cog to contain all main listener methods.""" @@ -120,7 +123,7 @@ async def setup(bot: commands.Bot): bot (commands.Bot): The bot. """ try: - modmail_channel = await bot.fetch_channel(config.channel) + modmail_channel = await bot.fetch_channel(modmail_config.channel) except Exception as e: raise ValueError( "The channel specified in config was not found. Please check your config." diff --git a/modmail.py b/modmail.py index fa7997b..4838ca0 100644 --- a/modmail.py +++ b/modmail.py @@ -2,7 +2,7 @@ from discord.ext import commands import db -from utils import config +from utils.config import Config from utils.ticket_embed import MessageButtonsView import logging @@ -21,6 +21,8 @@ member_cache = discord.MemberCacheFlags() +modmail_config = Config() + INITIAL_COGS = ["commands", "listeners"] @@ -28,9 +30,9 @@ class Modmail(commands.Bot): def __init__(self): super().__init__( intents=intents, - command_prefix=config.prefix, - description=config.status, - application_id=config.application_id, + command_prefix=modmail_config.prefix, + description=modmail_config.status, + application_id=modmail_config.application_id, member_cache_flags=member_cache, ) @@ -57,7 +59,8 @@ async def setup_hook(self): async def on_ready(self): await bot.change_presence( - activity=discord.Game(name=config.status), status=discord.Status.online + activity=discord.Game(name=modmail_config.status), + status=discord.Status.online, ) logger.info(f"Bot '{bot.user.name}' is now connected.") @@ -68,4 +71,4 @@ async def on_command_error(self, ctx: commands.Context, exception) -> None: bot = Modmail() -bot.run(config.token) +bot.run(modmail_config.token.get_secret_value()) diff --git a/requirements.txt b/requirements.txt index 196d252..18efe52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,19 @@ aiohttp==3.8.4 aiosignal==1.3.1 +aiosqlite==0.19.0 +annotated-types==0.5.0 async-timeout==4.0.2 attrs==23.1.0 charset-normalizer==3.1.0 +confz==2.0.1 discord.py==2.3.1 frozenlist==1.3.3 idna==3.4 multidict==6.0.4 +pydantic==2.3.0 +pydantic_core==2.6.3 +python-dotenv==1.0.0 +PyYAML==6.0.1 +toml==0.10.2 +typing_extensions==4.7.1 yarl==1.9.2 diff --git a/utils/config.py b/utils/config.py index 6d6adcf..6934d59 100644 --- a/utils/config.py +++ b/utils/config.py @@ -1,36 +1,20 @@ -import json -import os from pathlib import Path +from confz import BaseConfig, EnvSource, FileFormat, FileSource +from pydantic import SecretStr _path = Path(__file__).parent / "../config.json" -_config = json.load(open(_path, "r")) -token = ( - os.getenv("MODMAIL_TOKEN") if "MODMAIL_TOKEN" in os.environ else _config["token"] -) -application_id = ( - int(os.getenv("MODMAIL_APPLICATION_ID")) - if "MODMAIL_APPLICATION_ID" in os.environ - else _config["application_id"] -) -guild = ( - int(os.getenv("MODMAIL_GUILD")) - if "MODMAIL_GUILD" in os.environ - else _config["guild"] -) -channel = ( - int(os.getenv("MODMAIL_CHANNEL")) - if "MODMAIL_CHANNEL" in os.environ - else _config["channel"] -) -prefix = ( - os.getenv("MODMAIL_PREFIX") if "MODMAIL_PREFIX" in os.environ else _config["prefix"] -) -status = ( - os.getenv("MODMAIL_STATUS") if "MODMAIL_STATUS" in os.environ else _config["status"] -) -id_prefix = ( - os.getenv("MODMAIL_ID_PREFIX") - if "MODMAIL_ID_PREFIX" in os.environ - else _config["id_prefix"] -) + +class Config(BaseConfig): + token: SecretStr + application_id: int + guild: int + channel: int + prefix: str + status: str + id_prefix: str + + CONFIG_SOURCES = [ + FileSource(_path, format=FileFormat.JSON, optional=True), + EnvSource(prefix="MODMAIL_", allow_all=True), + ] diff --git a/utils/ticket_embed.py b/utils/ticket_embed.py index 476493b..53265a1 100644 --- a/utils/ticket_embed.py +++ b/utils/ticket_embed.py @@ -6,13 +6,16 @@ from discord.utils import format_dt import db -from utils import actions, config +from utils import actions +from utils.config import Config from utils.pagination import paginated_embed_menus import logging logger = logging.getLogger(__name__) +modmail_config = Config() + class ConfirmationView(discord.ui.View): """Confirmation view for yes/no operations.""" @@ -77,7 +80,7 @@ def __init__(self, bot: commands.Bot, embeds: Collection[discord.Embed]): self.embeds = embeds self.current_page = len(self.embeds) - 1 - @discord.ui.button(emoji="💬", custom_id=f"{config.id_prefix}:reply") + @discord.ui.button(emoji="💬", custom_id=f"{modmail_config.id_prefix}:reply") async def mail_reply(self, interaction: discord.Interaction, _): """ Replies to the ticket. @@ -85,7 +88,7 @@ async def mail_reply(self, interaction: discord.Interaction, _): ticket = await db.get_ticket_by_message(interaction.message.id) await actions.message_reply(self.bot, interaction, ticket) - @discord.ui.button(emoji="❎", custom_id=f"{config.id_prefix}:close") + @discord.ui.button(emoji="❎", custom_id=f"{modmail_config.id_prefix}:close") async def mail_close(self, interaction: discord.Interaction, _): """ Closes the ticket. @@ -96,7 +99,7 @@ async def mail_close(self, interaction: discord.Interaction, _): ) or await interaction.guild.fetch_member(ticket.user) await actions.message_close(interaction, ticket, member) - @discord.ui.button(emoji="⏲️", custom_id=f"{config.id_prefix}:timeout") + @discord.ui.button(emoji="⏲️", custom_id=f"{modmail_config.id_prefix}:timeout") async def mail_timeout(self, interaction: discord.Interaction, _): """ Times out the user of the ticket. @@ -110,7 +113,7 @@ async def mail_timeout(self, interaction: discord.Interaction, _): @discord.ui.button( emoji="⬅️", style=discord.ButtonStyle.blurple, - custom_id=f"{config.id_prefix}:previous_page", + custom_id=f"{modmail_config.id_prefix}:previous_page", ) async def previous_page(self, interaction: discord.Interaction, _): """ @@ -131,7 +134,7 @@ async def previous_page(self, interaction: discord.Interaction, _): @discord.ui.button( emoji="➡️", style=discord.ButtonStyle.blurple, - custom_id=f"{config.id_prefix}:next_page", + custom_id=f"{modmail_config.id_prefix}:next_page", ) async def next_page(self, interaction: discord.Interaction, _): """ From b716589139dca34be344874009e77f03ba1646b7 Mon Sep 17 00:00:00 2001 From: NathanealV <18462634+NathanealV@users.noreply.github.com> Date: Fri, 8 Sep 2023 23:04:54 +0930 Subject: [PATCH 22/22] Changed `docker-compose.yml` image to reflect repository name --- docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yml b/docker-compose.yml index 8ae5b1e..018a9bf 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: "3.8" services: modmail: - image: ghcr.io/ib-ai/modmail-py:latest + image: ghcr.io/ib-ai/modmail.py:latest env_file: modmail.env restart: on-failure volumes: