Skip to content

Commit

Permalink
Lookup project abbreviations in Discord home server if set.
Browse files Browse the repository at this point in the history
  • Loading branch information
synrg committed Jul 31, 2024
1 parent 0f26988 commit 1524f25
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
29 changes: 17 additions & 12 deletions inatcog/commands/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..embeds.inat import INatEmbeds
from ..interfaces import MixinMeta
from ..places import RESERVED_PLACES
from ..utils import has_valid_user_config
from ..utils import get_home_server, has_valid_user_config

logger = logging.getLogger("red.dronefly." + __name__)

Expand All @@ -35,7 +35,7 @@ async def project(self, ctx, *, query):
- *abbreviation* defined with `[p]project add`; see `[p]help project add` for details.
"""
try:
project = await self.project_table.get_project(ctx.guild, query)
project = await self.project_table.get_project(ctx.guild, query, ctx.author)
embed = make_embed(
title=project.title,
url=project.url,
Expand All @@ -48,8 +48,9 @@ async def project(self, ctx, *, query):
if project.icon:
embed.set_thumbnail(url=project.icon)
embed.add_field(name="Project number", value=project.id)
if ctx.guild:
guild_config = self.config.guild(ctx.guild)
guild = ctx.guild or await get_home_server(self, ctx.author)
if guild:
guild_config = self.config.guild(guild)
projects = await guild_config.projects()
proj_abbrevs = [
abbrev for abbrev in projects if projects[abbrev] == project.id
Expand Down Expand Up @@ -102,10 +103,10 @@ async def project_add(self, ctx, abbrev: str, project_number: int):
@checks.bot_has_permissions(embed_links=True, read_message_history=True)
async def project_list(self, ctx, *, match=""):
"""List projects with abbreviations on this server."""
if not ctx.guild:
guild = ctx.guild or await get_home_server(self, ctx.author)
if not guild:
return

config = self.config.guild(ctx.guild)
config = self.config.guild(guild)
projects = await config.projects()
result_pages = []

Expand Down Expand Up @@ -138,7 +139,7 @@ async def project_list(self, ctx, *, match=""):
"%s (places: %s, guild: %d)",
err,
",".join(proj_id_group),
ctx.guild.id,
guild.id,
)

# Iterate over projects and do a quick cache lookup per project:
Expand All @@ -147,7 +148,9 @@ async def project_list(self, ctx, *, match=""):
proj_str_text = ""
if proj_id in self.api.projects_cache:
try:
project = await self.project_table.get_project(ctx.guild, proj_id)
project = await self.project_table.get_project(
guild, proj_id, ctx.author
)
proj_str = f"{abbrev}: [{project.title}]({project.url})"
proj_str_text = f"{abbrev} {project.title}"
except LookupError as err:
Expand All @@ -161,7 +164,7 @@ async def project_list(self, ctx, *, match=""):
"Project in cache could not be retrieved: %s (project: %d, guild: %d)",
err,
proj_id,
ctx.guild.id,
guild.id,
)
# In the unlikely case of the deletion of a project that is cached:
proj_str = f"{abbrev}: {proj_id} not found."
Expand All @@ -174,7 +177,7 @@ async def project_list(self, ctx, *, match=""):
"Project deleted? %s: %d (guild: %d)",
abbrev,
proj_id,
ctx.guild.id,
guild.id,
)
proj_str = f"{abbrev}: [{proj_id}]({WWW_BASE_URL}/projects/{proj_id})"
proj_str_text = abbrev
Expand Down Expand Up @@ -239,7 +242,9 @@ async def project_stats(self, ctx, project: str, *, user: str = "me"):
error_msg = None
async with ctx.typing():
try:
proj = await self.project_table.get_project(ctx.guild, project)
proj = await self.project_table.get_project(
ctx.guild, project, ctx.author
)
ctx_member = await MemberConverter.convert(ctx, user)
member = ctx_member.member
user = await self.user_table.get_user(member)
Expand Down
15 changes: 12 additions & 3 deletions inatcog/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from pyinaturalist.models import Project

from .converters.base import QuotedContextMemberConverter
from .utils import get_home_server


class UserProject(Project):
"""A collection project for observations by specific users.
Expand Down Expand Up @@ -100,19 +103,25 @@ class INatProjectTable:
def __init__(self, cog):
self.cog = cog

async def get_project(self, guild, query: Union[int, str]):
async def get_project(
self,
guild,
query: Union[int, str],
user: QuotedContextMemberConverter = None,
):
"""Get project by guild abbr or via id#/keyword lookup in API."""
project = None
response = None
abbrev = None

_guild = guild or await get_home_server(self.cog, user)
if isinstance(query, str):
abbrev = query.lower()
if isinstance(query, int) or query.isnumeric():
project_id = query
response = await self.cog.api.get_projects(int(project_id))
if guild and abbrev:
guild_config = self.cog.config.guild(guild)
if _guild and abbrev:
guild_config = self.cog.config.guild(_guild)
projects = await guild_config.projects()
if abbrev in projects:
response = await self.cog.api.get_projects(projects[abbrev])
Expand Down
4 changes: 3 additions & 1 deletion inatcog/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ async def get(
args = get_base_query_args(query)

args["project"] = (
await self.cog.project_table.get_project(ctx.guild, query.project)
await self.cog.project_table.get_project(
ctx.guild, query.project, ctx.author
)
if has_value(query.project)
else None
)
Expand Down

0 comments on commit 1524f25

Please sign in to comment.