diff --git a/.lnav-json-config.json b/.lnav-json-config.json new file mode 100644 index 000000000..ffaad31dd --- /dev/null +++ b/.lnav-json-config.json @@ -0,0 +1,115 @@ +{ + "$schema": "https://lnav.org/schemas/format-v1.schema.json", + "libxmtp_json_log": { + "title": "Libxmtp Log", + "url": "https://github.com/xmtp/libxmtp", + "json": true, + "file-type": "json", + "hide-extra": true, + "level-field": "level", + "timestamp-field": "timestamp", + "timestamp-format": "%a, %b %d, %Y %I:%M:%S %p %Z", + "body-field": "message", + "level": { + "error": "ERROR", + "warning": "WARN", + "info": "INFO", + "debug": "DEBUG", + "trace": "TRACE" + }, + "line-format": [ + { + "field": "__timestamp__", + "align": "right" + }, + " ", + { + "field": "__level__", + "min-width": 4, + "max-width": 4, + "align": "right", + "text-transform": "uppercase", + "suffix": ":" + }, + " ", + { + "prefix": "target=", + "field": "target", + "align": "left", + "default-value": "" + }, + " ", + { + "field": "message" + }, + " ", + { + "prefix": "span: ", + "field": "span", + "default-value": "" + }, + " ", + { + "prefix": "spans: ", + "field": "spans", + "default-value": "" + }, + " ", + { + "prefix": "signer=", + "field": "signer", + "default-value": "" + }, + " ", + { + "prefix": "missing_signatures=", + "field": "missing_signatures", + "default-value": "" + }, + " ", + { + "prefix": "inbox_id=", + "field": "inbox_id", + "default-value": "" + }, + " ", + { + "prefix": "sender_inbox_id=", + "field": "sender_inbox_id", + "default-value": "" + }, + " ", + { + "prefix": "installation_id=", + "field": "installation_id", + "default-value": "" + }, + " ", + { + "prefix": "sender_installation_id=", + "field": "sender_installation_id", + "default-value": "" + }, + " ", + { + "prefix": "group_id=", + "field": "group_id", + "default-value": "" + } + ], + "value": { + "message": { "kind": "string" }, + "target": { "kind": "string" }, + "signer": { "kind": "string" }, + "message_id": { "kind": "string" }, + "installation_id": { "kind": "string" }, + "sender_installation_id": { "kind": "string" }, + "inbox_id": { "kind": "string" }, + "sender_inbox_id": { "kind": "string" }, + "group_id": { "kind": "string" }, + "missing_signatures": { "kind": "string" }, + "span": { "kind": "string", "hidden": true }, + "spans": { "kind": "string", "hidden": true } + } + } +} diff --git a/Cargo.lock b/Cargo.lock index f4b4f3a53..336e606c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7480,7 +7480,6 @@ version = "0.1.0" dependencies = [ "aes-gcm", "anyhow", - "async-stream", "async-trait", "bincode", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 73137934c..1f9f270c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -91,6 +91,7 @@ sqlite-web = "0.0.1" tonic = { version = "0.12", default-features = false } tracing = { version = "0.1", features = ["log"] } tracing-subscriber = { version = "0.3", default-features = false } +tracing-logfmt = "0.3" trait-variant = "0.1.2" url = "2.5.0" wasm-bindgen = "=0.2.100" diff --git a/xmtp_debug/Cargo.toml b/xmtp_debug/Cargo.toml index a8284800f..9ed97b708 100644 --- a/xmtp_debug/Cargo.toml +++ b/xmtp_debug/Cargo.toml @@ -21,7 +21,7 @@ xmtp_proto.workspace = true openmls.workspace = true indicatif = "0.17" color-eyre = "0.6" -tracing-logfmt = "0.3" +tracing-logfmt.workspace = true owo-colors = "4.1" url.workspace = true redb = "2.4" diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index 506e0b698..7ee73643f 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -45,7 +45,6 @@ update-schema = ["toml"] [dependencies] aes-gcm = { version = "0.10.3", features = ["std"] } -async-stream.workspace = true async-trait.workspace = true bincode.workspace = true diesel_migrations.workspace = true diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index abc134024..2be2dfffd 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -735,9 +735,7 @@ impl MlsGroup { /// Send a message on this users XMTP [`Client`]. pub async fn send_message(&self, message: &[u8]) -> Result, GroupError> { - tracing::debug!(inbox_id = self.client.inbox_id(), "sending message"); - let conn = self.context().store().conn()?; - let provider = XmtpOpenMlsProvider::from(conn); + let provider = self.mls_provider()?; self.send_message_with_provider(message, &provider).await } diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index 55f735853..2b9c2ae4a 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -9,8 +9,8 @@ use crate::{ groups::ScopedGroupClient, storage::group_message::StoredGroupMessage, subscriptions::{ - stream_messages::{MessagesStreamInfo, ProcessMessageFuture, StreamGroupMessages}, - SubscribeError, + stream_messages::{ProcessMessageFuture, StreamGroupMessages}, + Result, SubscribeError, }, }; use xmtp_proto::api_client::{trait_impls::XmtpApi, XmtpMlsStreams}; @@ -24,36 +24,34 @@ impl MlsGroup { pub async fn process_streamed_group_message( &self, envelope_bytes: Vec, - ) -> Result { + ) -> Result { let envelope = GroupMessage::decode(envelope_bytes.as_slice())?; ProcessMessageFuture::new(&self.client, envelope)? .process() .await + .map(|(group, _)| group) } pub async fn stream<'a>( &'a self, - ) -> Result< - impl Stream> + use<'a, ScopedClient>, - SubscribeError, - > + ) -> Result> + use<'a, ScopedClient>> where ::ApiClient: XmtpMlsStreams + 'a, { - let group_list = HashMap::from([(self.group_id.clone(), MessagesStreamInfo { cursor: 0 })]); - Ok(StreamGroupMessages::new(&self.client, &group_list).await?) + let group_list = HashMap::from([(self.group_id.clone(), 0u64.into())]); + Ok(StreamGroupMessages::new(&self.client, group_list).await?) } pub fn stream_with_callback( client: ScopedClient, group_id: Vec, - callback: impl FnMut(Result) + Send + 'static, - ) -> impl crate::StreamHandle> + callback: impl FnMut(Result) + Send + 'static, + ) -> impl crate::StreamHandle> where ScopedClient: 'static, ::ApiClient: XmtpMlsStreams + 'static, { - let group_list = HashMap::from([(group_id, MessagesStreamInfo { cursor: 0 })]); + let group_list = HashMap::from([(group_id, 0)]); stream_messages_with_callback(client, group_list, callback) } } @@ -62,9 +60,9 @@ impl MlsGroup { /// messages along to a callback. pub(crate) fn stream_messages_with_callback( client: ScopedClient, - group_id_to_info: HashMap, MessagesStreamInfo>, - mut callback: impl FnMut(Result) + Send + 'static, -) -> impl crate::StreamHandle> + active_conversations: HashMap, u64>, + mut callback: impl FnMut(Result) + Send + 'static, +) -> impl crate::StreamHandle> where ScopedClient: ScopedGroupClient + 'static, ::ApiClient: XmtpApi + XmtpMlsStreams + 'static, @@ -73,7 +71,11 @@ where crate::spawn(Some(rx), async move { let client_ref = &client; - let stream = StreamGroupMessages::new(client_ref, &group_id_to_info).await?; + let active_conversations = active_conversations + .into_iter() + .map(|(g, c)| (g, c.into())) + .collect(); + let stream = StreamGroupMessages::new(client_ref, active_conversations).await?; futures::pin_mut!(stream); let _ = tx.send(()); while let Some(message) = stream.next().await { diff --git a/xmtp_mls/src/subscriptions/mod.rs b/xmtp_mls/src/subscriptions/mod.rs index c71342af8..cded55008 100644 --- a/xmtp_mls/src/subscriptions/mod.rs +++ b/xmtp_mls/src/subscriptions/mod.rs @@ -1,12 +1,6 @@ use futures::{FutureExt, Stream, StreamExt}; use prost::Message; -use std::{ - collections::{HashMap, HashSet}, - future::Future, - pin::Pin, - sync::Arc, - task::Poll, -}; +use std::{collections::HashSet, future::Future, pin::Pin, sync::Arc, task::Poll}; use tokio::{ sync::{broadcast, oneshot}, task::JoinHandle, @@ -16,15 +10,14 @@ use tracing::instrument; use xmtp_id::scw_verifier::SmartContractSignatureVerifier; use xmtp_proto::{api_client::XmtpMlsStreams, xmtp::mls::api::v1::WelcomeMessage}; +use stream_all::StreamAllMessages; use stream_conversations::{ProcessWelcomeFuture, StreamConversations, WelcomeOrGroup}; -use stream_messages::{MessagesStreamInfo, StreamGroupMessages}; -// mod stream_all; +mod stream_all; mod stream_conversations; pub(crate) mod stream_messages; use crate::{ - client::ClientError, groups::{ device_sync::preference_sync::UserPreferenceUpdate, mls_sync::GroupMessageProcessingError, GroupError, MlsGroup, @@ -33,13 +26,15 @@ use crate::{ consent_record::StoredConsentRecord, group::{ConversationType, GroupQueryArgs, StoredGroup}, group_message::StoredGroupMessage, - ProviderTransactions, StorageError, NotFound + ProviderTransactions, StorageError, NotFound, group::ConversationType }, Client, XmtpApi, }; use thiserror::Error; use xmtp_common::{retryable, RetryableError}; +pub(crate) type Result = std::result::Result; + #[derive(Debug, Error)] pub enum LocalEventError { #[error("Unable to send event: {0}")] @@ -59,8 +54,8 @@ pub struct FutureWrapper<'a, O> { } #[cfg(target_arch = "wasm32")] -pub struct FutureWrapper<'a, C> { - inner: Pin, Option)>> + 'a>>, +pub struct FutureWrapper<'a, O> { + inner: Pin + 'a>>, } impl<'a, O> Future for FutureWrapper<'a, O> { @@ -199,18 +194,14 @@ impl LocalEvents { } pub(crate) trait StreamMessages { - fn stream_sync_messages(self) -> impl Stream>; - fn stream_consent_updates( - self, - ) -> impl Stream, SubscribeError>>; - fn stream_preference_updates( - self, - ) -> impl Stream, SubscribeError>>; + fn stream_sync_messages(self) -> impl Stream>; + fn stream_consent_updates(self) -> impl Stream>>; + fn stream_preference_updates(self) -> impl Stream>>; } impl StreamMessages for broadcast::Receiver { #[instrument(level = "trace", skip_all)] - fn stream_sync_messages(self) -> impl Stream> { + fn stream_sync_messages(self) -> impl Stream> { BroadcastStream::new(self).filter_map(|event| async { xmtp_common::optify!(event, "Missed message due to event queue lag") .and_then(LocalEvents::sync_filter) @@ -218,9 +209,7 @@ impl StreamMessages for broadcast::Receiver { }) } - fn stream_consent_updates( - self, - ) -> impl Stream, SubscribeError>> { + fn stream_consent_updates(self) -> impl Stream>> { BroadcastStream::new(self).filter_map(|event| async { xmtp_common::optify!(event, "Missed message due to event queue lag") .and_then(LocalEvents::consent_filter) @@ -228,9 +217,7 @@ impl StreamMessages for broadcast::Receiver { }) } - fn stream_preference_updates( - self, - ) -> impl Stream, SubscribeError>> { + fn stream_preference_updates(self) -> impl Stream>> { BroadcastStream::new(self).filter_map(|event| async { xmtp_common::optify!(event, "Missed message due to event queue lag") .and_then(LocalEvents::preference_filter) @@ -254,19 +241,6 @@ impl From> for JoinHandle { } } -impl From for (Vec, MessagesStreamInfo) { - fn from(group: StoredGroup) -> (Vec, MessagesStreamInfo) { - (group.id, MessagesStreamInfo { cursor: 0 }) - } -} - -// TODO: REMOVE BEFORE MERGING -// TODO: REMOVE BEFORE MERGING -// TODO: REMOVE BEFORE MERGING -pub(self) mod temp { - pub(super) type Result = std::result::Result; -} - #[derive(thiserror::Error, Debug)] pub enum SubscribeError { #[error(transparent)] @@ -321,7 +295,7 @@ where pub async fn process_streamed_welcome_message( &self, envelope_bytes: Vec, - ) -> Result, SubscribeError> { + ) -> Result> { let provider = self.mls_provider()?; let conn = provider.conn_ref(); let envelope = WelcomeMessage::decode(envelope_bytes.as_slice()) @@ -348,7 +322,7 @@ where pub async fn stream_conversations<'a>( &'a self, conversation_type: Option, - ) -> Result, SubscribeError>> + 'a, SubscribeError> + ) -> Result>> + 'a> where ApiClient: XmtpMlsStreams, { @@ -364,8 +338,8 @@ where pub fn stream_conversations_with_callback( client: Arc>, conversation_type: Option, - mut convo_callback: impl FnMut(Result, SubscribeError>) + Send + 'static, - ) -> impl crate::StreamHandle> { + mut convo_callback: impl FnMut(Result>) + Send + 'static, + ) -> impl crate::StreamHandle> { let (tx, rx) = oneshot::channel(); crate::spawn(Some(rx), async move { @@ -385,107 +359,22 @@ where pub async fn stream_all_messages( &self, conversation_type: Option, - ) -> Result> + '_, SubscribeError> - { + ) -> Result> + '_> { tracing::debug!( inbox_id = self.inbox_id(), + installation_id = %self.context().installation_public_key(), conversation_type = ?conversation_type, "stream all messages" ); - let mut group_list = async { - let provider = self.mls_provider()?; - self.sync_welcomes(&provider).await?; - - let group_list = provider - .conn_ref() - .find_groups(GroupQueryArgs::default().maybe_conversation_type(conversation_type))? - .into_iter() - .map(Into::into) - .collect::, MessagesStreamInfo>>(); - Ok::<_, SubscribeError>(group_list) - } - .await?; - - let stream = async_stream::stream! { - let messages_stream = StreamGroupMessages::new( - self, - &group_list - ) - .await?; - futures::pin_mut!(messages_stream); - - let convo_stream = self.stream_conversations(conversation_type).await?; - futures::pin_mut!(convo_stream); - - tracing::info!("\n\n Waiting on messages \n\n"); - let mut extra_messages = Vec::new(); - - loop { - tokio::select! { - // biased enforces an order to select!. If a message and a group are both ready - // at the same time, `biased` mode will process the message before the new - // group. - biased; - - messages = futures::future::ready(&mut extra_messages), if !extra_messages.is_empty() => { - for message in messages.drain(0..) { - yield message; - } - }, - Some(message) = messages_stream.next() => { - // an error can only mean the receiver has been dropped or closed so we're - // safe to end the stream - yield message; - } - Some(new_group) = convo_stream.next() => { - match new_group { - Ok(new_group) => { - tracing::info!("Received new conversation inside streamAllMessages"); - if group_list.contains_key(&new_group.group_id) { - continue; - } - for info in group_list.values_mut() { - info.cursor = 0; - } - group_list.insert( - new_group.group_id, - MessagesStreamInfo { - cursor: 1, // For the new group, stream all messages since the group was created - }, - ); - let new_messages_stream = match StreamGroupMessages::new(self, &group_list).await { - Ok(s) => s, - Err(e) => { - yield Err(e); - continue; - }, - }; - - tracing::debug!("switching streams"); - // attempt to drain all ready messages from existing stream - while let Some(Some(message)) = messages_stream.next().now_or_never() { - extra_messages.push(message); - } - messages_stream.set(new_messages_stream); - continue; - }, - Err(e) => { - yield Err(e) - } - } - }, - } - } - }; - Ok(stream) + StreamAllMessages::new(self, conversation_type).await } pub fn stream_all_messages_with_callback( client: Arc>, conversation_type: Option, - mut callback: impl FnMut(Result) + Send + 'static, - ) -> impl crate::StreamHandle> { + mut callback: impl FnMut(Result) + Send + 'static, + ) -> impl crate::StreamHandle> { let (tx, rx) = oneshot::channel(); crate::spawn(Some(rx), async move { @@ -502,8 +391,8 @@ where pub fn stream_consent_with_callback( client: Arc>, - mut callback: impl FnMut(Result, SubscribeError>) + Send + 'static, - ) -> impl crate::StreamHandle> { + mut callback: impl FnMut(Result>) + Send + 'static, + ) -> impl crate::StreamHandle> { let (tx, rx) = oneshot::channel(); crate::spawn(Some(rx), async move { @@ -516,14 +405,14 @@ where callback(message) } tracing::debug!("`stream_consent` stream ended, dropping stream"); - Ok::<_, ClientError>(()) + Ok::<_, SubscribeError>(()) }) } pub fn stream_preferences_with_callback( client: Arc>, - mut callback: impl FnMut(Result, SubscribeError>) + Send + 'static, - ) -> impl crate::StreamHandle> { + mut callback: impl FnMut(Result>) + Send + 'static, + ) -> impl crate::StreamHandle> { let (tx, rx) = oneshot::channel(); crate::spawn(Some(rx), async move { @@ -536,7 +425,7 @@ where callback(message) } tracing::debug!("`stream_consent` stream ended, dropping stream"); - Ok::<_, ClientError>(()) + Ok::<_, SubscribeError>(()) }) } } @@ -546,23 +435,6 @@ pub(crate) mod tests { #[cfg(target_arch = "wasm32")] wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker); - use crate::{ - builder::ClientBuilder, - groups::GroupMetadataOptions, - storage::{group::ConversationType, group_message::StoredGroupMessage}, - utils::test::{Delivery, TestClient}, - Client, StreamHandle, - }; - use futures::StreamExt; - use parking_lot::Mutex; - use std::sync::{ - atomic::{AtomicU64, Ordering}, - Arc, - }; - use wasm_bindgen_test::wasm_bindgen_test; - use xmtp_cryptography::utils::generate_local_wallet; - use xmtp_id::InboxOwner; - /// A macro for asserting that a stream yields a specific decrypted message. /// /// # Example @@ -602,282 +474,4 @@ pub(crate) mod tests { .is_empty()); }; } - - #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] - async fn test_stream_all_messages_unchanging_group_list() { - let alix = ClientBuilder::new_test_client(&generate_local_wallet()).await; - let bo = ClientBuilder::new_test_client(&generate_local_wallet()).await; - let caro = ClientBuilder::new_test_client(&generate_local_wallet()).await; - - let alix_group = alix - .create_group(None, GroupMetadataOptions::default()) - .unwrap(); - alix_group - .add_members_by_inbox_id(&[caro.inbox_id()]) - .await - .unwrap(); - - let bo_group = bo - .create_group(None, GroupMetadataOptions::default()) - .unwrap(); - bo_group - .add_members_by_inbox_id(&[caro.inbox_id()]) - .await - .unwrap(); - - let stream = caro.stream_all_messages(None).await.unwrap(); - futures::pin_mut!(stream); - bo_group.send_message(b"first").await.unwrap(); - assert_msg!(stream, "first"); - - bo_group.send_message(b"second").await.unwrap(); - assert_msg!(stream, "second"); - - alix_group.send_message(b"third").await.unwrap(); - assert_msg!(stream, "third"); - - bo_group.send_message(b"fourth").await.unwrap(); - assert_msg!(stream, "fourth"); - } - - #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] - async fn test_stream_all_messages_changing_group_list() { - let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); - let bo = ClientBuilder::new_test_client(&generate_local_wallet()).await; - let caro_wallet = generate_local_wallet(); - let caro = Arc::new(ClientBuilder::new_test_client(&caro_wallet).await); - - let alix_group = alix - .create_group(None, GroupMetadataOptions::default()) - .unwrap(); - alix_group - .add_members_by_inbox_id(&[caro.inbox_id()]) - .await - .unwrap(); - - let stream = caro.stream_all_messages(None).await.unwrap(); - futures::pin_mut!(stream); - tracing::info!("\n\nSENDING FIRST MESSAGE\n\n"); - - alix_group.send_message(b"first").await.unwrap(); - assert_msg!(stream, "first"); - - let bo_group = bo.create_dm(caro_wallet.get_address()).await.unwrap(); - assert_msg_exists!(stream); - - bo_group.send_message(b"second").await.unwrap(); - assert_msg!(stream, "second"); - - alix_group.send_message(b"third").await.unwrap(); - assert_msg!(stream, "third"); - - let alix_group_2 = alix - .create_group(None, GroupMetadataOptions::default()) - .unwrap(); - alix_group_2 - .add_members_by_inbox_id(&[caro.inbox_id()]) - .await - .unwrap(); - - alix_group.send_message(b"fourth").await.unwrap(); - assert_msg!(stream, "fourth"); - - alix_group_2.send_message(b"fifth").await.unwrap(); - assert_msg!(stream, "fifth"); - } - - #[ignore] - #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread"))] - async fn test_stream_all_messages_does_not_lose_messages() { - let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); - let caro = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); - - let alix_group = alix - .create_group(None, GroupMetadataOptions::default()) - .unwrap(); - alix_group - .add_members_by_inbox_id(&[caro.inbox_id()]) - .await - .unwrap(); - - let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); - let messages_clone = messages.clone(); - - let blocked = Arc::new(AtomicU64::new(55)); - - let blocked_pointer = blocked.clone(); - let mut handle = Client::::stream_all_messages_with_callback( - caro.clone(), - None, - move |message| { - (*messages_clone.lock()).push(message.unwrap()); - blocked_pointer.fetch_sub(1, Ordering::SeqCst); - }, - ); - handle.wait_for_ready().await; - - let alix_group_pointer = alix_group.clone(); - crate::spawn(None, async move { - for _ in 0..50 { - alix_group_pointer.send_message(b"spam").await.unwrap(); - xmtp_common::time::sleep(core::time::Duration::from_micros(200)).await; - } - }); - - for _ in 0..5 { - let new_group = alix - .create_group(None, GroupMetadataOptions::default()) - .unwrap(); - new_group - .add_members_by_inbox_id(&[caro.inbox_id()]) - .await - .unwrap(); - new_group - .send_message(b"spam from new group") - .await - .unwrap(); - } - - let _ = tokio::time::timeout(core::time::Duration::from_secs(120), async { - while blocked.load(Ordering::SeqCst) > 0 { - tokio::task::yield_now().await; - } - }) - .await; - - let missed_messages = blocked.load(Ordering::SeqCst); - if missed_messages > 0 { - println!("Missed {} Messages", missed_messages); - panic!("Test failed due to missed messages"); - } - } - - #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread"))] - async fn test_dm_stream_all_messages() { - let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); - let bo = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); - - let alix_group = alix - .create_group(None, GroupMetadataOptions::default()) - .unwrap(); - alix_group - .add_members_by_inbox_id(&[bo.inbox_id()]) - .await - .unwrap(); - - let alix_dm = alix - .create_dm_by_inbox_id(bo.inbox_id().to_string()) - .await - .unwrap(); - - // Start a stream with only groups - let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); - // Wait for 2 seconds for the group creation to be streamed - let notify = Delivery::new(Some(1)); - let (notify_pointer, messages_pointer) = (notify.clone(), messages.clone()); - - let mut closer = Client::::stream_all_messages_with_callback( - bo.clone(), - Some(ConversationType::Group), - move |message| { - let mut messages: parking_lot::lock_api::MutexGuard< - '_, - parking_lot::RawMutex, - Vec, - > = messages_pointer.lock(); - messages.push(message.unwrap()); - notify_pointer.notify_one(); - }, - ); - closer.wait_for_ready().await; - - alix_dm.send_message("first".as_bytes()).await.unwrap(); - - let result = notify.wait_for_delivery().await; - assert!(result.is_err(), "Stream unexpectedly received a DM message"); - - alix_group.send_message("second".as_bytes()).await.unwrap(); - - notify.wait_for_delivery().await.unwrap(); - { - let msgs = messages.lock(); - assert_eq!(msgs.len(), 1); - } - - closer.end(); - - // Start a stream with only dms - let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); - // Wait for 2 seconds for the group creation to be streamed - let notify = Delivery::new(Some(1)); - let (notify_pointer, messages_pointer) = (notify.clone(), messages.clone()); - - let mut closer = Client::::stream_all_messages_with_callback( - bo.clone(), - Some(ConversationType::Dm), - move |message| { - let mut messages: parking_lot::lock_api::MutexGuard< - '_, - parking_lot::RawMutex, - Vec, - > = messages_pointer.lock(); - messages.push(message.unwrap()); - notify_pointer.notify_one(); - }, - ); - closer.wait_for_ready().await; - - alix_group.send_message("first".as_bytes()).await.unwrap(); - - let result = notify.wait_for_delivery().await; - assert!( - result.is_err(), - "Stream unexpectedly received a Group message" - ); - - alix_dm.send_message("second".as_bytes()).await.unwrap(); - - notify.wait_for_delivery().await.unwrap(); - { - let msgs = messages.lock(); - assert_eq!(msgs.len(), 1); - } - - closer.end(); - - // Start a stream with all conversations - let messages: Arc>> = Arc::new(Mutex::new(Vec::new())); - // Wait for 2 seconds for the group creation to be streamed - let notify = Delivery::new(Some(1)); - let (notify_pointer, messages_pointer) = (notify.clone(), messages.clone()); - - let mut closer = Client::::stream_all_messages_with_callback( - bo.clone(), - None, - move |message| { - let mut messages = messages_pointer.lock(); - messages.push(message.unwrap()); - notify_pointer.notify_one(); - }, - ); - closer.wait_for_ready().await; - - alix_group.send_message("first".as_bytes()).await.unwrap(); - - notify.wait_for_delivery().await.unwrap(); - { - let msgs = messages.lock(); - assert_eq!(msgs.len(), 1); - } - - alix_dm.send_message("second".as_bytes()).await.unwrap(); - - notify.wait_for_delivery().await.unwrap(); - { - let msgs = messages.lock(); - assert_eq!(msgs.len(), 2); - } - - closer.end(); - } } diff --git a/xmtp_mls/src/subscriptions/stream_all.rs b/xmtp_mls/src/subscriptions/stream_all.rs index a14d45434..6953f7f42 100644 --- a/xmtp_mls/src/subscriptions/stream_all.rs +++ b/xmtp_mls/src/subscriptions/stream_all.rs @@ -1,48 +1,45 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::HashMap, + pin::Pin, + task::{Context, Poll}, +}; +use crate::subscriptions::stream_messages::{MessagePositionCursor, MessagesApiSubscription}; use crate::{ - client::ClientError, - groups::scoped_client::ScopedGroupClient, - groups::subscriptions, + groups::{scoped_client::ScopedGroupClient, MlsGroup}, storage::{ group::{ConversationType, GroupQueryArgs}, group_message::StoredGroupMessage, }, Client, }; -use futures::{ - stream::{self, Stream, StreamExt}, - Future, -}; +use futures::{stream::Stream, Future}; use xmtp_id::scw_verifier::SmartContractSignatureVerifier; use xmtp_proto::api_client::{trait_impls::XmtpApi, XmtpMlsStreams}; use super::{ stream_conversations::{StreamConversations, WelcomesApiSubscription}, - FutureWrapper, MessagesStreamInfo, SubscribeError, + stream_messages::StreamGroupMessages, + FutureWrapper, Result, SubscribeError, }; -pub struct StreamAllMessages<'a, C, Welcomes, Messages> { - /// The monolithic XMTP Client - client: &'a C, - /// Type of conversation to stream - conversation_type: Option, - /// Conversations that are being actively streamed - active_conversations: HashMap, MessagesStreamInfo>, - /// Welcomes Stream - welcomes: Welcomes, - /// Messages Stream - messages: Messages, - /// Extra messages from message stream, when the stream switches because - /// of a new group received. - extra_messages: Vec, +use pin_project_lite::pin_project; + +pin_project! { + pub(super) struct StreamAllMessages<'a, C, Conversations, Messages> { + #[pin] conversations: Conversations, + #[pin] messages: Messages, + #[pin] state: SwitchState<'a, Messages>, + client: &'a C, + conversation_type: Option, + } } -impl<'a, A, V, Messages> +impl<'a, A, V> StreamAllMessages< 'a, Client, - StreamConversations<'a, Client, WelcomesApiSubscription<'a, A>>, - FutureWrapper<'a, Result>, + StreamConversations<'a, Client, WelcomesApiSubscription<'a, Client>>, + StreamGroupMessages<'a, Client, MessagesApiSubscription<'a, Client>>, > where A: XmtpApi + XmtpMlsStreams + Send + Sync + 'static, @@ -51,8 +48,8 @@ where pub async fn new( client: &'a Client, conversation_type: Option, - ) -> Result { - let mut active_conversations = async { + ) -> Result { + let active_conversations = async { let provider = client.mls_provider()?; client.sync_welcomes(&provider).await?; @@ -61,15 +58,13 @@ where .find_groups(GroupQueryArgs::default().maybe_conversation_type(conversation_type))? .into_iter() .map(Into::into) - .collect::, MessagesStreamInfo>>(); - Ok::<_, ClientError>(active_conversations) + .collect::, MessagePositionCursor>>(); + Ok::<_, SubscribeError>(active_conversations) } .await?; - let messages = - subscriptions::stream_messages(client, Arc::new(active_conversations.clone())).await?; - let messages = FutureWrapper::new(messages); - let welcomes = super::stream_conversations::StreamConversations::new( + let messages = StreamGroupMessages::new(client, active_conversations).await?; + let conversations = super::stream_conversations::StreamConversations::new( client, conversation_type.clone(), ) @@ -79,23 +74,367 @@ where client, conversation_type, messages, - welcomes, - active_conversations, - extra_messages: Vec::new(), + conversations, + state: Default::default(), }) } } -impl<'a, C, Welcomes, Messages> Stream for StreamAllMessages<'a, C, Welcomes, Messages> +pin_project! { + #[project = SwitchProject] + #[derive(Default)] + enum SwitchState<'a, Out> { + /// State that indicates the stream is waiting on the next message from the network + #[default] + Waiting, + /// state that indicates the stream is waiting on a IO/Network future to finish processing + /// the current message before moving on to the next one + Switching { + #[pin] future: FutureWrapper<'a, Result> + } + } +} + +impl std::fmt::Debug for SwitchState<'_, Out> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SwitchState::Waiting => write!(f, "waiting"), + SwitchState::Switching { .. } => write!(f, "switching"), + } + } +} + +impl<'a, C, Conversations> Stream + for StreamAllMessages< + 'a, + C, + Conversations, + StreamGroupMessages<'a, C, MessagesApiSubscription<'a, C>>, + > +where + C: ScopedGroupClient + Clone + 'a, + ::ApiClient: XmtpApi + XmtpMlsStreams + 'a, + Conversations: Stream>>, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // tracing::debug!("POLLING STREAM ALL"); + use std::task::Poll::*; + use SwitchProject::*; + let this = self.as_mut().project(); + + let state = this.state.project(); + match state { + Waiting => { + if let Ready(msg) = this.messages.poll_next(cx) { + return Ready(msg); + } + if let Ready(Some(Ok(group))) = this.conversations.poll_next(cx) { + self.as_mut().begin_switch_stream(group); + tracing::trace!("stream all state = {:?}", self.state); + } + cx.waker().wake_by_ref(); + Pending + } + // switching message streams + Switching { future } => match future.poll(cx) { + Ready(Ok(stream)) => { + self.as_mut().end_switch_stream(stream, cx); + tracing::trace!("stream all state state = {:?}", self.state); + Pending + } + Ready(Err(e)) => { + tracing::error!("Error swapping message stream in StreamAllMessages {}", e); + Ready(Some(Err(e))) + } + Pending => { + cx.waker().wake_by_ref(); + Pending + } + }, + } + } +} + +impl<'a, C, Conversations> + StreamAllMessages< + 'a, + C, + Conversations, + StreamGroupMessages<'a, C, MessagesApiSubscription<'a, C>>, + > where - C: ScopedGroupClient, + C: ScopedGroupClient + Clone + 'a, + ::ApiClient: XmtpApi + XmtpMlsStreams + 'a, + Conversations: Stream>>, { - type Item = Result; + /// Polls groups + /// if groups are available, the stream starts waiting for the future to switch message + /// streams. + fn begin_switch_stream(mut self: Pin<&mut Self>, new_group: MlsGroup) { + if self.messages.group_list().contains_key(&new_group.group_id) { + return; + } + + tracing::trace!( + inbox_id = self.client.inbox_id(), + installation_id = %self.client.installation_id(), + group_id = hex::encode(&new_group.group_id), + "begin establishing new message stream to include group_id={}", + hex::encode(&new_group.group_id) + ); + + let mut conversations = self.messages.group_list().clone(); + conversations.insert(new_group.group_id, 1.into()); + + let future = StreamGroupMessages::new(self.client, conversations); + let mut this = self.as_mut().project(); + this.state.set(SwitchState::Switching { + future: FutureWrapper::new(future), + }); + } + + fn end_switch_stream( + mut self: Pin<&mut Self>, + stream: StreamGroupMessages<'a, C, MessagesApiSubscription<'a, C>>, + cx: &mut Context<'_>, + ) { + let mut this = self.as_mut().project(); + // drain the stream + // if we don't drain the stream, we inadvertantly create a zombie stream + // that freezes the executor + // Not entirely certain why it happens, but i assume gRPC does not like closing the stream + // because we have unread items in queue. + // We can throw away the drained messages, because we set the cursor for the stream + // before these messages were received + this.messages.as_mut().drain(cx); + this.messages.set(stream); + this.state.as_mut().set(SwitchState::Waiting); + // TODO: take old group list and .diff with new, to check which group is new + // for log msg. + tracing::trace!("established new stream"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(target_arch = "wasm32")] + wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker); + + use std::sync::Arc; + + use crate::{assert_msg, builder::ClientBuilder, groups::GroupMetadataOptions}; + use xmtp_cryptography::utils::generate_local_wallet; + use xmtp_id::InboxOwner; + + use futures::StreamExt; + use wasm_bindgen_test::wasm_bindgen_test; + + #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] + async fn test_stream_all_messages_changing_group_list() { + let alix = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let bo = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let caro_wallet = generate_local_wallet(); + let caro = ClientBuilder::new_test_client(&caro_wallet).await; + + let alix_group = alix + .create_group(None, GroupMetadataOptions::default()) + .unwrap(); + alix_group + .add_members_by_inbox_id(&[caro.inbox_id()]) + .await + .unwrap(); + + let stream = caro.stream_all_messages(None).await.unwrap(); + futures::pin_mut!(stream); + + alix_group.send_message(b"first").await.unwrap(); + assert_msg!(stream, "first"); + let bo_group = bo.create_dm(caro_wallet.get_address()).await.unwrap(); + + bo_group.send_message(b"second").await.unwrap(); + assert_msg!(stream, "second"); + + alix_group.send_message(b"third").await.unwrap(); + assert_msg!(stream, "third"); + + let alix_group_2 = alix + .create_group(None, GroupMetadataOptions::default()) + .unwrap(); + alix_group_2 + .add_members_by_inbox_id(&[caro.inbox_id()]) + .await + .unwrap(); + + alix_group.send_message(b"fourth").await.unwrap(); + assert_msg!(stream, "fourth"); + + alix_group_2.send_message(b"fifth").await.unwrap(); + assert_msg!(stream, "fifth"); + } + + #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] + async fn test_stream_all_messages_unchanging_group_list() { + let alix = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let bo = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let caro = ClientBuilder::new_test_client(&generate_local_wallet()).await; + + let alix_group = alix + .create_group(None, GroupMetadataOptions::default()) + .unwrap(); + alix_group + .add_members_by_inbox_id(&[caro.inbox_id()]) + .await + .unwrap(); + + let bo_group = bo + .create_group(None, GroupMetadataOptions::default()) + .unwrap(); + bo_group + .add_members_by_inbox_id(&[caro.inbox_id()]) + .await + .unwrap(); + + let stream = caro.stream_all_messages(None).await.unwrap(); + futures::pin_mut!(stream); + bo_group.send_message(b"first").await.unwrap(); + assert_msg!(stream, "first"); + + bo_group.send_message(b"second").await.unwrap(); + assert_msg!(stream, "second"); + + alix_group.send_message(b"third").await.unwrap(); + assert_msg!(stream, "third"); + + bo_group.send_message(b"fourth").await.unwrap(); + assert_msg!(stream, "fourth"); + } + + #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread"))] + async fn test_dm_stream_all_messages() { + let alix = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let bo = ClientBuilder::new_test_client(&generate_local_wallet()).await; + + let alix_group = alix + .create_group(None, GroupMetadataOptions::default()) + .unwrap(); + alix_group + .add_members_by_inbox_id(&[bo.inbox_id()]) + .await + .unwrap(); + + let alix_dm = alix + .create_dm_by_inbox_id(&alix.mls_provider().unwrap(), bo.inbox_id().to_string()) + .await + .unwrap(); + + // start a stream with only group messages + let stream = bo + .stream_all_messages(Some(ConversationType::Group)) + .await + .unwrap(); + futures::pin_mut!(stream); + alix_dm.send_message("first".as_bytes()).await.unwrap(); + alix_group.send_message("second".as_bytes()).await.unwrap(); + assert_msg!(stream, "second"); + + // Start a stream with only dms + // Wait for 2 seconds for the group creation to be streamed + let stream = bo + .stream_all_messages(Some(ConversationType::Dm)) + .await + .unwrap(); + futures::pin_mut!(stream); + alix_group.send_message("first".as_bytes()).await.unwrap(); + alix_dm.send_message("second".as_bytes()).await.unwrap(); + assert_msg!(stream, "second"); + + // Start a stream with all conversations + // Wait for 2 seconds for the group creation to be streamed + let stream = bo.stream_all_messages(None).await.unwrap(); + futures::pin_mut!(stream); + alix_group.send_message("first".as_bytes()).await.unwrap(); + assert_msg!(stream, "first"); + + alix_dm.send_message("second".as_bytes()).await.unwrap(); + assert_msg!(stream, "second"); + } + + #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] + async fn test_stream_all_messages_does_not_lose_messages() { + let caro = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); + let eve = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); + tracing::info!(inbox_id = eve.inbox_id(), installation_id = %eve.installation_id(), "EVE"); + + let alix_group = alix + .create_group(None, GroupMetadataOptions::default()) + .unwrap(); + alix_group + .add_members_by_inbox_id(&[caro.inbox_id()]) + .await + .unwrap(); + + let stream = caro.stream_all_messages(None).await.unwrap(); + + let alix_group_pointer = alix_group.clone(); + crate::spawn(None, async move { + let mut sent = 0; + for _ in 0..50 { + alix_group_pointer.send_message(b"spam").await.unwrap(); + sent += 1; + xmtp_common::time::sleep(core::time::Duration::from_micros(100)).await; + tracing::info!("sent {sent}"); + } + }); + + // Eve will try to break our stream by creating lots of groups + // and immediately sending a message + // this forces our streams to re-subscribe + let caro_id = caro.inbox_id().to_string(); + crate::spawn(None, async move { + let caro = &caro_id; + for i in 0..50 { + let new_group = eve + .create_group(None, GroupMetadataOptions::default()) + .unwrap(); + new_group.add_members_by_inbox_id(&[caro]).await.unwrap(); + tracing::info!("\n\n EVE SENDING {i} \n\n"); + new_group + .send_message(b"spam from new group") + .await + .unwrap(); + } + }); + + let mut messages = Vec::new(); + let _ = tokio::time::timeout(core::time::Duration::from_secs(60), async { + futures::pin_mut!(stream); + loop { + if messages.len() < 100 { + if let Some(Ok(msg)) = stream.next().await { + tracing::info!( + message_id = hex::encode(&msg.id), + sender_inbox_id = msg.sender_inbox_id, + sender_installation_id = hex::encode(&msg.sender_installation_id), + group_id = hex::encode(&msg.group_id), + "GOT MESSAGE {}, text={}", + messages.len(), + String::from_utf8_lossy(msg.decrypted_message_bytes.as_slice()) + ); + messages.push(msg) + } + } else { + break; + } + } + }) + .await; - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - todo!() + tracing::info!("Total Messages: {}", messages.len()); + assert_eq!(messages.len(), 100); } } diff --git a/xmtp_mls/src/subscriptions/stream_conversations.rs b/xmtp_mls/src/subscriptions/stream_conversations.rs index 64eca085e..f9ef430b0 100644 --- a/xmtp_mls/src/subscriptions/stream_conversations.rs +++ b/xmtp_mls/src/subscriptions/stream_conversations.rs @@ -19,7 +19,7 @@ use xmtp_proto::{ xmtp::mls::api::v1::{welcome_message, WelcomeMessage}, }; -use super::{temp::Result, FutureWrapper, LocalEvents, SubscribeError}; +use super::{FutureWrapper, LocalEvents, Result, SubscribeError}; #[derive(thiserror::Error, Debug)] pub enum ConversationStreamError { @@ -63,7 +63,6 @@ impl Stream for BroadcastGroupStream { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use std::task::Poll::*; let this = self.project(); - match this.inner.poll_next(cx) { Ready(Some(event)) => { let ev = xmtp_common::optify!(event, "Missed messages due to event queue lag") @@ -158,10 +157,11 @@ impl<'a, O> Default for ProcessState<'a, O> { type MultiplexedSelect = Select>; -pub(super) type WelcomesApiSubscription<'a, A> = - MultiplexedSelect<::WelcomeMessageStream<'a>>; +pub(super) type WelcomesApiSubscription<'a, C> = MultiplexedSelect< + <::ApiClient as XmtpMlsStreams>::WelcomeMessageStream<'a>, +>; -impl<'a, A, V> StreamConversations<'a, Client, WelcomesApiSubscription<'a, A>> +impl<'a, A, V> StreamConversations<'a, Client, WelcomesApiSubscription<'a, Client>> where A: XmtpApi + XmtpMlsStreams + Send + Sync + 'static, V: SmartContractSignatureVerifier + Send + Sync + 'static, @@ -214,7 +214,6 @@ where ) -> std::task::Poll> { use std::task::Poll::*; let mut this = self.as_mut().project(); - match this.state.as_mut().project() { ProcessProject::Waiting => { match this.inner.poll_next(cx) { @@ -255,7 +254,10 @@ where Pending } Ready(Err(e)) => Ready(Some(Err(e))), - Pending => Pending, + Pending => { + cx.waker().wake_by_ref(); + Pending + } }, } } @@ -270,10 +272,15 @@ fn extract_welcome_message<'a>(welcome: &'a WelcomeMessage) -> Result<&'a welcom /// Future for processing `WelcomeorGroup` pub struct ProcessWelcomeFuture { + /// welcome ids in DB and which are already processed known_welcome_ids: HashSet, + /// The libxmtp client client: Client, + /// the welcome or group being processed in this future item: WelcomeOrGroup, + /// the xmtp mls provider provider: XmtpOpenMlsProvider, + /// Conversation type to filter for, if any. conversation_type: Option, } @@ -371,7 +378,9 @@ where if let Err(e) = group { // try to load it from the store again - return self.load_from_store(id).map_err(|_| SubscribeError::from(e)); + return self + .load_from_store(id) + .map_err(|_| SubscribeError::from(e)); } Ok((group?, id)) @@ -404,7 +413,7 @@ mod test { use super::*; use crate::builder::ClientBuilder; use crate::groups::GroupMetadataOptions; - use crate::subscriptions::GroupQueryArgs; + use crate::storage::group::GroupQueryArgs; use futures::StreamExt; use wasm_bindgen_test::wasm_bindgen_test; diff --git a/xmtp_mls/src/subscriptions/stream_messages.rs b/xmtp_mls/src/subscriptions/stream_messages.rs index 433b39c91..5d38a6ed0 100644 --- a/xmtp_mls/src/subscriptions/stream_messages.rs +++ b/xmtp_mls/src/subscriptions/stream_messages.rs @@ -1,10 +1,18 @@ -use std::{collections::HashMap, future::Future, pin::Pin, task::Poll}; +use std::{ + collections::HashMap, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; -use super::{temp::Result, FutureWrapper, SubscribeError}; +use super::{FutureWrapper, Result, SubscribeError}; use crate::{ api::GroupFilter, groups::{scoped_client::ScopedGroupClient, MlsGroup}, - storage::{group_message::StoredGroupMessage, refresh_state::EntityKind, StorageError}, + storage::{ + group::StoredGroup, group_message::StoredGroupMessage, refresh_state::EntityKind, + StorageError, + }, XmtpOpenMlsProvider, }; use futures::Stream; @@ -40,20 +48,56 @@ fn extract_message_v1(message: GroupMessage) -> Result { } } -type GroupId = Vec; +pub(super) type GroupId = Vec; + +/// the position of this message in the backend topic +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct MessagePositionCursor(u64); + +impl MessagePositionCursor { + pub(super) fn set(&mut self, cursor: u64) { + self.0 = cursor; + } +} + +impl std::fmt::Display for MessagePositionCursor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for (Vec, u64) { + fn from(group: StoredGroup) -> (Vec, u64) { + (group.id, 0u64) + } +} + +impl From for (Vec, MessagePositionCursor) { + fn from(group: StoredGroup) -> (Vec, MessagePositionCursor) { + (group.id, 0u64.into()) + } +} -/// Information about the current position -/// in a stream of messages from a single group. -#[derive(Debug)] -pub(crate) struct MessagesStreamInfo { - pub cursor: u64, +impl std::ops::Deref for MessagePositionCursor { + type Target = u64; + + fn deref(&self) -> &u64 { + &self.0 + } +} + +impl From for MessagePositionCursor { + fn from(v: u64) -> MessagePositionCursor { + Self(v) + } } pin_project! { pub struct StreamGroupMessages<'a, C, Subscription> { #[pin] inner: Subscription, + #[pin] state: ProcessState<'a>, client: &'a C, - #[pin] state: ProcessState<'a> + group_list: HashMap, } } @@ -67,7 +111,7 @@ pin_project! { /// state that indicates the stream is waiting on a IO/Network future to finish processing /// the current message before moving on to the next one Processing { - #[pin] future: FutureWrapper<'a, Result> + #[pin] future: FutureWrapper<'a, Result<(StoredGroupMessage, u64)>> } } } @@ -82,11 +126,11 @@ where { pub async fn new( client: &'a C, - group_list: &HashMap, + group_list: HashMap, ) -> Result { let filters: Vec = group_list .iter() - .map(|(group_id, info)| GroupFilter::new(group_id.clone(), Some(info.cursor))) + .map(|(group_id, cursor)| GroupFilter::new(group_id.clone(), Some(**cursor))) .collect(); let subscription = client.api().subscribe_group_messages(filters).await?; @@ -94,6 +138,7 @@ where inner: subscription, client, state: Default::default(), + group_list: group_list.into_iter().map(|(g, c)| (g, c.into())).collect(), }) } } @@ -105,10 +150,8 @@ where { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // tracing::debug!("POLLING STREAM MESSAGES"); use std::task::Poll::*; use ProcessProject::*; let mut this = self.as_mut().project(); @@ -116,6 +159,7 @@ where match this.state.as_mut().project() { Waiting => match this.inner.poll_next(cx) { Ready(Some(envelope)) => { + tracing::debug!("processing message in stream"); let future = ProcessMessageFuture::new(*this.client, envelope?)?; let future = future.process(); this.state.set(ProcessState::Processing { @@ -131,8 +175,14 @@ where Ready(None) => Ready(None), }, Processing { future } => match future.poll(cx) { - Ready(Ok(msg)) => { + Ready(Ok((msg, new_cursor))) => { this.state.set(ProcessState::Waiting); + if let Some(tracked_cursor) = this.group_list.get_mut(&msg.group_id) { + tracked_cursor.set(new_cursor) + } else { + this.group_list + .insert(msg.group_id.clone(), new_cursor.into()); + } Ready(Some(Ok(msg))) } // skip if payload GroupMessageNotFound @@ -152,6 +202,29 @@ where } } +impl<'a, C, S> StreamGroupMessages<'a, C, S> { + pub(super) fn group_list(&self) -> &HashMap { + &self.group_list + } +} + +impl<'a, C, S> StreamGroupMessages<'a, C, S> +where + S: Stream> + 'a, + C: ScopedGroupClient + 'a, +{ + pub(super) fn drain( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Vec>> { + let mut drained = Vec::new(); + while let Poll::Ready(msg) = self.as_mut().poll_next(cx) { + drained.push(msg); + } + drained + } +} + /// Future that processes a group message from the network pub struct ProcessMessageFuture { provider: XmtpOpenMlsProvider, @@ -184,9 +257,11 @@ where self.client.inbox_id() } - pub(crate) async fn process(self) -> Result { + /// process a message, returning the message from the database and the cursor of the message. + pub(crate) async fn process(self) -> Result<(StoredGroupMessage, u64)> { let group_message::V1 { - id: ref msg_id, + // the cursor ID is the position in the monolithic backend topic + id: ref cursor_id, ref created_ns, .. } = self.msg; @@ -194,13 +269,13 @@ where tracing::info!( inbox_id = self.inbox_id(), group_id = hex::encode(&self.msg.group_id), - msg_id, + cursor_id, "client [{}] is about to process streamed envelope: [{}]", self.inbox_id(), - &msg_id + &cursor_id ); - if !self.has_already_synced(*msg_id).await? { + if self.needs_to_sync(*cursor_id).await? { self.process_stream_entry().await } @@ -214,14 +289,14 @@ where .inspect_err(|e| { if matches!(e, SubscribeError::GroupMessageNotFound) { tracing::warn!( - msg_id, + cursor_id, inbox_id = self.inbox_id(), group_id = hex::encode(&self.msg.group_id), "group message not found" ); } })?; - return Ok(new_message); + return Ok((new_message, *cursor_id)); } /// stream processing function @@ -235,7 +310,7 @@ where tracing::info!( inbox_id = self.inbox_id(), group_id = hex::encode(&self.msg.group_id), - msg_id = self.msg.id, + cursor_id = self.msg.id, "current epoch for [{}] in process_stream_entry()", self.inbox_id(), ); @@ -255,14 +330,14 @@ where tracing::error!( inbox_id = self.client.inbox_id(), group_id = hex::encode(&self.msg.group_id), - msg_id = self.msg.id, + cursor_id = self.msg.id, err = e.to_string(), "process stream entry {:?}", e ); } else { tracing::trace!( - msg_id = self.msg.id, + cursor_id = self.msg.id, inbox_id = self.inbox_id(), group_id = hex::encode(&self.msg.group_id), "message process in stream success" @@ -270,16 +345,16 @@ where } } - // Checks if a message has already been processed through a sync - async fn has_already_synced(&self, id: u64) -> Result { + /// Checks if a message has already been processed through a sync + async fn needs_to_sync(&self, current_msg_cursor: u64) -> Result { let check_for_last_cursor = || -> std::result::Result { self.provider .conn_ref() .get_last_cursor_for_id(&self.msg.group_id, EntityKind::Group) }; - let last_id = retry_async!(Retry::default(), (async { check_for_last_cursor() }))?; - Ok(last_id >= id as i64) + let last_synced_id = retry_async!(Retry::default(), (async { check_for_last_cursor() }))?; + Ok(last_synced_id < current_msg_cursor as i64) } /// Attempt a recovery sync if a group message failed to process @@ -292,7 +367,7 @@ where tracing::debug!( inbox_id = self.client.inbox_id(), group_id = hex::encode(&self.msg.group_id), - msg_id = self.msg.id, + cursor_id = self.msg.id, "attempting recovery sync" ); // Swallow errors here, since another process may have successfully saved the message @@ -301,7 +376,7 @@ where tracing::warn!( inbox_id = self.client.inbox_id(), group_id = hex::encode(&self.msg.group_id), - msg_id = self.msg.id, + cursor_id = self.msg.id, err = %err, "recovery sync triggered by streamed message failed: {}", err ); @@ -309,7 +384,7 @@ where tracing::debug!( inbox_id = self.client.inbox_id(), group_id = hex::encode(&self.msg.group_id), - msg_id = self.msg.id, + cursor_id = self.msg.id, "recovery sync triggered by streamed message successful" ) } @@ -323,6 +398,7 @@ mod tests { use futures::stream::StreamExt; use wasm_bindgen_test::wasm_bindgen_test; + use crate::assert_msg; use crate::{builder::ClientBuilder, groups::GroupMetadataOptions}; use xmtp_cryptography::utils::generate_local_wallet; @@ -354,11 +430,9 @@ mod tests { // implicitly skips the first message (add bob to group message) // since that is an epoch increment. - let message = stream.next().await.unwrap().unwrap(); - assert_eq!(message.decrypted_message_bytes, b"hello"); + assert_msg!(stream, "hello"); bob_group.send_message(b"hello2").await.unwrap(); - let message = stream.next().await.unwrap().unwrap(); - assert_eq!(message.decrypted_message_bytes, b"hello2"); + assert_msg!(stream, "hello2"); } }