Skip to content

Commit

Permalink
Merge pull request #79 from trag1c/entity-cache
Browse files Browse the repository at this point in the history
feat: implement a TTL cache for entities
  • Loading branch information
mitchellh authored Dec 28, 2024
2 parents 2e8fb76 + ef66410 commit 4a01584
Showing 1 changed file with 53 additions and 17 deletions.
70 changes: 53 additions & 17 deletions app/features/entity_mentions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import datetime as dt
import re
from contextlib import suppress
from types import SimpleNamespace
from typing import Literal, Protocol, cast

import discord
import github
Expand All @@ -17,6 +19,10 @@
IGNORED_MESSAGE_TYPES = frozenset(
(discord.MessageType.thread_created, discord.MessageType.channel_name_change)
)
REPOSITORIES: dict[str, Repository] = {
kind: gh.get_repo(f"{config.GITHUB_ORG}/{name}", lazy=True)
for kind, name in config.GITHUB_REPOS.items()
}

DISCUSSION_QUERY = """
query getDiscussion($number: Int!, $org: String!, $repo: String!) {
Expand All @@ -30,7 +36,50 @@
}
"""

repo_cache: dict[str, Repository] = {}
RepoName = Literal["web", "bot", "main"]
CacheKey = tuple[RepoName, int]
EntityKind = Literal["Pull Request", "Issue", "Discussion"]


class Entity(Protocol):
number: int
title: str
html_url: str


class TTLCache:
def __init__(self, ttl: int) -> None:
self._ttl = dt.timedelta(seconds=ttl)
self._cache: dict[CacheKey, tuple[dt.datetime, EntityKind, Entity]] = {}

def _fetch_entity(self, key: CacheKey) -> None:
repo_name, entity_id = key
try:
entity = REPOSITORIES[repo_name].get_issue(entity_id)
kind = "Pull Request" if entity.pull_request else "Issue"
except github.UnknownObjectException:
try:
entity = get_discussion(REPOSITORIES[repo_name], entity_id)
kind = "Discussion"
except github.GithubException:
raise KeyError(key) from None
self._cache[key] = (dt.datetime.now(), kind, cast(Entity, entity))

def _refresh(self, key: CacheKey) -> None:
if key not in self._cache:
self._fetch_entity(key)
return
timestamp, *_ = self._cache[key]
if dt.datetime.now() - timestamp >= self._ttl:
self._fetch_entity(key)

def __getitem__(self, key: CacheKey) -> tuple[EntityKind, Entity]:
self._refresh(key)
_, kind, entity = self._cache[key]
return kind, entity


entity_cache = TTLCache(1800) # 30 minutes


async def handle_entities(message: Message) -> None:
Expand All @@ -46,22 +95,9 @@ async def handle_entities(message: Message) -> None:

entities: list[str] = []
for match in ENTITY_REGEX.finditer(message.content):
repo_name = config.GITHUB_REPOS[match[1] or "main"]
if (repo := repo_cache.get(repo_name)) is None:
repo_cache[repo_name] = repo = gh.get_repo(
f"{config.GITHUB_ORG}/{repo_name}", lazy=True
)
entity_id = int(match[2])
try:
entity = repo.get_issue(entity_id)
kind = "Pull Request" if entity.pull_request else "Issue"
except github.UnknownObjectException:
try:
entity = get_discussion(repo, entity_id)
kind = "Discussion"
except github.GithubException:
continue
if entity_id < 10:
repo_name = cast(RepoName, match[1] or "main")
kind, entity = entity_cache[repo_name, int(match[2])]
if entity.number < 10:
# Ignore single-digit mentions (likely a false positive)
continue
entities.append(ENTITY_TEMPLATE.format(kind=kind, entity=entity))
Expand Down

0 comments on commit 4a01584

Please sign in to comment.