Skip to content

Commit

Permalink
Reduce amount of state pulled out when querying federation hierachy (#…
Browse files Browse the repository at this point in the history
…16785)

There are two changes here:

1. Only pull out the required state when handling the request.
2. Change the get filtered state return type to check that we're only
querying state that was requested

---------

Co-authored-by: reivilibre <[email protected]>
  • Loading branch information
erikjohnston and reivilibre committed Jan 10, 2024
1 parent 4c67f03 commit 578c5c7
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 3 deletions.
1 change: 1 addition & 0 deletions changelog.d/16785.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Reduce amount of state pulled out when querying federation hierachy.
12 changes: 11 additions & 1 deletion synapse/handlers/room_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from synapse.config.ratelimiting import RatelimitSettings
from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StrCollection
from synapse.types.state import StateFilter
from synapse.util.caches.response_cache import ResponseCache

if TYPE_CHECKING:
Expand Down Expand Up @@ -546,7 +547,16 @@ async def _is_local_room_accessible(
Returns:
True if the room is accessible to the requesting user or server.
"""
state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
event_types = [
(EventTypes.JoinRules, ""),
(EventTypes.RoomHistoryVisibility, ""),
]
if requester:
event_types.append((EventTypes.Member, requester))

state_ids = await self._storage_controllers.state.get_current_state_ids(
room_id, state_filter=StateFilter.from_types(event_types)
)

# If there's no state for the room, it isn't known.
if not state_ids:
Expand Down
48 changes: 46 additions & 2 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
Optional,
Set,
Tuple,
TypeVar,
Union,
cast,
overload,
)

import attr
Expand All @@ -52,7 +55,7 @@
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.types import JsonDict, JsonMapping, StateKey, StateMap
from synapse.types.state import StateFilter
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
Expand All @@ -64,6 +67,8 @@

logger = logging.getLogger(__name__)

_T = TypeVar("_T")


MAX_STATE_DELTA_HOPS = 100

Expand Down Expand Up @@ -349,7 +354,8 @@ async def get_partial_filtered_current_state_ids(
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
results = {}
results = StateMapWrapper(state_filter=state_filter or StateFilter.all())

sql = """
SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
Expand Down Expand Up @@ -726,3 +732,41 @@ def __init__(
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)


@attr.s(auto_attribs=True, slots=True)
class StateMapWrapper(Dict[StateKey, str]):
"""A wrapper around a StateMap[str] to ensure that we only query for items
that were not filtered out.
This is to help prevent bugs where we filter out state but other bits of the
code expect the state to be there.
"""

state_filter: StateFilter

def __getitem__(self, key: StateKey) -> str:
if key not in self.state_filter:
raise Exception("State map was filtered and doesn't include: %s", key)
return super().__getitem__(key)

@overload
def get(self, key: Tuple[str, str]) -> Optional[str]:
...

@overload
def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]:
...

def get(
self, key: StateKey, default: Union[str, _T, None] = None
) -> Union[str, _T, None]:
if key not in self.state_filter:
raise Exception("State map was filtered and doesn't include: %s", key)
return super().get(key, default)

def __contains__(self, key: Any) -> bool:
if key not in self.state_filter:
raise Exception("State map was filtered and doesn't include: %s", key)

return super().__contains__(key)
24 changes: 24 additions & 0 deletions synapse/types/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Dict,
Expand Down Expand Up @@ -584,6 +585,29 @@ def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool:
# local users only
return False

def __contains__(self, key: Any) -> bool:
if not isinstance(key, tuple) or len(key) != 2:
raise TypeError(
f"'in StateFilter' requires (str, str) as left operand, not {type(key).__name__}"
)

typ, state_key = key

if not isinstance(typ, str) or not isinstance(state_key, str):
raise TypeError(
f"'in StateFilter' requires (str, str) as left operand, not ({type(typ).__name__}, {type(state_key).__name__})"
)

if typ in self.types:
state_keys = self.types[typ]
if state_keys is None or state_key in state_keys:
return True

elif self.include_others:
return True

return False


_ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True)
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
Expand Down

0 comments on commit 578c5c7

Please sign in to comment.