diff --git a/app/features/entity_mentions.py b/app/features/entity_mentions.py index e0ed134..f3b4b01 100644 --- a/app/features/entity_mentions.py +++ b/app/features/entity_mentions.py @@ -1,7 +1,9 @@ import asyncio +import datetime as dt import re from contextlib import suppress from types import SimpleNamespace +from typing import Literal, Protocol, cast import discord import github @@ -17,6 +19,10 @@ IGNORED_MESSAGE_TYPES = frozenset( (discord.MessageType.thread_created, discord.MessageType.channel_name_change) ) +REPOSITORIES: dict[str, Repository] = { + kind: gh.get_repo(f"{config.GITHUB_ORG}/{name}", lazy=True) + for kind, name in config.GITHUB_REPOS.items() +} DISCUSSION_QUERY = """ query getDiscussion($number: Int!, $org: String!, $repo: String!) { @@ -30,7 +36,50 @@ } """ -repo_cache: dict[str, Repository] = {} +RepoName = Literal["web", "bot", "main"] +CacheKey = tuple[RepoName, int] +EntityKind = Literal["Pull Request", "Issue", "Discussion"] + + +class Entity(Protocol): + number: int + title: str + html_url: str + + +class TTLCache: + def __init__(self, ttl: int) -> None: + self._ttl = dt.timedelta(seconds=ttl) + self._cache: dict[CacheKey, tuple[dt.datetime, EntityKind, Entity]] = {} + + def _fetch_entity(self, key: CacheKey) -> None: + repo_name, entity_id = key + try: + entity = REPOSITORIES[repo_name].get_issue(entity_id) + kind = "Pull Request" if entity.pull_request else "Issue" + except github.UnknownObjectException: + try: + entity = get_discussion(REPOSITORIES[repo_name], entity_id) + kind = "Discussion" + except github.GithubException: + raise KeyError(key) from None + self._cache[key] = (dt.datetime.now(), kind, cast(Entity, entity)) + + def _refresh(self, key: CacheKey) -> None: + if key not in self._cache: + self._fetch_entity(key) + return + timestamp, *_ = self._cache[key] + if dt.datetime.now() - timestamp >= self._ttl: + self._fetch_entity(key) + + def __getitem__(self, key: CacheKey) -> tuple[EntityKind, Entity]: + self._refresh(key) + _, kind, entity = self._cache[key] + return kind, entity + + +entity_cache = TTLCache(1800) # 30 minutes async def handle_entities(message: Message) -> None: @@ -46,22 +95,9 @@ async def handle_entities(message: Message) -> None: entities: list[str] = [] for match in ENTITY_REGEX.finditer(message.content): - repo_name = config.GITHUB_REPOS[match[1] or "main"] - if (repo := repo_cache.get(repo_name)) is None: - repo_cache[repo_name] = repo = gh.get_repo( - f"{config.GITHUB_ORG}/{repo_name}", lazy=True - ) - entity_id = int(match[2]) - try: - entity = repo.get_issue(entity_id) - kind = "Pull Request" if entity.pull_request else "Issue" - except github.UnknownObjectException: - try: - entity = get_discussion(repo, entity_id) - kind = "Discussion" - except github.GithubException: - continue - if entity_id < 10: + repo_name = cast(RepoName, match[1] or "main") + kind, entity = entity_cache[repo_name, int(match[2])] + if entity.number < 10: # Ignore single-digit mentions (likely a false positive) continue entities.append(ENTITY_TEMPLATE.format(kind=kind, entity=entity))