diff --git a/crates/matrix-sdk-base/src/store/memory_store.rs b/crates/matrix-sdk-base/src/store/memory_store.rs index 2c8e1d84943..abc5fca09b4 100644 --- a/crates/matrix-sdk-base/src/store/memory_store.rs +++ b/crates/matrix-sdk-base/src/store/memory_store.rs @@ -14,7 +14,7 @@ use std::{ collections::{BTreeMap, BTreeSet, HashMap}, - sync::RwLock as StdRwLock, + sync::RwLock, }; use async_trait::async_trait; @@ -33,7 +33,7 @@ use ruma::{ CanonicalJsonObject, EventId, OwnedEventId, OwnedMxcUri, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, RoomVersionId, TransactionId, UserId, }; -use tracing::{debug, instrument, trace, warn}; +use tracing::{debug, instrument, warn}; use super::{ send_queue::{ChildTransactionId, QueuedRequest, SentRequestKey}, @@ -47,50 +47,49 @@ use crate::{ MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue, }; -/// In-memory, non-persistent implementation of the `StateStore`. -/// -/// Default if no other is configured at startup. -#[allow(clippy::type_complexity)] #[derive(Debug, Default)] -pub struct MemoryStore { - recently_visited_rooms: StdRwLock>>, - composer_drafts: StdRwLock>, - user_avatar_url: StdRwLock>, - sync_token: StdRwLock>, - server_capabilities: StdRwLock>, - filters: StdRwLock>, - utd_hook_manager_data: StdRwLock>, - account_data: StdRwLock>>, - profiles: StdRwLock>>, - display_names: StdRwLock>>>, - members: StdRwLock>>, - room_info: StdRwLock>, - room_state: StdRwLock< +#[allow(clippy::type_complexity)] +struct MemoryStoreInner { + recently_visited_rooms: HashMap>, + composer_drafts: HashMap, + user_avatar_url: HashMap, + sync_token: Option, + server_capabilities: Option, + filters: HashMap, + utd_hook_manager_data: Option, + account_data: HashMap>, + profiles: HashMap>, + display_names: HashMap>>, + members: HashMap>, + room_info: HashMap, + room_state: HashMap>>>, - >, - room_account_data: StdRwLock< + room_account_data: HashMap>>, - >, - stripped_room_state: StdRwLock< + stripped_room_state: HashMap>>>, + stripped_members: HashMap>, + presence: HashMap>, + room_user_receipts: HashMap< + OwnedRoomId, + HashMap<(String, Option), HashMap>, >, - stripped_members: StdRwLock>>, - presence: StdRwLock>>, - room_user_receipts: StdRwLock< - HashMap< - OwnedRoomId, - HashMap<(String, Option), HashMap>, - >, - >, - room_event_receipts: StdRwLock< - HashMap< - OwnedRoomId, - HashMap<(String, Option), HashMap>>, - >, + + room_event_receipts: HashMap< + OwnedRoomId, + HashMap<(String, Option), HashMap>>, >, - custom: StdRwLock, Vec>>, - send_queue_events: StdRwLock>>, - dependent_send_queue_events: StdRwLock>>, + custom: HashMap, Vec>, + send_queue_events: BTreeMap>, + dependent_send_queue_events: BTreeMap>, +} + +/// In-memory, non-persistent implementation of the `StateStore`. +/// +/// Default if no other is configured at startup. +#[derive(Debug, Default)] +pub struct MemoryStore { + inner: RwLock, } impl MemoryStore { @@ -106,9 +105,10 @@ impl MemoryStore { thread: ReceiptThread, user_id: &UserId, ) -> Option<(OwnedEventId, Receipt)> { - self.room_user_receipts + self.inner .read() .unwrap() + .room_user_receipts .get(room_id)? .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))? .get(user_id) @@ -123,9 +123,10 @@ impl MemoryStore { event_id: &EventId, ) -> Option> { Some( - self.room_event_receipts + self.inner .read() .unwrap() + .room_event_receipts .get(room_id)? .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))? .get(event_id)? @@ -142,50 +143,31 @@ impl StateStore for MemoryStore { type Error = StoreError; async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result> { + let inner = self.inner.read().unwrap(); Ok(match key { StateStoreDataKey::SyncToken => { - self.sync_token.read().unwrap().clone().map(StateStoreDataValue::SyncToken) + inner.sync_token.clone().map(StateStoreDataValue::SyncToken) } - StateStoreDataKey::ServerCapabilities => self - .server_capabilities - .read() - .unwrap() - .clone() - .map(StateStoreDataValue::ServerCapabilities), - StateStoreDataKey::Filter(filter_name) => self - .filters - .read() - .unwrap() - .get(filter_name) - .cloned() - .map(StateStoreDataValue::Filter), - StateStoreDataKey::UserAvatarUrl(user_id) => self - .user_avatar_url - .read() - .unwrap() - .get(user_id) - .cloned() - .map(StateStoreDataValue::UserAvatarUrl), - StateStoreDataKey::RecentlyVisitedRooms(user_id) => self + StateStoreDataKey::ServerCapabilities => { + inner.server_capabilities.clone().map(StateStoreDataValue::ServerCapabilities) + } + StateStoreDataKey::Filter(filter_name) => { + inner.filters.get(filter_name).cloned().map(StateStoreDataValue::Filter) + } + StateStoreDataKey::UserAvatarUrl(user_id) => { + inner.user_avatar_url.get(user_id).cloned().map(StateStoreDataValue::UserAvatarUrl) + } + StateStoreDataKey::RecentlyVisitedRooms(user_id) => inner .recently_visited_rooms - .read() - .unwrap() .get(user_id) .cloned() .map(StateStoreDataValue::RecentlyVisitedRooms), - StateStoreDataKey::UtdHookManagerData => self - .utd_hook_manager_data - .read() - .unwrap() - .clone() - .map(StateStoreDataValue::UtdHookManagerData), - StateStoreDataKey::ComposerDraft(room_id) => self - .composer_drafts - .read() - .unwrap() - .get(room_id) - .cloned() - .map(StateStoreDataValue::ComposerDraft), + StateStoreDataKey::UtdHookManagerData => { + inner.utd_hook_manager_data.clone().map(StateStoreDataValue::UtdHookManagerData) + } + StateStoreDataKey::ComposerDraft(room_id) => { + inner.composer_drafts.get(room_id).cloned().map(StateStoreDataValue::ComposerDraft) + } }) } @@ -194,25 +176,26 @@ impl StateStore for MemoryStore { key: StateStoreDataKey<'_>, value: StateStoreDataValue, ) -> Result<()> { + let mut inner = self.inner.write().unwrap(); match key { StateStoreDataKey::SyncToken => { - *self.sync_token.write().unwrap() = + inner.sync_token = Some(value.into_sync_token().expect("Session data not a sync token")) } StateStoreDataKey::Filter(filter_name) => { - self.filters.write().unwrap().insert( + inner.filters.insert( filter_name.to_owned(), value.into_filter().expect("Session data not a filter"), ); } StateStoreDataKey::UserAvatarUrl(user_id) => { - self.user_avatar_url.write().unwrap().insert( + inner.user_avatar_url.insert( user_id.to_owned(), value.into_user_avatar_url().expect("Session data not a user avatar url"), ); } StateStoreDataKey::RecentlyVisitedRooms(user_id) => { - self.recently_visited_rooms.write().unwrap().insert( + inner.recently_visited_rooms.insert( user_id.to_owned(), value .into_recently_visited_rooms() @@ -220,20 +203,20 @@ impl StateStore for MemoryStore { ); } StateStoreDataKey::UtdHookManagerData => { - *self.utd_hook_manager_data.write().unwrap() = Some( + inner.utd_hook_manager_data = Some( value .into_utd_hook_manager_data() .expect("Session data not the hook manager data"), ); } StateStoreDataKey::ComposerDraft(room_id) => { - self.composer_drafts.write().unwrap().insert( + inner.composer_drafts.insert( room_id.to_owned(), value.into_composer_draft().expect("Session data not a composer draft"), ); } StateStoreDataKey::ServerCapabilities => { - *self.server_capabilities.write().unwrap() = Some( + inner.server_capabilities = Some( value .into_server_capabilities() .expect("Session data not containing server capabilities"), @@ -245,25 +228,22 @@ impl StateStore for MemoryStore { } async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> { + let mut inner = self.inner.write().unwrap(); match key { - StateStoreDataKey::SyncToken => *self.sync_token.write().unwrap() = None, - StateStoreDataKey::ServerCapabilities => { - *self.server_capabilities.write().unwrap() = None - } + StateStoreDataKey::SyncToken => inner.sync_token = None, + StateStoreDataKey::ServerCapabilities => inner.server_capabilities = None, StateStoreDataKey::Filter(filter_name) => { - self.filters.write().unwrap().remove(filter_name); + inner.filters.remove(filter_name); } StateStoreDataKey::UserAvatarUrl(user_id) => { - self.user_avatar_url.write().unwrap().remove(user_id); + inner.user_avatar_url.remove(user_id); } StateStoreDataKey::RecentlyVisitedRooms(user_id) => { - self.recently_visited_rooms.write().unwrap().remove(user_id); - } - StateStoreDataKey::UtdHookManagerData => { - *self.utd_hook_manager_data.write().unwrap() = None + inner.recently_visited_rooms.remove(user_id); } + StateStoreDataKey::UtdHookManagerData => inner.utd_hook_manager_data = None, StateStoreDataKey::ComposerDraft(room_id) => { - self.composer_drafts.write().unwrap().remove(room_id); + inner.composer_drafts.remove(room_id); } } Ok(()) @@ -272,248 +252,202 @@ impl StateStore for MemoryStore { #[instrument(skip(self, changes))] async fn save_changes(&self, changes: &StateChanges) -> Result<()> { let now = Instant::now(); - // these trace calls are to debug https://github.com/matrix-org/complement-crypto/issues/77 - trace!("starting"); + + let mut inner = self.inner.write().unwrap(); if let Some(s) = &changes.sync_token { - *self.sync_token.write().unwrap() = Some(s.to_owned()); - trace!("assigned sync token"); + inner.sync_token = Some(s.to_owned()); } - trace!("profiles"); - { - let mut profiles = self.profiles.write().unwrap(); - - for (room, users) in &changes.profiles_to_delete { - let Some(room_profiles) = profiles.get_mut(room) else { - continue; - }; - for user in users { - room_profiles.remove(user); - } + for (room, users) in &changes.profiles_to_delete { + let Some(room_profiles) = inner.profiles.get_mut(room) else { + continue; + }; + for user in users { + room_profiles.remove(user); } + } - for (room, users) in &changes.profiles { - for (user_id, profile) in users { - profiles - .entry(room.clone()) - .or_default() - .insert(user_id.clone(), profile.clone()); - } + for (room, users) in &changes.profiles { + for (user_id, profile) in users { + inner + .profiles + .entry(room.clone()) + .or_default() + .insert(user_id.clone(), profile.clone()); } } - trace!("ambiguity maps"); for (room, map) in &changes.ambiguity_maps { for (display_name, display_names) in map { - self.display_names - .write() - .unwrap() + inner + .display_names .entry(room.clone()) .or_default() .insert(display_name.clone(), display_names.clone()); } } - trace!("account data"); - { - let mut account_data = self.account_data.write().unwrap(); - for (event_type, event) in &changes.account_data { - account_data.insert(event_type.clone(), event.clone()); + for (event_type, event) in &changes.account_data { + inner.account_data.insert(event_type.clone(), event.clone()); + } + + for (room, events) in &changes.room_account_data { + for (event_type, event) in events { + inner + .room_account_data + .entry(room.clone()) + .or_default() + .insert(event_type.clone(), event.clone()); } } - trace!("room account data"); - { - let mut room_account_data = self.room_account_data.write().unwrap(); - for (room, events) in &changes.room_account_data { - for (event_type, event) in events { - room_account_data + for (room, event_types) in &changes.state { + for (event_type, events) in event_types { + for (state_key, raw_event) in events { + inner + .room_state .entry(room.clone()) .or_default() - .insert(event_type.clone(), event.clone()); - } - } - } + .entry(event_type.clone()) + .or_default() + .insert(state_key.to_owned(), raw_event.clone()); + inner.stripped_room_state.remove(room); + + if *event_type == StateEventType::RoomMember { + let event = match raw_event.deserialize_as::() { + Ok(ev) => ev, + Err(e) => { + let event_id: Option = + raw_event.get_field("event_id").ok().flatten(); + debug!(event_id, "Failed to deserialize member event: {e}"); + continue; + } + }; - trace!("room state"); - { - let mut room_state = self.room_state.write().unwrap(); - trace!("room state: got room_state lock"); - let mut stripped_room_state = self.stripped_room_state.write().unwrap(); - trace!("room state: got stripped_room_state lock"); - let mut members = self.members.write().unwrap(); - trace!("room state: got members lock"); - let mut stripped_members = self.stripped_members.write().unwrap(); - trace!("room state: got stripped_members lock"); - - for (room, event_types) in &changes.state { - for (event_type, events) in event_types { - for (state_key, raw_event) in events { - room_state + inner.stripped_members.remove(room); + + inner + .members .entry(room.clone()) .or_default() - .entry(event_type.clone()) - .or_default() - .insert(state_key.to_owned(), raw_event.clone()); - stripped_room_state.remove(room); - - if *event_type == StateEventType::RoomMember { - let event = match raw_event.deserialize_as::() { - Ok(ev) => ev, - Err(e) => { - let event_id: Option = - raw_event.get_field("event_id").ok().flatten(); - debug!(event_id, "Failed to deserialize member event: {e}"); - continue; - } - }; - - stripped_members.remove(room); - - members - .entry(room.clone()) - .or_default() - .insert(event.state_key().to_owned(), event.membership().clone()); - } + .insert(event.state_key().to_owned(), event.membership().clone()); } } } } - trace!("room info"); - { - let mut room_info = self.room_info.write().unwrap(); - for (room_id, info) in &changes.room_infos { - room_info.insert(room_id.clone(), info.clone()); - } + for (room_id, info) in &changes.room_infos { + inner.room_info.insert(room_id.clone(), info.clone()); } - trace!("presence"); - { - let mut presence = self.presence.write().unwrap(); - for (sender, event) in &changes.presence { - presence.insert(sender.clone(), event.clone()); - } + for (sender, event) in &changes.presence { + inner.presence.insert(sender.clone(), event.clone()); } - trace!("stripped state"); - { - let mut stripped_room_state = self.stripped_room_state.write().unwrap(); - let mut stripped_members = self.stripped_members.write().unwrap(); + for (room, event_types) in &changes.stripped_state { + for (event_type, events) in event_types { + for (state_key, raw_event) in events { + inner + .stripped_room_state + .entry(room.clone()) + .or_default() + .entry(event_type.clone()) + .or_default() + .insert(state_key.to_owned(), raw_event.clone()); + + if *event_type == StateEventType::RoomMember { + let event = match raw_event.deserialize_as::() { + Ok(ev) => ev, + Err(e) => { + let event_id: Option = + raw_event.get_field("event_id").ok().flatten(); + debug!( + event_id, + "Failed to deserialize stripped member event: {e}" + ); + continue; + } + }; - for (room, event_types) in &changes.stripped_state { - for (event_type, events) in event_types { - for (state_key, raw_event) in events { - stripped_room_state + inner + .stripped_members .entry(room.clone()) .or_default() - .entry(event_type.clone()) - .or_default() - .insert(state_key.to_owned(), raw_event.clone()); - - if *event_type == StateEventType::RoomMember { - let event = match raw_event.deserialize_as::() - { - Ok(ev) => ev, - Err(e) => { - let event_id: Option = - raw_event.get_field("event_id").ok().flatten(); - debug!( - event_id, - "Failed to deserialize stripped member event: {e}" - ); - continue; - } - }; - - stripped_members - .entry(room.clone()) - .or_default() - .insert(event.state_key, event.content.membership.clone()); - } + .insert(event.state_key, event.content.membership.clone()); } } } } - trace!("receipts"); - { - let mut room_user_receipts = self.room_user_receipts.write().unwrap(); - let mut room_event_receipts = self.room_event_receipts.write().unwrap(); - - for (room, content) in &changes.receipts { - for (event_id, receipts) in &content.0 { - for (receipt_type, receipts) in receipts { - for (user_id, receipt) in receipts { - let thread = receipt.thread.as_str().map(ToOwned::to_owned); - // Add the receipt to the room user receipts - if let Some((old_event, _)) = room_user_receipts - .entry(room.clone()) - .or_default() - .entry((receipt_type.to_string(), thread.clone())) - .or_default() - .insert(user_id.clone(), (event_id.clone(), receipt.clone())) - { - // Remove the old receipt from the room event receipts - if let Some(receipt_map) = room_event_receipts.get_mut(room) { - if let Some(event_map) = receipt_map - .get_mut(&(receipt_type.to_string(), thread.clone())) - { - if let Some(user_map) = event_map.get_mut(&old_event) { - user_map.remove(user_id); - } + for (room, content) in &changes.receipts { + for (event_id, receipts) in &content.0 { + for (receipt_type, receipts) in receipts { + for (user_id, receipt) in receipts { + let thread = receipt.thread.as_str().map(ToOwned::to_owned); + // Add the receipt to the room user receipts + if let Some((old_event, _)) = inner + .room_user_receipts + .entry(room.clone()) + .or_default() + .entry((receipt_type.to_string(), thread.clone())) + .or_default() + .insert(user_id.clone(), (event_id.clone(), receipt.clone())) + { + // Remove the old receipt from the room event receipts + if let Some(receipt_map) = inner.room_event_receipts.get_mut(room) { + if let Some(event_map) = + receipt_map.get_mut(&(receipt_type.to_string(), thread.clone())) + { + if let Some(user_map) = event_map.get_mut(&old_event) { + user_map.remove(user_id); } } } - - // Add the receipt to the room event receipts - room_event_receipts - .entry(room.clone()) - .or_default() - .entry((receipt_type.to_string(), thread)) - .or_default() - .entry(event_id.clone()) - .or_default() - .insert(user_id.clone(), receipt.clone()); } + + // Add the receipt to the room event receipts + inner + .room_event_receipts + .entry(room.clone()) + .or_default() + .entry((receipt_type.to_string(), thread)) + .or_default() + .entry(event_id.clone()) + .or_default() + .insert(user_id.clone(), receipt.clone()); } } } } - trace!("room info/state"); - { - let room_info = self.room_info.read().unwrap(); - let mut room_state = self.room_state.write().unwrap(); - - let make_room_version = |room_id| { - room_info.get(room_id).and_then(|info| info.room_version().cloned()).unwrap_or_else( - || { - warn!(?room_id, "Unable to find the room version, assuming version 9"); - RoomVersionId::V9 - }, - ) - }; + let make_room_version = |room_info: &HashMap, room_id| { + room_info.get(room_id).and_then(|info| info.room_version().cloned()).unwrap_or_else( + || { + warn!(?room_id, "Unable to find the room version, assuming version 9"); + RoomVersionId::V9 + }, + ) + }; - for (room_id, redactions) in &changes.redactions { - let mut room_version = None; - if let Some(room) = room_state.get_mut(room_id) { - for ref_room_mu in room.values_mut() { - for raw_evt in ref_room_mu.values_mut() { - if let Ok(Some(event_id)) = - raw_evt.get_field::("event_id") - { - if let Some(redaction) = redactions.get(&event_id) { - let redacted = redact( - raw_evt.deserialize_as::()?, - room_version - .get_or_insert_with(|| make_room_version(room_id)), - Some(RedactedBecause::from_raw_event(redaction)?), - ) - .map_err(StoreError::Redaction)?; - *raw_evt = Raw::new(&redacted)?.cast(); - } + let inner = &mut *inner; + for (room_id, redactions) in &changes.redactions { + let mut room_version = None; + + if let Some(room) = inner.room_state.get_mut(room_id) { + for ref_room_mu in room.values_mut() { + for raw_evt in ref_room_mu.values_mut() { + if let Ok(Some(event_id)) = raw_evt.get_field::("event_id") { + if let Some(redaction) = redactions.get(&event_id) { + let redacted = redact( + raw_evt.deserialize_as::()?, + room_version.get_or_insert_with(|| { + make_room_version(&inner.room_info, room_id) + }), + Some(RedactedBecause::from_raw_event(redaction)?), + ) + .map_err(StoreError::Redaction)?; + *raw_evt = Raw::new(&redacted)?.cast(); } } } @@ -527,14 +461,14 @@ impl StateStore for MemoryStore { } async fn get_presence_event(&self, user_id: &UserId) -> Result>> { - Ok(self.presence.read().unwrap().get(user_id).cloned()) + Ok(self.inner.read().unwrap().presence.get(user_id).cloned()) } async fn get_presence_events( &self, user_ids: &[OwnedUserId], ) -> Result>> { - let presence = self.presence.read().unwrap(); + let presence = &self.inner.read().unwrap().presence; Ok(user_ids.iter().filter_map(|user_id| presence.get(user_id).cloned()).collect()) } @@ -566,18 +500,17 @@ impl StateStore for MemoryStore { Some(state_events.values().cloned().map(to_enum).collect()) } - let state_map = self.stripped_room_state.read().unwrap(); - Ok(get_events(&state_map, room_id, &event_type, RawAnySyncOrStrippedState::Stripped) - .or_else(|| { - drop(state_map); // release the lock on stripped_room_state - get_events( - &self.room_state.read().unwrap(), - room_id, - &event_type, - RawAnySyncOrStrippedState::Sync, - ) - }) - .unwrap_or_default()) + let inner = self.inner.read().unwrap(); + Ok(get_events( + &inner.stripped_room_state, + room_id, + &event_type, + RawAnySyncOrStrippedState::Stripped, + ) + .or_else(|| { + get_events(&inner.room_state, room_id, &event_type, RawAnySyncOrStrippedState::Sync) + }) + .unwrap_or_default()) } async fn get_state_events_for_keys( @@ -586,41 +519,31 @@ impl StateStore for MemoryStore { event_type: StateEventType, state_keys: &[&str], ) -> Result, Self::Error> { - Ok( - if let Some(stripped_state_events) = self - .stripped_room_state - .read() - .unwrap() - .get(room_id) - .and_then(|events| events.get(&event_type)) - { - state_keys - .iter() - .filter_map(|k| { - stripped_state_events - .get(*k) - .map(|e| RawAnySyncOrStrippedState::Stripped(e.clone())) - }) - .collect() - } else if let Some(sync_state_events) = self - .room_state - .read() - .unwrap() - .get(room_id) - .and_then(|events| events.get(&event_type)) - { - state_keys - .iter() - .filter_map(|k| { - sync_state_events - .get(*k) - .map(|e| RawAnySyncOrStrippedState::Sync(e.clone())) - }) - .collect() - } else { - Vec::new() - }, - ) + let inner = self.inner.read().unwrap(); + + if let Some(stripped_state_events) = + inner.stripped_room_state.get(room_id).and_then(|events| events.get(&event_type)) + { + Ok(state_keys + .iter() + .filter_map(|k| { + stripped_state_events + .get(*k) + .map(|e| RawAnySyncOrStrippedState::Stripped(e.clone())) + }) + .collect()) + } else if let Some(sync_state_events) = + inner.room_state.get(room_id).and_then(|events| events.get(&event_type)) + { + Ok(state_keys + .iter() + .filter_map(|k| { + sync_state_events.get(*k).map(|e| RawAnySyncOrStrippedState::Sync(e.clone())) + }) + .collect()) + } else { + Ok(Vec::new()) + } } async fn get_profile( @@ -629,9 +552,10 @@ impl StateStore for MemoryStore { user_id: &UserId, ) -> Result> { Ok(self - .profiles + .inner .read() .unwrap() + .profiles .get(room_id) .and_then(|room_profiles| room_profiles.get(user_id)) .cloned()) @@ -646,7 +570,7 @@ impl StateStore for MemoryStore { return Ok(BTreeMap::new()); } - let profiles = self.profiles.read().unwrap(); + let profiles = &self.inner.read().unwrap().profiles; let Some(room_profiles) = profiles.get(room_id) else { return Ok(BTreeMap::new()); }; @@ -686,17 +610,16 @@ impl StateStore for MemoryStore { }) .unwrap_or_default() } - let state_map = self.stripped_members.read().unwrap(); - let v = get_user_ids_inner(&state_map, room_id, memberships); + let inner = self.inner.read().unwrap(); + let v = get_user_ids_inner(&inner.stripped_members, room_id, memberships); if !v.is_empty() { return Ok(v); } - drop(state_map); // release the stripped_members lock - Ok(get_user_ids_inner(&self.members.read().unwrap(), room_id, memberships)) + Ok(get_user_ids_inner(&inner.members, room_id, memberships)) } async fn get_room_infos(&self) -> Result> { - Ok(self.room_info.read().unwrap().values().cloned().collect()) + Ok(self.inner.read().unwrap().room_info.values().cloned().collect()) } async fn get_users_with_display_name( @@ -705,9 +628,10 @@ impl StateStore for MemoryStore { display_name: &DisplayName, ) -> Result> { Ok(self - .display_names + .inner .read() .unwrap() + .display_names .get(room_id) .and_then(|room_names| room_names.get(display_name).cloned()) .unwrap_or_default()) @@ -722,8 +646,8 @@ impl StateStore for MemoryStore { return Ok(HashMap::new()); } - let read_guard = &self.display_names.read().unwrap(); - let Some(room_names) = read_guard.get(room_id) else { + let inner = self.inner.read().unwrap(); + let Some(room_names) = inner.display_names.get(room_id) else { return Ok(HashMap::new()); }; @@ -734,7 +658,7 @@ impl StateStore for MemoryStore { &self, event_type: GlobalAccountDataEventType, ) -> Result>> { - Ok(self.account_data.read().unwrap().get(&event_type).cloned()) + Ok(self.inner.read().unwrap().account_data.get(&event_type).cloned()) } async fn get_room_account_data_event( @@ -743,9 +667,10 @@ impl StateStore for MemoryStore { event_type: RoomAccountDataEventType, ) -> Result>> { Ok(self - .room_account_data + .inner .read() .unwrap() + .room_account_data .get(room_id) .and_then(|m| m.get(&event_type)) .cloned()) @@ -774,30 +699,32 @@ impl StateStore for MemoryStore { } async fn get_custom_value(&self, key: &[u8]) -> Result>> { - Ok(self.custom.read().unwrap().get(key).cloned()) + Ok(self.inner.read().unwrap().custom.get(key).cloned()) } async fn set_custom_value(&self, key: &[u8], value: Vec) -> Result>> { - Ok(self.custom.write().unwrap().insert(key.to_vec(), value)) + Ok(self.inner.write().unwrap().custom.insert(key.to_vec(), value)) } async fn remove_custom_value(&self, key: &[u8]) -> Result>> { - Ok(self.custom.write().unwrap().remove(key)) + Ok(self.inner.write().unwrap().custom.remove(key)) } async fn remove_room(&self, room_id: &RoomId) -> Result<()> { - self.profiles.write().unwrap().remove(room_id); - self.display_names.write().unwrap().remove(room_id); - self.members.write().unwrap().remove(room_id); - self.room_info.write().unwrap().remove(room_id); - self.room_state.write().unwrap().remove(room_id); - self.room_account_data.write().unwrap().remove(room_id); - self.stripped_room_state.write().unwrap().remove(room_id); - self.stripped_members.write().unwrap().remove(room_id); - self.room_user_receipts.write().unwrap().remove(room_id); - self.room_event_receipts.write().unwrap().remove(room_id); - self.send_queue_events.write().unwrap().remove(room_id); - self.dependent_send_queue_events.write().unwrap().remove(room_id); + let mut inner = self.inner.write().unwrap(); + + inner.profiles.remove(room_id); + inner.display_names.remove(room_id); + inner.members.remove(room_id); + inner.room_info.remove(room_id); + inner.room_state.remove(room_id); + inner.room_account_data.remove(room_id); + inner.stripped_room_state.remove(room_id); + inner.stripped_members.remove(room_id); + inner.room_user_receipts.remove(room_id); + inner.room_event_receipts.remove(room_id); + inner.send_queue_events.remove(room_id); + inner.dependent_send_queue_events.remove(room_id); Ok(()) } @@ -809,9 +736,10 @@ impl StateStore for MemoryStore { kind: QueuedRequestKind, priority: usize, ) -> Result<(), Self::Error> { - self.send_queue_events + self.inner .write() .unwrap() + .send_queue_events .entry(room_id.to_owned()) .or_default() .push(QueuedRequest { kind, transaction_id, error: None, priority }); @@ -825,9 +753,10 @@ impl StateStore for MemoryStore { kind: QueuedRequestKind, ) -> Result { if let Some(entry) = self - .send_queue_events + .inner .write() .unwrap() + .send_queue_events .entry(room_id.to_owned()) .or_default() .iter_mut() @@ -846,7 +775,8 @@ impl StateStore for MemoryStore { room_id: &RoomId, transaction_id: &TransactionId, ) -> Result { - let mut q = self.send_queue_events.write().unwrap(); + let mut inner = self.inner.write().unwrap(); + let q = &mut inner.send_queue_events; let entry = q.get_mut(room_id); if let Some(entry) = entry { @@ -868,8 +798,14 @@ impl StateStore for MemoryStore { &self, room_id: &RoomId, ) -> Result, Self::Error> { - let mut ret = - self.send_queue_events.write().unwrap().entry(room_id.to_owned()).or_default().clone(); + let mut ret = self + .inner + .write() + .unwrap() + .send_queue_events + .entry(room_id.to_owned()) + .or_default() + .clone(); // Inverted order of priority, use stable sort to keep insertion order. ret.sort_by(|lhs, rhs| rhs.priority.cmp(&lhs.priority)); Ok(ret) @@ -882,9 +818,10 @@ impl StateStore for MemoryStore { error: Option, ) -> Result<(), Self::Error> { if let Some(entry) = self - .send_queue_events + .inner .write() .unwrap() + .send_queue_events .entry(room_id.to_owned()) .or_default() .iter_mut() @@ -896,7 +833,7 @@ impl StateStore for MemoryStore { } async fn load_rooms_with_unsent_requests(&self) -> Result, Self::Error> { - Ok(self.send_queue_events.read().unwrap().keys().cloned().collect()) + Ok(self.inner.read().unwrap().send_queue_events.keys().cloned().collect()) } async fn save_dependent_queued_request( @@ -906,14 +843,18 @@ impl StateStore for MemoryStore { own_transaction_id: ChildTransactionId, content: DependentQueuedRequestKind, ) -> Result<(), Self::Error> { - self.dependent_send_queue_events.write().unwrap().entry(room.to_owned()).or_default().push( - DependentQueuedRequest { + self.inner + .write() + .unwrap() + .dependent_send_queue_events + .entry(room.to_owned()) + .or_default() + .push(DependentQueuedRequest { kind: content, parent_transaction_id: parent_transaction_id.to_owned(), own_transaction_id, parent_key: None, - }, - ); + }); Ok(()) } @@ -923,8 +864,8 @@ impl StateStore for MemoryStore { parent_txn_id: &TransactionId, sent_parent_key: SentRequestKey, ) -> Result { - let mut dependent_send_queue_events = self.dependent_send_queue_events.write().unwrap(); - let dependents = dependent_send_queue_events.entry(room.to_owned()).or_default(); + let mut inner = self.inner.write().unwrap(); + let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default(); let mut num_updated = 0; for d in dependents.iter_mut().filter(|item| item.parent_transaction_id == parent_txn_id) { d.parent_key = Some(sent_parent_key.clone()); @@ -939,8 +880,8 @@ impl StateStore for MemoryStore { own_transaction_id: &ChildTransactionId, new_content: DependentQueuedRequestKind, ) -> Result { - let mut dependent_send_queue_events = self.dependent_send_queue_events.write().unwrap(); - let dependents = dependent_send_queue_events.entry(room.to_owned()).or_default(); + let mut inner = self.inner.write().unwrap(); + let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default(); for d in dependents.iter_mut() { if d.own_transaction_id == *own_transaction_id { d.kind = new_content; @@ -955,8 +896,8 @@ impl StateStore for MemoryStore { room: &RoomId, txn_id: &ChildTransactionId, ) -> Result { - let mut dependent_send_queue_events = self.dependent_send_queue_events.write().unwrap(); - let dependents = dependent_send_queue_events.entry(room.to_owned()).or_default(); + let mut inner = self.inner.write().unwrap(); + let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default(); if let Some(pos) = dependents.iter().position(|item| item.own_transaction_id == *txn_id) { dependents.remove(pos); Ok(true) @@ -973,7 +914,14 @@ impl StateStore for MemoryStore { &self, room: &RoomId, ) -> Result, Self::Error> { - Ok(self.dependent_send_queue_events.read().unwrap().get(room).cloned().unwrap_or_default()) + Ok(self + .inner + .read() + .unwrap() + .dependent_send_queue_events + .get(room) + .cloned() + .unwrap_or_default()) } }