Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add locking to more safely delete state groups #18107

Draft
wants to merge 5 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/18107.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix rare edge case where state groups could be deleted while we are persisting new events that reference them.
18 changes: 14 additions & 4 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class FederationEventHandler:
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
self._state_store = hs.get_datastores().state
self._state_deletion_store = hs.get_datastores().state_deletion
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state

Expand Down Expand Up @@ -580,7 +582,9 @@ async def process_remote_join(
room_version.identifier,
state_maps_to_resolve,
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)
)
else:
Expand Down Expand Up @@ -1179,7 +1183,9 @@ async def _compute_event_context_with_maybe_missing_prevs(
room_version,
state_maps,
event_map={event_id: event},
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)

except Exception as e:
Expand Down Expand Up @@ -1874,7 +1880,9 @@ async def _check_event_auth(
room_version,
[local_state_id_map, claimed_auth_events_id_map],
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)
)
else:
Expand Down Expand Up @@ -2014,7 +2022,9 @@ async def _check_for_soft_fail(
room_version,
state_sets,
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)
)
else:
Expand Down
61 changes: 55 additions & 6 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func
from synapse.util.stringutils import shortstr

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.controllers import StateStorageController
from synapse.storage.databases.main import DataStore
from synapse.storage.databases.state.deletion import StateDeletionDataStore

logger = logging.getLogger(__name__)
metrics_logger = logging.getLogger("synapse.state.metrics")
Expand Down Expand Up @@ -194,6 +196,8 @@ def __init__(self, hs: "HomeServer"):
self._storage_controllers = hs.get_storage_controllers()
self._events_shard_config = hs.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
self._state_store = hs.get_datastores().state
self._state_epoch_store = hs.get_datastores().state_deletion

self._update_current_state_client = (
ReplicationUpdateCurrentStateRestServlet.make_client(hs)
Expand Down Expand Up @@ -475,7 +479,10 @@ async def compute_event_context(
@trace
@measure_func()
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
self,
room_id: str,
event_ids: StrCollection,
await_full_state: bool = True,
) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
Expand Down Expand Up @@ -511,6 +518,19 @@ async def resolve_state_groups_for_events(
) = await self._state_storage_controller.get_state_group_delta(
state_group_id
)

if prev_group:
# Ensure that we still have the prev group, and ensure we don't
# delete it while we're persisting the event.
missing_state_group = (
await self._state_epoch_store.check_state_groups_and_bump_deletion(
{prev_group}
)
)
if missing_state_group:
prev_group = None
delta_ids = None

return _StateCacheEntry(
state=None,
state_group=state_group_id,
Expand All @@ -531,7 +551,7 @@ async def resolve_state_groups_for_events(
room_version,
state_to_resolve,
None,
state_res_store=StateResolutionStore(self.store),
state_res_store=StateResolutionStore(self.store, self._state_epoch_store),
)
return result

Expand Down Expand Up @@ -663,14 +683,42 @@ async def resolve_state_groups(
async with self.resolve_linearizer.queue(group_names):
cache = self._state_cache.get(group_names, None)
if cache:
return cache
# Check that the returned cache entry doesn't point to deleted
# state groups.
state_groups_to_check = set()
if cache.state_group is not None:
state_groups_to_check.add(cache.state_group)

if cache.prev_group is not None:
state_groups_to_check.add(cache.prev_group)

missing_state_groups = await state_res_store.state_deletion_store.check_state_groups_and_bump_deletion(
state_groups_to_check
)

if not missing_state_groups:
return cache
else:
# There are missing state groups, so let's remove the stale
# entry and continue as if it was a cache miss.
self._state_cache.pop(group_names, None)

logger.info(
"Resolving state for %s with groups %s",
room_id,
list(group_names),
)

# We double check that none of the state groups have been deleted.
# They shouldn't be as all these state groups should be referenced.
missing_state_groups = await state_res_store.state_deletion_store.check_state_groups_and_bump_deletion(
group_names
)
if missing_state_groups:
raise Exception(
f"State groups have been deleted: {shortstr(missing_state_groups)}"
)

state_groups_histogram.observe(len(state_groups_ids))

new_state = await self.resolve_events_with_store(
Expand Down Expand Up @@ -884,7 +932,8 @@ class StateResolutionStore:
in well defined way.
"""

store: "DataStore"
main_store: "DataStore"
state_deletion_store: "StateDeletionDataStore"

def get_events(
self, event_ids: StrCollection, allow_rejected: bool = False
Expand All @@ -899,7 +948,7 @@ def get_events(
An awaitable which resolves to a dict from event_id to event.
"""

return self.store.get_events(
return self.main_store.get_events(
event_ids,
redact_behaviour=EventRedactBehaviour.as_is,
get_prev_content=False,
Expand All @@ -920,4 +969,4 @@ def get_auth_chain_difference(
An awaitable that resolves to a set of event IDs.
"""

return self.store.get_auth_chain_difference(room_id, state_sets)
return self.main_store.get_auth_chain_difference(room_id, state_sets)
32 changes: 21 additions & 11 deletions synapse/storage/controllers/persist_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def __init__(
# store for now.
self.main_store = stores.main
self.state_store = stores.state
self._state_deletion_store = stores.state_deletion

assert stores.persist_events
self.persist_events_store = stores.persist_events
Expand Down Expand Up @@ -549,7 +550,9 @@ async def _calculate_current_state(self, room_id: str) -> StateMap[str]:
room_version,
state_maps_by_state_group,
event_map=None,
state_res_store=StateResolutionStore(self.main_store),
state_res_store=StateResolutionStore(
self.main_store, self._state_deletion_store
),
)

return await res.get_state(self._state_controller, StateFilter.all())
Expand Down Expand Up @@ -635,15 +638,20 @@ async def _persist_event_batch(
room_id, [e for e, _ in chunk]
)

await self.persist_events_store._persist_events_and_state_updates(
room_id,
chunk,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled,
new_event_links=new_event_links,
)
# Stop the state groups from being deleted while we're persisting
# them.
async with self._state_deletion_store.persisting_state_group_references(
events_and_contexts
):
await self.persist_events_store._persist_events_and_state_updates(
room_id,
chunk,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled,
new_event_links=new_event_links,
)

return replaced_events

Expand Down Expand Up @@ -965,7 +973,9 @@ async def _get_new_state_after_events(
room_version,
state_groups,
events_map,
state_res_store=StateResolutionStore(self.main_store),
state_res_store=StateResolutionStore(
self.main_store, self._state_deletion_store
),
)

state_resolutions_during_persistence.inc()
Expand Down
10 changes: 8 additions & 2 deletions synapse/storage/databases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.databases.state import StateGroupDataStore
from synapse.storage.databases.state.deletion import StateDeletionDataStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database

Expand All @@ -49,12 +50,14 @@ class Databases(Generic[DataStoreT]):
main
state
persist_events
state_deletion
"""

databases: List[DatabasePool]
main: "DataStore" # FIXME: https://github.com/matrix-org/synapse/issues/11165: actually an instance of `main_store_class`
state: StateGroupDataStore
persist_events: Optional[PersistEventsStore]
state_deletion: StateDeletionDataStore

def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
# Note we pass in the main store class here as workers use a different main
Expand All @@ -63,6 +66,7 @@ def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
self.databases = []
main: Optional[DataStoreT] = None
state: Optional[StateGroupDataStore] = None
state_deletion: Optional[StateDeletionDataStore] = None
persist_events: Optional[PersistEventsStore] = None

for database_config in hs.config.database.databases:
Expand Down Expand Up @@ -114,7 +118,8 @@ def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
if state:
raise Exception("'state' data store already configured")

state = StateGroupDataStore(database, db_conn, hs)
state_deletion = StateDeletionDataStore(database, db_conn, hs)
state = StateGroupDataStore(database, db_conn, hs, state_deletion)

db_conn.commit()

Expand All @@ -135,11 +140,12 @@ def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
if not main:
raise Exception("No 'main' database configured")

if not state:
if not state or not state_deletion:
raise Exception("No 'state' database configured")

# We use local variables here to ensure that the databases do not have
# optional types.
self.main = main # type: ignore[assignment]
self.state = state
self.persist_events = persist_events
self.state_deletion = state_deletion
Loading
Loading