Skip to content

Commit

Permalink
Add cache to get_tags_for_room(...)
Browse files Browse the repository at this point in the history
  • Loading branch information
MadLittleMods committed Sep 18, 2024
1 parent 61b7c31 commit 4ff42d3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
1 change: 1 addition & 0 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def _invalidate_caches_for_room(self, room_id: str) -> None:

self._attempt_to_invalidate_cache("get_account_data_for_room", None)
self._attempt_to_invalidate_cache("get_account_data_for_room_and_type", None)
self._attempt_to_invalidate_cache("get_tags_for_room", None)
self._attempt_to_invalidate_cache("get_aliases_for_room", (room_id,))
self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,))
self._attempt_to_invalidate_cache("_get_forward_extremeties_for_room", None)
Expand Down
13 changes: 13 additions & 0 deletions synapse/storage/databases/main/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def get_updated_tags_txn(txn: LoggingTransaction) -> List[str]:

return results

@cached(num_args=2, tree=True)
async def get_tags_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
Expand Down Expand Up @@ -213,6 +214,7 @@ def add_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)

self.get_tags_for_user.invalidate((user_id,))
self.get_tags_for_room.invalidate((user_id, room_id))

return self._account_data_id_gen.get_current_token()

Expand All @@ -237,6 +239,7 @@ def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)

self.get_tags_for_user.invalidate((user_id,))
self.get_tags_for_room.invalidate((user_id, room_id))

return self._account_data_id_gen.get_current_token()

Expand Down Expand Up @@ -290,9 +293,19 @@ def process_replication_rows(
rows: Iterable[Any],
) -> None:
if stream_name == AccountDataStream.NAME:
# Cast is safe because the `AccountDataStream` should only be giving us
# `AccountDataStreamRow`
rows: List[AccountDataStream.AccountDataStreamRow] = cast(
List[AccountDataStream.AccountDataStreamRow], rows
)

for row in rows:
if row.data_type == AccountDataTypes.TAG:
self.get_tags_for_user.invalidate((row.user_id,))
if row.room_id:
self.get_tags_for_room.invalidate((row.user_id, row.room_id))
else:
self.get_tags_for_room.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(
row.user_id, token
)
Expand Down

0 comments on commit 4ff42d3

Please sign in to comment.