Skip to content

Commit

Permalink
add doc & small improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
AiroPi committed Mar 24, 2024
1 parent 949ad7b commit be15d5b
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 93 deletions.
6 changes: 3 additions & 3 deletions src/cogs/poll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def edit_poll(self, inter: Interaction, message: discord.Message) -> None:
if poll.author_id != inter.user.id:
raise NonSpecificError(_("You are not the author of this poll. You can't edit it.", _l=256))
await inter.response.send_message(
**(await PollDisplay.build(poll, self.bot)),
**(await PollDisplay(poll, self.bot)),
view=await EditPoll(self, poll, message),
ephemeral=True,
)
Expand Down Expand Up @@ -133,7 +133,7 @@ async def on_submit(self, inter: discord.Interaction):
self.poll.title = self.question.value
self.poll.description = self.description.value
await inter.response.send_message(
**(await PollDisplay.build(self.poll, self.bot)),
**(await PollDisplay(self.poll, self.bot)),
view=await EditPoll(self.cog, self.poll, inter.message),
ephemeral=True,
)
Expand Down Expand Up @@ -173,7 +173,7 @@ async def on_submit(self, inter: discord.Interaction):
self.poll.choices.append(db.PollChoice(poll_id=self.poll.id, label=self.choice3.value))

await inter.response.send_message(
**(await PollDisplay.build(self.poll, self.bot)),
**(await PollDisplay(self.poll, self.bot)),
view=await EditPoll(self.cog, self.poll, inter.message),
ephemeral=True,
)
Expand Down
72 changes: 35 additions & 37 deletions src/cogs/poll/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from core import Emojis, db
from core.i18n import _
from core.response import MessageDisplay
from core.utils import AsyncInitMixin

from .constants import (
BOOLEAN_INDEXES,
Expand All @@ -28,56 +29,53 @@
from mybot import MyBot


class PollDisplay:
def __init__(self, poll: Poll, votes: dict[str, int] | None):
self.poll: Poll = poll
self.votes = votes
class PollDisplay(AsyncInitMixin, MessageDisplay):
async def __init__(self, poll: Poll, bot: MyBot, old_embed: Embed | None = None):
self.poll = poll
self.votes: dict[str, int] | None = await self.get_votes(bot)

@classmethod
async def build(cls, poll: Poll, bot: MyBot, old_embed: Embed | None = None) -> MessageDisplay:
content = poll.description
embed = discord.Embed(title=poll.title)

votes: dict[str, int] | None
if poll.public_results is True:
async with bot.async_session.begin() as session:
stmt = (
db.select(db.PollAnswer.value, func.count())
.select_from(db.PollAnswer)
.where(db.PollAnswer.poll_id == poll.id)
.group_by(db.PollAnswer.value)
)

votes = { # noqa: C416, dict comprehension used for typing purposes
key: value
for key, value in (await session.execute(stmt)).all() # choice_id: vote_count
}
if poll.type == db.PollType.CHOICE:
# when we delete a choice from a poll, the votes are still in the db before commit
# so we need to filter them
votes = {
key: value for key, value in votes.items() if key in (str(choice.id) for choice in poll.choices)
}
else:
votes = None

poll_display = cls(poll, votes)

description_split: list[str] = [poll_display.build_end_date(), poll_display.build_legend()]

description_split: list[str] = [self.build_end_date(), self.build_legend()]
embed.description = "\n".join(description_split)

if poll.public_results:
embed.add_field(name="\u200b", value=poll_display.build_graph())
embed.color = poll_display.build_color()
embed.add_field(name="\u200b", value=self.build_graph())
embed.color = self.build_color()

if old_embed:
embed.set_footer(text=old_embed.footer.text)
else:
author = await bot.getch_user(poll.author_id)
embed.set_footer(text=_("Poll created by {}", author.name if author else "unknown"))

return MessageDisplay(content=content, embed=embed)
MessageDisplay.__init__(self, content=content, embed=embed)

async def get_votes(self, bot: MyBot) -> dict[str, int] | None:
if self.poll.public_results is False:
return None

async with bot.async_session.begin() as session:
stmt = (
db.select(db.PollAnswer.value, func.count())
.select_from(db.PollAnswer)
.where(db.PollAnswer.poll_id == self.poll.id)
.group_by(db.PollAnswer.value)
)

votes = { # noqa: C416, dict comprehension used for typing purposes
key: value
for key, value in (await session.execute(stmt)).all() # choice_id: vote_count
}
if self.poll.type == db.PollType.CHOICE:
# when we delete a choice from a poll, the votes are still in the db before commit
# so we need to filter them
votes = {
key: value
for key, value in votes.items()
if key in (str(choice.id) for choice in self.poll.choices)
}
return votes

@property
def total_votes(self) -> int:
Expand Down
40 changes: 18 additions & 22 deletions src/cogs/poll/edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def update(self) -> None:
self.toggle_poll.style = discord.ButtonStyle.red

async def message_display(self) -> MessageDisplay:
return await PollDisplay.build(self.poll, self.bot)
return await PollDisplay(self.poll, self.bot)

@ui.button(row=4, style=discord.ButtonStyle.red)
async def reset_votes(self, inter: discord.Interaction, button: ui.Button[Self]):
Expand All @@ -63,7 +63,7 @@ async def reset_votes(self, inter: discord.Interaction, button: ui.Button[Self])
async def toggle_poll(self, inter: discord.Interaction, button: ui.Button[Self]):
del button # unused
self.poll.closed = not self.poll.closed
await self.message_refresh(inter)
await self.edit_message(inter)

@ui.button(row=4, style=discord.ButtonStyle.green)
async def save(self, inter: discord.Interaction, button: ui.Button[Self]):
Expand All @@ -74,7 +74,7 @@ async def save(self, inter: discord.Interaction, button: ui.Button[Self]):
# channel can be other type of channels like voice, but it's ok.
channel = cast(discord.TextChannel, inter.channel)
message = await channel.send(
**(await PollDisplay.build(self.poll, self.bot)), view=await PollPublicMenu(self.cog, self.poll)
**(await PollDisplay(self.poll, self.bot)), view=await PollPublicMenu(self.cog, self.poll)
)
self.poll.message_id = message.id

Expand All @@ -91,7 +91,7 @@ async def save(self, inter: discord.Interaction, button: ui.Button[Self]):
await inter.delete_original_response()

await self.poll_message.edit(
**(await PollDisplay.build(self.poll, self.bot)), view=await PollPublicMenu(self.cog, self.poll)
**(await PollDisplay(self.poll, self.bot)), view=await PollPublicMenu(self.cog, self.poll)
)

currents = self.cog.current_votes.pop(self.poll.id, None)
Expand Down Expand Up @@ -137,12 +137,12 @@ def set_options(self, poll: db.Poll):
# )

async def callback(self, inter: Interaction[MyBot]) -> None: # pyright: ignore [reportIncompatibleMethodOverride]
view = cast(EditPoll, self.view)
menu = cast(EditPoll, self.view)
if self.values[0] == "public_results":
view.poll.public_results = not view.poll.public_results
menu.poll.public_results = not menu.poll.public_results
elif self.values[0] == "users_can_change_answer":
view.poll.users_can_change_answer = not view.poll.users_can_change_answer
await view.message_refresh(inter)
menu.poll.users_can_change_answer = not menu.poll.users_can_change_answer
await menu.edit_message(inter)


class EditPollMenus(ui.Select[EditPoll]):
Expand Down Expand Up @@ -180,12 +180,8 @@ async def __init__(self, parent: EditPoll):
await super().__init__(parent)
self.poll = parent.poll

async def set_back(self, inter: discord.Interaction):
await self.parent.update()
await super().set_back(inter)

async def update_poll_display(self, inter: discord.Interaction, view: ui.View | None = None):
await inter.response.edit_message(**(await PollDisplay.build(self.poll, self.bot)), view=view or self)
await inter.response.edit_message(**(await PollDisplay(self.poll, self.bot)), view=view or self)


class EditTitleAndDescription(EditSubmenu, ui.Modal):
Expand Down Expand Up @@ -219,7 +215,7 @@ async def __init__(self, parent: EditPoll):
async def on_submit(self, inter: discord.Interaction) -> None:
self.poll.title = self.question.value
self.poll.description = self.description.value
await self.set_back(inter)
await super().on_submit(inter)


class EditEndingTime(EditSubmenu):
Expand Down Expand Up @@ -269,7 +265,7 @@ async def __init__(self, parent: EditPoll):
option = next(opt for opt in self.select_minutes.options if opt.value == default_minutes)
option.default = True

async def cancel(self):
async def on_cancel(self):
self.poll.end_date = self.old_value

def set_time(self):
Expand Down Expand Up @@ -319,7 +315,7 @@ async def __init__(self, parent: EditPoll):
self.add_choice.label = _("Add a choice")
self.remove_choice.label = _("Remove a choice")

async def cancel(self):
async def on_cancel(self):
self.poll.choices = self.old_value

async def update(self):
Expand Down Expand Up @@ -378,7 +374,7 @@ async def choices_to_remove(self, inter: Interaction, select: ui.Select[Self]):
await self.update()
await self.parent.update_poll_display(inter, view=self)

async def cancel(self):
async def on_cancel(self):
self.parent.poll.choices = self.old_value


Expand All @@ -401,9 +397,9 @@ async def __init__(self, parent: EditPoll) -> None:
@ui.select(cls=ui.Select[Self])
async def max_choices(self, inter: Interaction, select: ui.Select[Self]):
self.parent.poll.max_answers = int(select.values[0])
await self.message_refresh(inter)
await self.edit_message(inter)

async def cancel(self):
async def on_cancel(self):
self.parent.poll.max_answers = self.old_value


Expand All @@ -422,9 +418,9 @@ async def __init__(self, parent: EditPoll) -> None:
@ui.select(cls=ui.RoleSelect[Self], max_values=25)
async def allowed_roles(self, inter: Interaction, select: ui.RoleSelect[Self]):
self.parent.poll.allowed_roles = [role.id for role in select.values]
await self.message_refresh(inter)
await self.edit_message(inter)

async def cancel(self):
async def on_cancel(self):
self.parent.poll.allowed_roles = self.old_value


Expand All @@ -445,4 +441,4 @@ async def reset(self, inter: discord.Interaction, button: ui.Button[Self]):
del button # unused
async with self.bot.async_session.begin() as session:
await session.execute(delete(db.PollAnswer).where(db.PollAnswer.poll_id == self.parent.poll.id))
await self.set_back(inter)
await self.set_menu(inter, self.parent)
8 changes: 4 additions & 4 deletions src/cogs/poll/vote_menus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from core import Menu, ResponseType, db, response_constructor
from core.constants import Emojis
from core.i18n import _
from core.view_menus import SubMenu
from core.view_menus import SubMenuWithoutButtons

from .constants import LEGEND_EMOJIS
from .display import PollDisplay
Expand Down Expand Up @@ -112,7 +112,7 @@ async def vote(self, inter: discord.Interaction, button: ui.Button[Self]):
)


class VoteMenu(SubMenu[PollPublicMenu]):
class VoteMenu(SubMenuWithoutButtons[PollPublicMenu]):
async def __init__(
self, parent: PollPublicMenu, poll: db.Poll, user_votes: Sequence[db.PollAnswer], base_inter: Interaction
):
Expand All @@ -132,7 +132,7 @@ async def update_poll_display(self):
try:
message = cast(discord.Message, self.base_inter.message) # type: ignore
old_embed = message.embeds[0] if message.embeds else None
await message.edit(**(await PollDisplay.build(self.poll, self.parent.bot, old_embed)))
await message.edit(**await PollDisplay(self.poll, self.parent.bot, old_embed))
except discord.NotFound:
pass

Expand Down Expand Up @@ -174,7 +174,7 @@ async def update(self):
@ui.select(cls=ui.Select[Self])
async def choice(self, inter: Interaction, select: ui.Select[Self]):
del select # unused
await self.message_refresh(inter, False)
await self.edit_message(inter, False)

@ui.button(style=discord.ButtonStyle.red)
async def remove_vote(self, inter: Interaction, button: ui.Button[Self]):
Expand Down
4 changes: 2 additions & 2 deletions src/core/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def __getitem__(self, key: str) -> Any:
return self.__dict__[key]

def __iter__(self) -> Iterator[str]:
return iter(self.__dict__)
return iter(self.__dataclass_fields__)

def __len__(self) -> int:
return self.__dict__.__len__()
return self.__dataclass_fields__.__len__()


class UneditedMessageDisplay(Mapping[str, Any]):
Expand Down
Loading

0 comments on commit be15d5b

Please sign in to comment.