diff --git a/Cargo.lock b/Cargo.lock index 83826030c..99009e7e1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -327,7 +327,7 @@ dependencies = [ "rustversion", "serde", "sync_wrapper 1.0.2", - "tower 0.5.1", + "tower 0.5.2", "tower-layer", "tower-service", ] @@ -4817,7 +4817,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.9" +version = "0.12.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" dependencies = [ "base64 0.22.1", "bytes", @@ -4849,8 +4851,8 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-util", + "tower 0.5.2", "tower-service", - "tracing", "url", "wasm-bindgen", "wasm-bindgen-futures", @@ -6161,14 +6163,15 @@ dependencies = [ [[package]] name = "tower" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", "pin-project-lite", - "sync_wrapper 0.1.2", + "sync_wrapper 1.0.2", + "tokio", "tower-layer", "tower-service", ] @@ -6906,10 +6909,11 @@ dependencies = [ [[package]] name = "wasm-streams" version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" dependencies = [ "futures-util", "js-sys", - "tracing", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", @@ -7333,7 +7337,7 @@ dependencies = [ "bytes", "futures", "pin-project-lite", - "reqwest 0.12.9", + "reqwest 0.12.12", "serde", "serde_json", "thiserror 2.0.6", @@ -7513,7 +7517,7 @@ dependencies = [ "pin-project-lite", "prost", "rand", - "reqwest 0.12.9", + "reqwest 0.12.12", "serde", "serde_json", "sha2 0.10.8", diff --git a/Cargo.toml b/Cargo.toml index 2939b6399..69b6c8acc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -158,6 +158,4 @@ diesel = { git = "https://github.com/diesel-rs/diesel", branch = "master" } diesel_derives = { git = "https://github.com/diesel-rs/diesel", branch = "master" } diesel_migrations = { git = "https://github.com/diesel-rs/diesel", branch = "master" } sqlite-web = { git = "https://github.com/xmtp/sqlite-web-rs", branch = "main" } -reqwest = { path = "/Users/insipx/Projects/seanmonster/reqwest" } -wasm-streams = { path = "/Users/insipx/Projects/wasm-streams" } diff --git a/common/src/test.rs b/common/src/test.rs index 802f01fbc..215e682f3 100644 --- a/common/src/test.rs +++ b/common/src/test.rs @@ -73,9 +73,9 @@ pub fn logger() { use tracing_subscriber::EnvFilter; INIT.get_or_init(|| { - let filter = EnvFilter::builder() - .parse_lossy("xmtp_mls::subscriptions=TRACE,xmtp_api_http=TRACE,xmtp_common=TRACE,wasm_streams=TRACE,reqwest=TRACE"); - // .with_default_directive(tracing::metadata::LevelFilter::DEBUG.into()) + let filter = EnvFilter::builder().parse_lossy("xmtp_mls::subscriptions=debug"); + // .parse_lossy("xmtp_mls::subscriptions=TRACE,xmtp_api_http=TRACE,xmtp_common=TRACE,wasm_streams=TRACE,reqwest=TRACE"); + // .with_default_directive(tracing::metadata::LevelFilter::TRACE.into()); tracing_subscriber::registry() .with(tracing_wasm::WASMLayer::default()) diff --git a/common/src/wasm.rs b/common/src/wasm.rs index 8c2724bc6..4c02b3554 100644 --- a/common/src/wasm.rs +++ b/common/src/wasm.rs @@ -1,5 +1,5 @@ -use std::{pin::Pin, task::Poll, future::Future}; -use futures::{Stream, FutureExt, StreamExt}; +use futures::{FutureExt, Stream, StreamExt}; +use std::{future::Future, pin::Pin, task::Poll}; #[cfg(target_arch = "wasm32")] use wasm_bindgen::prelude::*; @@ -23,7 +23,10 @@ pub struct StreamWrapper<'a, I> { impl<'a, I> Stream for StreamWrapper<'a, I> { type Item = I; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Option<Self::Item>> { let inner = &mut self.inner; futures::pin_mut!(inner); inner.as_mut().poll_next(cx) @@ -106,19 +109,3 @@ pub async fn yield_() { pub async fn yield_() { crate::time::sleep(crate::time::Duration::from_millis(100)).await; } - -#[cfg(target_arch = "wasm32")] -mod inner { - use super::*; - - #[wasm_bindgen] - extern "C" { - #[wasm_bindgen (extends = js_sys::Object, js_name = Scheduler, typescript_type = "Scheduler")] - pub type Scheduler; - - #[wasm_bindgen(method, structural, js_class = "Scheduler", js_name = yield)] - pub fn r#yield(this: &Scheduler) -> js_sys::Promise; - } -} -#[cfg(target_arch = "wasm32")] -use inner::*; diff --git a/dev/test-wasm-interactive b/dev/test-wasm-interactive index 81980bdda..51e1535bf 100755 --- a/dev/test-wasm-interactive +++ b/dev/test-wasm-interactive @@ -15,4 +15,4 @@ WASM_BINDGEN_SPLIT_LINKED_MODULES=1 \ WASM_BINDGEN_TEST_ONLY_WEB=1 \ NO_HEADLESS=1 \ cargo test --target wasm32-unknown-unknown --release \ - -p $PACKAGE -- subscriptions::stream_conversations::test::test_stream_welcomes + -p $PACKAGE -- subscriptions:: diff --git a/xmtp_api_http/src/http_stream.rs b/xmtp_api_http/src/http_stream.rs index ff3d77b8f..cf18c1027 100644 --- a/xmtp_api_http/src/http_stream.rs +++ b/xmtp_api_http/src/http_stream.rs @@ -12,7 +12,7 @@ use serde_json::Deserializer; use std::{ marker::PhantomData, pin::Pin, - task::{Context, Poll, ready}, + task::{ready, Context, Poll}, }; use xmtp_common::StreamWrapper; use xmtp_proto::{Error, ErrorKind}; @@ -49,15 +49,11 @@ where use Poll::*; let this = self.as_mut().project(); let response = ready!(this.inner.poll(cx)); - tracing::info!("ESTABLISH READY"); let stream = response - .inspect_err(|e| { - tracing::error!( - "Error during http subscription with grpc http gateway {e}" - ); - }) - .map_err(|_| Error::new(ErrorKind::SubscribeError))?; - tracing::info!("Calling bytes stream!"); + .inspect_err(|e| { + tracing::error!("Error during http subscription with grpc http gateway {e}"); + }) + .map_err(|_| Error::new(ErrorKind::SubscribeError))?; Ready(Ok(StreamWrapper::new(stream.bytes_stream()))) } } @@ -89,13 +85,13 @@ where .inspect_err(|e| tracing::error!("Error in http stream to grpc gateway {e}")) .map_err(|_| Error::new(ErrorKind::SubscribeError))?; let item = Self::on_bytes(bytes, this.remaining)?.pop(); - if let None = item { + if item.is_none() { self.poll_next(cx) } else { Ready(Some(Ok(item.expect("handled none;")))) } - }, - None => Ready(None) + } + None => Ready(None), } } } @@ -119,7 +115,6 @@ where for<'de> R: Deserialize<'de> + DeserializeOwned + Send, { fn on_bytes(bytes: bytes::Bytes, remaining: &mut Vec<u8>) -> Result<Vec<R>, Error> { - tracing::info!("BYTES: {:x}", bytes); let bytes = &[remaining.as_ref(), bytes.as_ref()].concat(); let de = Deserializer::from_slice(bytes); let mut deser_stream = de.into_iter::<GrpcResponse<R>>(); @@ -138,7 +133,6 @@ where Err(e) => { if e.is_eof() { *remaining = (&**bytes)[deser_stream.byte_offset()..].to_vec(); - tracing::debug!("IS EOF"); break; } else { return Err(Error::new(ErrorKind::MlsError).with(e.to_string())); @@ -147,6 +141,10 @@ where Ok(GrpcResponse::Empty {}) => continue, } } + + if items.len() > 1 { + tracing::warn!("more than one item deserialized from http stream"); + } Ok(items) } } @@ -179,7 +177,9 @@ where let id = xmtp_common::rand_string::<12>(); tracing::info!("new http stream id={}", &id); Self { - state: HttpStreamState::NotStarted { future: HttpStreamEstablish::new(request) }, + state: HttpStreamState::NotStarted { + future: HttpStreamEstablish::new(request), + }, id, } } @@ -197,22 +197,21 @@ where cx: &mut std::task::Context<'_>, ) -> std::task::Poll<Option<Self::Item>> { use ProjectHttpStream::*; - tracing::info!("Polling http stream id={}", &self.id); + tracing::trace!("Polling http stream id={}", &self.id); let mut this = self.as_mut().project(); match this.state.as_mut().project() { NotStarted { future } => { let stream = ready!(future.poll(cx))?; - tracing::info!("Ready TOP LEVEL"); - this.state.set(HttpStreamState::Started { stream: HttpPostStream::new(stream)}); - tracing::info!("Stream {} ready, polling for the first time...", &self.id); + this.state.set(HttpStreamState::Started { + stream: HttpPostStream::new(stream), + }); + tracing::debug!("Stream {} ready, polling for the first time...", &self.id); self.poll_next(cx) - }, + } Started { stream } => { - let res = stream.poll_next(cx); - if let Poll::Ready(_) = res { - tracing::info!("stream id={} ready with item", &self.id); - } - res + let item = ready!(stream.poll_next(cx)); + tracing::debug!("stream id={} ready with item", &self.id); + Poll::Ready(item) } } } @@ -221,8 +220,8 @@ where impl<'a, F, R> std::fmt::Debug for HttpStream<'a, F, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { match self.state { - HttpStreamState::NotStarted{..} => write!(f, "not started"), - HttpStreamState::Started{..} => write!(f, "started"), + HttpStreamState::NotStarted { .. } => write!(f, "not started"), + HttpStreamState::Started { .. } => write!(f, "started"), } } } diff --git a/xmtp_mls/' b/xmtp_mls/' deleted file mode 100644 index b088ca7b4..000000000 --- a/xmtp_mls/' +++ /dev/null @@ -1,516 +0,0 @@ -use std::{ - collections::HashMap, - future::Future, - pin::Pin, - task::{Context, Poll, ready}, -}; - -use super::{Result, SubscribeError}; -use crate::{ - api::GroupFilter, - groups::{scoped_client::ScopedGroupClient, MlsGroup}, - types::GroupId, - storage::{ - group::StoredGroup, group_message::StoredGroupMessage, refresh_state::EntityKind, - StorageError, - }, - XmtpOpenMlsProvider, -}; -use futures::{Stream, TryFutureExt}; -use pin_project_lite::pin_project; -use xmtp_common::FutureWrapper; -use xmtp_common::{retry_async, Retry}; -use xmtp_id::InboxIdRef; -use xmtp_proto::{ - api_client::{trait_impls::XmtpApi, XmtpMlsStreams}, - xmtp::mls::api::v1::{group_message, GroupMessage}, -}; - -#[derive(thiserror::Error, Debug)] -pub enum MessageStreamError { - #[error("received message for not subscribed group {id}", id = hex::encode(_0))] - NotSubscribed(Vec<u8>), - #[error("Invalid Payload")] - InvalidPayload, -} - -impl xmtp_common::RetryableError for MessageStreamError { - fn is_retryable(&self) -> bool { - use MessageStreamError::*; - match self { - NotSubscribed(_) | InvalidPayload => false, - } - } -} - -fn extract_message_v1(message: GroupMessage) -> Result<group_message::V1> { - match message.version { - Some(group_message::Version::V1(value)) => Ok(value), - _ => Err(MessageStreamError::InvalidPayload.into()), - } -} - - -/// the position of this message in the backend topic -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct MessagePositionCursor(u64); - -impl MessagePositionCursor { - pub(super) fn set(&mut self, cursor: u64) { - self.0 = cursor; - } -} - -impl std::fmt::Display for MessagePositionCursor { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From<StoredGroup> for (Vec<u8>, u64) { - fn from(group: StoredGroup) -> (Vec<u8>, u64) { - (group.id, 0u64) - } -} - -impl From<StoredGroup> for (Vec<u8>, MessagePositionCursor) { - fn from(group: StoredGroup) -> (Vec<u8>, MessagePositionCursor) { - (group.id, 0u64.into()) - } -} - -impl std::ops::Deref for MessagePositionCursor { - type Target = u64; - - fn deref(&self) -> &u64 { - &self.0 - } -} - -impl From<u64> for MessagePositionCursor { - fn from(v: u64) -> MessagePositionCursor { - Self(v) - } -} - -pin_project! { - pub struct StreamGroupMessages<'a, C, Subscription> { - #[pin] inner: Subscription, - #[pin] state: State<'a, Subscription>, - client: &'a C, - group_list: HashMap<GroupId, MessagePositionCursor>, - } -} - -pin_project! { - #[project = ProjectState] - #[derive(Default)] - enum State<'a, Out> { - /// State that indicates the stream is waiting on the next message from the network - #[default] - Waiting, - /// state that indicates the stream is waiting on a IO/Network future to finish processing - /// the current message before moving on to the next one - Processing { - #[pin] future: FutureWrapper<'a, Result<Option<(StoredGroupMessage, u64)>>> - }, - Adding { - #[pin] future: FutureWrapper<'a, Result<Out>> - } - } -} - -pub(super) type MessagesApiSubscription<'a, C> = - <<C as ScopedGroupClient>::ApiClient as XmtpMlsStreams>::GroupMessageStream<'a>; - -impl<'a, C> StreamGroupMessages<'a, C, MessagesApiSubscription<'a, C>> -where - C: ScopedGroupClient + 'a, - <C as ScopedGroupClient>::ApiClient: XmtpApi + XmtpMlsStreams + 'a, -{ - pub async fn new( - client: &'a C, - group_list: Vec<GroupId>, - ) -> Result<Self> { - tracing::debug!("setting up messages subscription"); - let group_list = group_list.into_iter().map(|group_id| { - Ok((group_id, 1u64)) - }).collect::<Result<HashMap<GroupId, u64>>>()?; - let filters: Vec<GroupFilter> = group_list - .iter() - .map(|(group_id, cursor)| GroupFilter::new(group_id.to_vec(), Some(*cursor))) - .collect(); - for filter in &filters { - tracing::debug!("Subscribing to {} for group messages", hex::encode(&filter.group_id)); - } - let subscription = client.api().subscribe_group_messages(filters).await?; - - Ok(Self { - inner: subscription, - client, - state: Default::default(), - group_list: group_list.into_iter().map(|(g, c)| (g, c.into())).collect(), - }) - } - - /// Add a new group to this messages stream - pub(super) fn add(mut self: Pin<&mut Self>, group: MlsGroup<C>) { - tracing::info!("creating new messages stream to add group {}", hex::encode(&group.group_id)); - if self.group_list.contains_key(group.group_id.as_slice()) { - tracing::info!("group {} already in stream", hex::encode(&group.group_id)); - return; - } - - tracing::debug!( - inbox_id = self.client.inbox_id(), - installation_id = %self.client.installation_id(), - group_id = hex::encode(&group.group_id), - "begin establishing new message stream to include group_id={}", - hex::encode(&group.group_id) - ); - let this = self.as_mut().project(); - let mut filters = self.filters(); - // add the new group but not to our state. - // We will add the group to our state once we get the first message. - // In that message will be the real cursor, rather than a temporary `1` - filters.push(GroupFilter::new(group.group_id, Some(1))); - let future = self.client.api().subscribe_group_messages(filters).map_err(SubscribeError::from); - let mut this = self.as_mut().project(); - this.state.set(State::Adding { future: FutureWrapper::new(future)}); - } -} - -impl<'a, C, Subscription> Stream for StreamGroupMessages<'a, C, Subscription> -where - C: ScopedGroupClient + 'a, - Subscription: Stream<Item = std::result::Result<GroupMessage, xmtp_proto::Error>> + 'a, -{ - type Item = Result<StoredGroupMessage>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - // tracing::debug!("POLLING STREAM MESSAGES"); - use std::task::Poll::*; - use ProjectState::*; - let mut this = self.as_mut().project(); - - match this.state.as_mut().project() { - Waiting => { - if let Some(envelope) = ready!(this.inner.poll_next(cx)) { - tracing::debug!("processing message in stream"); - let future = ProcessMessageFuture::new(*this.client, envelope?)?; - let future = future.process(); - this.state.set(State::Processing { - future: FutureWrapper::new(future), - }); - self.try_update_state(cx) - } else { // the stream ended - Ready(None) - } - }, - Processing { .. } => self.try_update_state(cx), - Adding { future } => { - let stream = ready!(future.poll(cx))?; - let mut this = self.as_mut().project(); - this.inner.set(stream); - this.state.as_mut().set(State::Waiting); - tracing::debug!("added group to messages stream"); - self.poll_next(cx) - } - } - } -} - -impl<'a, C, S> StreamGroupMessages<'a, C, S> { - fn filters(&self) -> Vec<GroupFilter> { - self.group_list - .iter() - .map(|(group_id, cursor)| GroupFilter::new(group_id.to_vec(), Some(**cursor))) - .collect() - } -} - -impl<'a, C, Subscription> StreamGroupMessages<'a, C, Subscription> -where - C: ScopedGroupClient + 'a, - Subscription: Stream<Item = std::result::Result<GroupMessage, xmtp_proto::Error>> + 'a, -{ - /// Try to finish processing the stream item by polling the stored future. - /// Update state to `Waiting` and insert the new cursor if ready. - /// If Stream state is in `Waiting`, returns `Pending`. - fn try_update_state(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<<Self as Stream>::Item>> { - use ProjectState::*; - - let mut this = self.as_mut().project(); - if let Processing { future } = this.state.as_mut().project() { - match ready!(future.poll(cx))? { - Some((msg, new_cursor)) => { - this.state.set(State::Waiting); - if let Some(tracked_cursor) = this.group_list.get_mut(msg.group_id.as_slice()) { - tracked_cursor.set(new_cursor); - return Poll::Ready(Some(Ok(msg))); - } else { - tracing::info!("\n\nGot new group\n\n"); - this.group_list - .insert(msg.group_id.clone().into(), new_cursor.into()); - return self.poll_next(cx); - } - }, - None => { - tracing::warn!("skipping message streaming payload"); - this.state.set(State::Waiting); - // we are skipping this message and need to add the task - // back to the queue to start polling for the next one - return self.poll_next(cx); - // cx.waker().wake_by_ref(); - // return Poll::Pending; - } - } - } - Poll::Pending - } -} - -impl<'a, C, S> StreamGroupMessages<'a, C, S> -where - S: Stream<Item = std::result::Result<GroupMessage, xmtp_proto::Error>> + 'a, - C: ScopedGroupClient + 'a, -{ - /* - pub(super) fn drain( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Vec<Option<Result<StoredGroupMessage>>> { - let mut drained = Vec::new(); - while let Poll::Ready(msg) = self.as_mut().poll_next(cx) { - drained.push(msg); - } - drained - } - */ -} - -/// Future that processes a group message from the network -pub struct ProcessMessageFuture<Client> { - provider: XmtpOpenMlsProvider, - client: Client, - msg: group_message::V1, -} - -impl<C> ProcessMessageFuture<C> -where - C: ScopedGroupClient, -{ - /// Create a new Future to process a GroupMessage. - pub fn new(client: C, envelope: GroupMessage) -> Result<ProcessMessageFuture<C>> { - let msg = extract_message_v1(envelope)?; - let provider = client.mls_provider()?; - tracing::info!( - inbox_id = client.inbox_id(), - group_id = hex::encode(&msg.group_id), - "Received message streaming payload" - ); - - Ok(Self { - provider, - client, - msg, - }) - } - - fn inbox_id(&self) -> InboxIdRef<'_> { - self.client.inbox_id() - } - - /// process a message, returning the message from the database and the cursor of the message. - pub(crate) async fn process(self) -> Result<Option<(StoredGroupMessage, u64)>> { - let group_message::V1 { - // the cursor ID is the position in the monolithic backend topic - id: ref cursor_id, - ref created_ns, - .. - } = self.msg; - - tracing::info!( - inbox_id = self.inbox_id(), - group_id = hex::encode(&self.msg.group_id), - cursor_id, - "client [{}] is about to process streamed envelope: [{}]", - self.inbox_id(), - &cursor_id - ); - - if self.needs_to_sync(*cursor_id).await? { - self.process_stream_entry().await - } - - // Load the message from the DB to handle cases where it may have been already processed in - // another thread - let new_message = self - .provider - .conn_ref() - .get_group_message_by_timestamp(&self.msg.group_id, *created_ns as i64)?; - /* - .inspect(|e| { - if matches!(e, SubscribeError::GroupMessageNotFound) { - tracing::warn!( - cursor_id, - inbox_id = self.inbox_id(), - group_id = hex::encode(&self.msg.group_id), - "group message not found" - ); - } - })?; - */ - if let Some(msg) = new_message { - Ok(Some((msg, *cursor_id))) - } else { - tracing::warn!( - cursor_id, - inbox_id = self.inbox_id(), - group_id = hex::encode(&self.msg.group_id), - "group message not found" - ); - - Ok(None) - } - } - - /// stream processing function - async fn process_stream_entry(&self) { - let process_result = self - .client - .store() - .retryable_transaction_async(&self.provider, |provider| async move { - let (group, _) = - MlsGroup::new_validated(&self.client, self.msg.group_id.clone(), provider)?; - tracing::info!( - inbox_id = self.inbox_id(), - group_id = hex::encode(&self.msg.group_id), - cursor_id = self.msg.id, - "current epoch for [{}] in process_stream_entry()", - self.inbox_id(), - ); - group - .process_message(provider, &self.msg, false) - .await - // NOTE: We want to make sure we retry an error in process_message - .map_err(SubscribeError::ReceiveGroup) - }) - .await; - - if let Err(SubscribeError::ReceiveGroup(e)) = process_result { - tracing::warn!("error processing streamed message {e}"); - self.attempt_message_recovery().await - // This should never occur because we map the error to `ReceiveGroup` - // But still exists defensively - } else if let Err(e) = process_result { - tracing::error!( - inbox_id = self.client.inbox_id(), - group_id = hex::encode(&self.msg.group_id), - cursor_id = self.msg.id, - err = e.to_string(), - "process stream entry {:?}", - e - ); - } else { - tracing::trace!( - cursor_id = self.msg.id, - inbox_id = self.inbox_id(), - group_id = hex::encode(&self.msg.group_id), - "message process in stream success" - ); - } - } - - /// Checks if a message has already been processed through a sync - // TODO: Make this not async, and instead of retry add it back to wake queue. - async fn needs_to_sync(&self, current_msg_cursor: u64) -> Result<bool> { - let check_for_last_cursor = || -> std::result::Result<i64, StorageError> { - self.provider - .conn_ref() - .get_last_cursor_for_id(&self.msg.group_id, EntityKind::Group) - }; - - let last_synced_id = retry_async!(Retry::default(), (async { check_for_last_cursor() }))?; - Ok(last_synced_id < current_msg_cursor as i64) - } - - /// Attempt a recovery sync if a group message failed to process - async fn attempt_message_recovery(&self) { - let group = MlsGroup::new( - &self.client, - self.msg.group_id.clone(), - self.msg.created_ns as i64, - ); - tracing::debug!( - inbox_id = self.client.inbox_id(), - group_id = hex::encode(&self.msg.group_id), - cursor_id = self.msg.id, - "attempting recovery sync" - ); - // Swallow errors here, since another process may have successfully saved the message - // to the DB - if let Err(err) = group.sync_with_conn(&self.provider).await { - tracing::warn!( - inbox_id = self.client.inbox_id(), - group_id = hex::encode(&self.msg.group_id), - cursor_id = self.msg.id, - err = %err, - "recovery sync triggered by streamed message failed: {}", err - ); - } else { - tracing::debug!( - inbox_id = self.client.inbox_id(), - group_id = hex::encode(&self.msg.group_id), - cursor_id = self.msg.id, - "recovery sync triggered by streamed message successful" - ) - } - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use futures::stream::StreamExt; - use wasm_bindgen_test::wasm_bindgen_test; - - use crate::{assert_msg, assert_msg_exists}; - use crate::{builder::ClientBuilder, groups::GroupMetadataOptions}; - use xmtp_cryptography::utils::generate_local_wallet; - - #[wasm_bindgen_test(unsupported = tokio::test(flavor = "current_thread"))] - async fn test_stream_messages() { - xmtp_common::logger(); - let alice = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); - let bob = ClientBuilder::new_test_client(&generate_local_wallet()).await; - - let alice_group = alice - .create_group(None, GroupMetadataOptions::default()) - .unwrap(); - tracing::info!("Group Id = [{}]", hex::encode(&alice_group.group_id)); - - alice_group - .add_members_by_inbox_id(&[bob.inbox_id()]) - .await - .unwrap(); - let bob_groups = bob - .sync_welcomes(&bob.mls_provider().unwrap()) - .await - .unwrap(); - let bob_group = bob_groups.first().unwrap(); - alice_group.sync().await.unwrap(); - - let stream = alice_group.stream().await.unwrap(); - futures::pin_mut!(stream); - bob_group.send_message(b"hello").await.unwrap(); - - // group updated msg/bob is added - // assert_msg_exists!(stream); - assert_msg!(stream, "hello"); - - bob_group.send_message(b"hello2").await.unwrap(); - assert_msg!(stream, "hello2"); - } -} diff --git a/xmtp_mls/src/groups/device_sync/consent_sync.rs b/xmtp_mls/src/groups/device_sync/consent_sync.rs index ea3d5587d..4a59a9552 100644 --- a/xmtp_mls/src/groups/device_sync/consent_sync.rs +++ b/xmtp_mls/src/groups/device_sync/consent_sync.rs @@ -45,7 +45,6 @@ pub(crate) mod tests { #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(target_family = "wasm", ignore)] async fn test_consent_sync() { - xmtp_common::logger(); let history_sync_url = format!("http://{}:{}", HISTORY_SERVER_HOST, HISTORY_SERVER_PORT); let wallet = generate_local_wallet(); let amal_a = ClientBuilder::new_test_client_with_history(&wallet, &history_sync_url).await; diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index c6d60a05a..81adfc8fa 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -65,12 +65,10 @@ use crate::storage::{ NotFound, StorageError, }; use xmtp_common::time::now_ns; -use xmtp_proto::xmtp::mls::{ - message_contents::{ - content_types::ReactionV2, - plaintext_envelope::{Content, V1}, - EncodedContent, PlaintextEnvelope, - }, +use xmtp_proto::xmtp::mls::message_contents::{ + content_types::ReactionV2, + plaintext_envelope::{Content, V1}, + EncodedContent, PlaintextEnvelope, }; use crate::{ diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index ddbb119a7..6e0266abf 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -23,7 +23,6 @@ use std::collections::HashMap; use std::sync::{Arc, LazyLock, Mutex}; use storage::{xmtp_openmls_provider::XmtpOpenMlsProvider, DuplicateItem, StorageError}; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; -pub use xmtp_openmls_provider::XmtpOpenMlsProvider; pub use xmtp_id::InboxOwner; pub use xmtp_proto::api_client::trait_impls::*; diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index 86838c6d5..b127475bc 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -351,10 +351,8 @@ impl DbConnection { .order(dsl::created_at_ns.asc()) .limit(1) .filter(dsl::id.eq(id)); - - Ok(self - .raw_query(|conn| query.load(conn)) - .map(|mut g| g.pop())?) + let groups = self.raw_query(|conn| query.load(conn))?; + Ok(groups.into_iter().next()) } /// Return a single group that matches the given welcome ID @@ -366,14 +364,14 @@ impl DbConnection { .order(dsl::created_at_ns.asc()) .filter(dsl::welcome_id.eq(welcome_id)); - let mut groups = self.raw_query(|conn| query.load(conn))?; + let groups = self.raw_query(|conn| query.load(conn))?; if groups.len() > 1 { tracing::warn!( welcome_id, "More than one group found for welcome_id {welcome_id}" ); } - Ok(groups.pop()) + Ok(groups.into_iter().next()) } pub fn find_dm_group( @@ -386,12 +384,12 @@ impl DbConnection { .filter(dsl::dm_id.eq(Some(dm_id))) .order(dsl::last_message_ns.desc()); - let mut groups: Vec<StoredGroup> = self.raw_query(|conn| query.load(conn))?; + let groups: Vec<StoredGroup> = self.raw_query(|conn| query.load(conn))?; if groups.len() > 1 { tracing::info!("More than one group found for dm_inbox_id {members:?}"); } - Ok(groups.pop()) + Ok(groups.into_iter().next()) } /// Updates group membership state diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 3f7af2ce8..68f5d11e2 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -57,7 +57,6 @@ use diesel::{ sql_query, }; use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; -use std::sync::Arc; use xmtp_common::{retry_async, Retry, RetryableError}; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations/"); @@ -366,17 +365,17 @@ where Fut: futures::Future<Output = Result<T, E>>, E: From<diesel::result::Error> + From<StorageError>, Db: 'a; - pub async fn retryable_transaction_async<'a, T, F, E, Fut>( - &self, - provider: &'a XmtpOpenMlsProviderPrivate<<Db as XmtpDb>::Connection>, - retry: Option<Retry>, - fun: F, - ) -> Result<T, E> - where - F: Copy + FnMut(&'a XmtpOpenMlsProviderPrivate<<Db as XmtpDb>::Connection>) -> Fut, - Fut: futures::Future<Output = Result<T, E>>, - E: From<diesel::result::Error> + From<StorageError> + RetryableError; - + #[allow(async_fn_in_trait)] + async fn retryable_transaction_async<'a, T, F, E, Fut>( + &'a self, + retry: Option<Retry>, + fun: F, + ) -> Result<T, E> + where + F: Copy + FnMut(&'a XmtpOpenMlsProviderPrivate<Db, <Db as XmtpDb>::Connection>) -> Fut, + Fut: futures::Future<Output = Result<T, E>>, + E: From<diesel::result::Error> + From<StorageError> + RetryableError, + Db: 'a; } impl<Db> ProviderTransactions<Db> for XmtpOpenMlsProviderPrivate<Db, <Db as XmtpDb>::Connection> @@ -463,15 +462,6 @@ where // ensuring we have only one strong reference let result = fun(self).await; let local_connection = self.conn_ref().inner_ref(); - if Arc::strong_count(&local_connection) > 1 { - tracing::warn!( - "More than 1 strong connection references still exist during async transaction" - ); - } - - if Arc::weak_count(&local_connection) > 1 { - tracing::warn!("More than 1 weak connection references still exist during transaction"); - } // after the closure finishes, `local_provider` should have the only reference ('strong') // to `XmtpOpenMlsProvider` inner `DbConnection`.. @@ -497,22 +487,21 @@ where } } - pub async fn retryable_transaction_async<'a, T, F, E, Fut>( - &self, - provider: &'a XmtpOpenMlsProviderPrivate<<Db as XmtpDb>::Connection>, - retry: Option<Retry>, - fun: F, - ) -> Result<T, E> - where - F: Copy + FnMut(&'a XmtpOpenMlsProviderPrivate<<Db as XmtpDb>::Connection>) -> Fut, - Fut: futures::Future<Output = Result<T, E>>, - E: From<diesel::result::Error> + From<StorageError> + RetryableError, - { - retry_async!( - retry.unwrap_or_default(), - (async { self.transaction_async(provider, fun).await }) - ) - } + async fn retryable_transaction_async<'a, T, F, E, Fut>( + &'a self, + retry: Option<Retry>, + fun: F, + ) -> Result<T, E> + where + F: Copy + FnMut(&'a XmtpOpenMlsProviderPrivate<Db, <Db as XmtpDb>::Connection>) -> Fut, + Fut: futures::Future<Output = Result<T, E>>, + E: From<diesel::result::Error> + From<StorageError> + RetryableError, + { + retry_async!( + retry.unwrap_or_default(), + (async { self.transaction_async(fun).await }) + ) + } } #[cfg(test)] @@ -779,6 +768,7 @@ pub(crate) mod tests { #[cfg_attr(not(target_arch = "wasm32"), tokio::test)] #[cfg(not(target_arch = "wasm32"))] async fn test_transaction_rollback() { + use std::sync::Arc; use std::sync::Barrier; let db_path = tmp_path(); diff --git a/xmtp_mls/src/subscriptions/mod.rs b/xmtp_mls/src/subscriptions/mod.rs index 616fd9052..efce36735 100644 --- a/xmtp_mls/src/subscriptions/mod.rs +++ b/xmtp_mls/src/subscriptions/mod.rs @@ -24,9 +24,9 @@ use crate::{ }, storage::{ consent_record::StoredConsentRecord, - group::{ConversationType, GroupQueryArgs, StoredGroup}, + group::ConversationType, group_message::StoredGroupMessage, - ProviderTransactions, StorageError, NotFound, group::ConversationType + StorageError, NotFound, }, Client, XmtpApi, }; diff --git a/xmtp_mls/src/subscriptions/stream_all.rs b/xmtp_mls/src/subscriptions/stream_all.rs index 1cab003ff..9de224e0a 100644 --- a/xmtp_mls/src/subscriptions/stream_all.rs +++ b/xmtp_mls/src/subscriptions/stream_all.rs @@ -1,16 +1,16 @@ use std::{ pin::Pin, - task::{Context, Poll, ready}, + task::{ready, Context, Poll}, }; use crate::subscriptions::stream_messages::MessagesApiSubscription; use crate::{ - types::GroupId, groups::{scoped_client::ScopedGroupClient, MlsGroup}, storage::{ group::{ConversationType, GroupQueryArgs}, group_message::StoredGroupMessage, }, + types::GroupId, Client, }; use futures::stream::Stream; @@ -102,9 +102,8 @@ where } if let Some(group) = ready!(this.conversations.poll_next(cx)) { this.messages.as_mut().add(group?); - return this.messages.poll_next(cx); } - Poll::Pending + this.messages.poll_next(cx) } } @@ -177,7 +176,6 @@ mod tests { #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] async fn test_stream_all_messages_unchanging_group_list() { - xmtp_common::logger(); let alix = ClientBuilder::new_test_client(&generate_local_wallet()).await; let bo = ClientBuilder::new_test_client(&generate_local_wallet()).await; let caro = ClientBuilder::new_test_client(&generate_local_wallet()).await; @@ -227,7 +225,7 @@ mod tests { .unwrap(); let alix_dm = alix - .create_dm_by_inbox_id(&alix.mls_provider().unwrap(), bo.inbox_id().to_string()) + .create_dm_by_inbox_id(bo.inbox_id().to_string()) .await .unwrap(); @@ -237,9 +235,15 @@ mod tests { .await .unwrap(); futures::pin_mut!(stream); - alix_dm.send_message("first DM msg".as_bytes()).await.unwrap(); + alix_dm + .send_message("first DM msg".as_bytes()) + .await + .unwrap(); tracing::info!("\n\nsent first DM message\n\n"); - alix_group.send_message("second GROUP msg".as_bytes()).await.unwrap(); + alix_group + .send_message("second GROUP msg".as_bytes()) + .await + .unwrap(); tracing::info!("\n\nsent second group msg\n\n"); assert_msg!(stream, "second GROUP msg"); tracing::info!("\n\ngot `second`: Group-Only message\n\n"); @@ -250,9 +254,15 @@ mod tests { .await .unwrap(); futures::pin_mut!(stream); - alix_group.send_message("second GROUP msg".as_bytes()).await.unwrap(); + alix_group + .send_message("second GROUP msg".as_bytes()) + .await + .unwrap(); tracing::info!("\n\nSENDING SECOND DM MSG\n\n"); - alix_dm.send_message("second DM msg".as_bytes()).await.unwrap(); + alix_dm + .send_message("second DM msg".as_bytes()) + .await + .unwrap(); tracing::info!("\nSENT SECOND DM MSG\n\n"); assert_msg!(stream, "second DM msg"); tracing::info!("Got second DM Only Message"); @@ -268,7 +278,7 @@ mod tests { assert_msg!(stream, "second"); } - #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] + #[wasm_bindgen_test(unsupported = tokio::test(flavor = "current_thread"))] async fn test_stream_all_messages_does_not_lose_messages() { let caro = ClientBuilder::new_test_client(&generate_local_wallet()).await; let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); @@ -288,9 +298,12 @@ mod tests { let alix_group_pointer = alix_group.clone(); crate::spawn(None, async move { let mut sent = 0; - for i in 0..50 { - let msg = format!("spam {i}"); - alix_group_pointer.send_message(msg.as_bytes()).await.unwrap(); + for i in 0..15 { + let msg = format!("main spam {i}"); + alix_group_pointer + .send_message(msg.as_bytes()) + .await + .unwrap(); sent += 1; xmtp_common::time::sleep(core::time::Duration::from_micros(100)).await; tracing::info!("sent {sent}"); @@ -303,27 +316,25 @@ mod tests { let caro_id = caro.inbox_id().to_string(); crate::spawn(None, async move { let caro = &caro_id; - for i in 0..50 { + for i in 0..5 { let new_group = eve .create_group(None, GroupMetadataOptions::default()) .unwrap(); new_group.add_members_by_inbox_id(&[caro]).await.unwrap(); tracing::info!("\n\n EVE SENDING {i} \n\n"); - let msg = format!("spam {i} from new group"); - new_group - .send_message(msg.as_bytes()) - .await - .unwrap(); + let msg = format!("EVE spam {i} from new group"); + new_group.send_message(msg.as_bytes()).await.unwrap(); } }); let mut messages = Vec::new(); - let _ = tokio::time::timeout(core::time::Duration::from_secs(30), async { + let timeout = if cfg!(target_arch = "wasm32") { 15 } else { 5 }; + let _ = xmtp_common::time::timeout(core::time::Duration::from_secs(timeout), async { futures::pin_mut!(stream); loop { - if messages.len() < 100 { + if messages.len() < 20 { if let Some(Ok(msg)) = stream.next().await { - tracing::info!( + tracing::error!( message_id = hex::encode(&msg.id), sender_inbox_id = msg.sender_inbox_id, sender_installation_id = hex::encode(&msg.sender_installation_id), @@ -342,7 +353,7 @@ mod tests { .await; tracing::info!("Total Messages: {}", messages.len()); - assert_eq!(messages.len(), 100); + assert_eq!(messages.len(), 20); } #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))] @@ -369,7 +380,7 @@ mod tests { }); let mut messages = Vec::new(); - let _ = tokio::time::timeout(core::time::Duration::from_secs(20), async { + let _ = xmtp_common::time::timeout(core::time::Duration::from_secs(20), async { futures::pin_mut!(stream); loop { if messages.len() < 5 { diff --git a/xmtp_mls/src/subscriptions/stream_conversations.rs b/xmtp_mls/src/subscriptions/stream_conversations.rs index 8273e7081..0eeb39074 100644 --- a/xmtp_mls/src/subscriptions/stream_conversations.rs +++ b/xmtp_mls/src/subscriptions/stream_conversations.rs @@ -2,12 +2,12 @@ use std::{ collections::HashSet, future::Future, pin::Pin, - task::{Context, Poll}, + task::{ready, Context, Poll}, }; use crate::{ groups::{scoped_client::ScopedGroupClient, MlsGroup}, - storage::{group::ConversationType, NotFound, refresh_state::EntityKind}, + storage::{group::ConversationType, refresh_state::EntityKind, NotFound, ProviderTransactions}, Client, XmtpOpenMlsProvider, }; use futures::{prelude::stream::Select, Stream}; @@ -64,19 +64,17 @@ impl Stream for BroadcastGroupStream { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { use std::task::Poll::*; let this = self.project(); - match this.inner.poll_next(cx) { - Ready(Some(event)) => { - let ev = xmtp_common::optify!(event, "Missed messages due to event queue lag") - .and_then(LocalEvents::group_filter); - if let Some(g) = ev { - Ready(Some(Ok(WelcomeOrGroup::Group(g)))) - } else { - // skip this item since it was either missed due to lag, or not a group - Pending - } + if let Some(event) = ready!(this.inner.poll_next(cx)) { + if let Some(group) = + xmtp_common::optify!(event, "Missed messages due to event queue lag") + .and_then(LocalEvents::group_filter) + { + Ready(Some(Ok(WelcomeOrGroup::Group(group)))) + } else { + Pending } - Pending => Pending, - Ready(None) => Ready(None), + } else { + Ready(None) } } } @@ -227,7 +225,6 @@ where Waiting => { match this.inner.poll_next(cx) { Ready(Some(item)) => { - tracing::info!("READY, STARTING TO PROCESS"); let mut this = self.as_mut().project(); let future = ProcessWelcomeFuture::new( this.known_welcome_ids.clone(), @@ -243,16 +240,15 @@ where // this will return immediately if we have already processed the welcome // and it exists in the db - let Processing { future } = this.state.project() else { unreachable!() }; + let Processing { future } = this.state.project() else { + unreachable!() + }; let poll = future.poll(cx); self.as_mut().try_process(poll, cx) } // stream ended - Ready(None) => { - tracing::info!("READY NONE"); - Ready(None) - } - Pending => Pending + Ready(None) => Ready(None), + Pending => Pending, } } Processing { future } => { @@ -291,8 +287,7 @@ where this.state.as_mut().set(ProcessState::Waiting); // we have to re-ad this task to the queue // to let http know we are waiting on the next item - cx.waker().wake_by_ref(); - Pending + self.poll_next(cx) } Ready(Err(e)) => Ready(Some(Err(e))), Pending => Pending, @@ -407,10 +402,8 @@ where "Trying to process streamed welcome" ); - let group = client - .context() - .store() - .retryable_transaction_async(provider, |provider| async { + let group = provider + .retryable_transaction_async(None, |provider| async { MlsGroup::create_from_encrypted_welcome( client, provider, @@ -471,7 +464,6 @@ mod test { #[wasm_bindgen_test(unsupported = tokio::test(flavor = "current_thread"))] async fn test_stream_welcomes() { - xmtp_common::logger(); let alice = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); let bob = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); let alice_bob_group = alice @@ -490,7 +482,6 @@ mod test { } #[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread"))] - #[cfg_attr(target_family = "wasm", ignore)] async fn test_dm_streaming() { let alix = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); let bo = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); diff --git a/xmtp_mls/src/subscriptions/stream_messages.rs b/xmtp_mls/src/subscriptions/stream_messages.rs index fa3726c82..24b6007fa 100644 --- a/xmtp_mls/src/subscriptions/stream_messages.rs +++ b/xmtp_mls/src/subscriptions/stream_messages.rs @@ -2,21 +2,21 @@ use std::{ collections::HashMap, future::Future, pin::Pin, - task::{Context, Poll, ready}, + task::{ready, Context, Poll}, }; use super::{Result, SubscribeError}; use crate::{ api::GroupFilter, groups::{scoped_client::ScopedGroupClient, MlsGroup}, - types::GroupId, storage::{ - group::StoredGroup, group_message::StoredGroupMessage, refresh_state::EntityKind, - StorageError, + encrypted_store::ProviderTransactions, group::StoredGroup, + group_message::StoredGroupMessage, refresh_state::EntityKind, StorageError, }, + types::GroupId, XmtpOpenMlsProvider, }; -use futures::{Stream, TryFutureExt}; +use futures::Stream; use pin_project_lite::pin_project; use xmtp_common::FutureWrapper; use xmtp_id::InboxIdRef; @@ -49,7 +49,6 @@ fn extract_message_v1(message: GroupMessage) -> Result<group_message::V1> { } } - /// the position of this message in the backend topic /// based only upon messages from the stream #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -66,10 +65,6 @@ impl MessagePosition { fn pos(&self) -> u64 { self.cursor.unwrap_or(0) } - - fn is_unknown(&self) -> bool { - self.cursor.is_none() - } } impl std::fmt::Display for MessagePosition { @@ -92,9 +87,7 @@ impl From<StoredGroup> for (Vec<u8>, MessagePosition) { impl From<u64> for MessagePosition { fn from(v: u64) -> MessagePosition { - Self { - cursor: Some(v), - } + Self { cursor: Some(v) } } } @@ -104,6 +97,7 @@ pin_project! { #[pin] state: State<'a, Subscription>, client: &'a C, group_list: HashMap<GroupId, MessagePosition>, + drained: Vec<Option<Result<GroupMessage>>>, } } @@ -120,7 +114,7 @@ pin_project! { #[pin] future: FutureWrapper<'a, Result<Option<(StoredGroupMessage, u64)>>> }, Adding { - #[pin] future: FutureWrapper<'a, Result<Out>> + #[pin] future: FutureWrapper<'a, Result<(Out, Vec<u8>, Option<u64>)>> } } } @@ -133,23 +127,50 @@ where C: ScopedGroupClient + 'a, <C as ScopedGroupClient>::ApiClient: XmtpApi + XmtpMlsStreams + 'a, { - pub async fn new( - client: &'a C, - group_list: Vec<GroupId>, - ) -> Result<Self> { + pub async fn new(client: &'a C, group_list: Vec<GroupId>) -> Result<Self> { tracing::debug!("setting up messages subscription"); - let group_list = group_list.into_iter().map(|group_id| { - Ok((group_id, 0u64)) - }).collect::<Result<HashMap<GroupId, u64>>>()?; + + let mut group_list = group_list + .into_iter() + .map(|group_id| (group_id, 0u64)) + .collect::<HashMap<GroupId, u64>>(); + + let cursors = group_list + .iter() + .map(|(group, _)| client.api().query_group_messages(group.to_vec(), Some(0))); + + let cursors = futures::future::join_all(cursors) + .await + .into_iter() + .map(|r| r.map_err(SubscribeError::from)) + .collect::<Result<Vec<_>>>()? + .into_iter() + .flatten() + .collect::<Vec<_>>(); + + for message in cursors { + let group_message::V1 { + id: cursor, + group_id, + .. + } = extract_message_v1(message)?; + group_list.entry(group_id.clone().into()).and_modify(|e| { + if *e < cursor { + *e = cursor + } + }); + tracing::info!( + "Subscribed to group {} at cursor {}", + hex::encode(&group_id), + group_list.get(group_id.as_slice()).unwrap() + ); + } + let filters: Vec<GroupFilter> = group_list .iter() .map(|(group_id, cursor)| GroupFilter::new(group_id.to_vec(), Some(*cursor))) .collect(); - for filter in &filters { - let messages = client.api().query_group_messages(filter.group_id.to_vec(), Some(1)).await; - tracing::info!("{:?}", messages); - tracing::debug!("Subscribing to {} for group messages", hex::encode(&filter.group_id)); - } + let subscription = client.api().subscribe_group_messages(filters).await?; Ok(Self { @@ -157,12 +178,12 @@ where client, state: Default::default(), group_list: group_list.into_iter().map(|(g, c)| (g, c.into())).collect(), + drained: Vec::new(), }) } /// Add a new group to this messages stream pub(super) fn add(mut self: Pin<&mut Self>, group: MlsGroup<C>) { - tracing::info!("creating new messages stream to add group {}", hex::encode(&group.group_id)); if self.group_list.contains_key(group.group_id.as_slice()) { tracing::info!("group {} already in stream", hex::encode(&group.group_id)); return; @@ -176,26 +197,50 @@ where hex::encode(&group.group_id) ); let this = self.as_mut().project(); - this.group_list.insert(group.group_id.into(), 1.into()); - // let mut filters = self.filters(); - // add the new group but not to our state. - // We will add the group to our state once we get the first message. - // In that message will be the real cursor, rather than a temporary `1` - // filters.push(GroupFilter::new(group.group_id, Some(1))); - let future = self.client.api().subscribe_group_messages(self.filters()).map_err(SubscribeError::from); + this.group_list + .insert(group.group_id.clone().into(), 0.into()); + let future = Self::subscribe(self.client, self.filters(), group.group_id); let mut this = self.as_mut().project(); - this.state.set(State::Adding { future: FutureWrapper::new(future)}); + this.state.set(State::Adding { + future: FutureWrapper::new(future), + }); + } + + // re-subscribe to the stream with a new group + async fn subscribe( + client: &'a C, + mut filters: Vec<GroupFilter>, + new_group: Vec<u8>, + ) -> Result<(MessagesApiSubscription<'a, C>, Vec<u8>, Option<u64>)> { + let msgs = client + .api() + .query_group_messages(new_group.to_vec(), Some(0)) + .await?; + + let mut cursor = None; + if let Some(m) = msgs.first() { + let m = extract_message_v1(m.clone())?; + if let Some(new) = filters.iter_mut().find(|f| &f.group_id == &new_group) { + new.id_cursor = Some(m.id); + cursor = Some(m.id); + } + } + let stream = client.api().subscribe_group_messages(filters).await?; + Ok((stream, new_group, cursor)) } - // Reinit with all the correct cursors - // this should result in the least amount of network calls & mitigate missed messages to get the right messages - // when groups are changing quickly - // TODO: can store a cursor on a message, or in refresh table, that is last processed message - // not necessarily synced. then we don't need this. - fn reinit(mut self: Pin<&mut Self>) { - let future = self.client.api().subscribe_group_messages(self.filters()).map_err(SubscribeError::from); + // needed mainly for slower connections when we may receive messages + // in between a switch. + pub(super) fn drain( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Vec<Option<Result<GroupMessage>>> { + let mut drained = Vec::new(); let mut this = self.as_mut().project(); - this.state.set(State::Adding { future: FutureWrapper::new(future)}); + while let Poll::Ready(msg) = this.inner.as_mut().poll_next(cx) { + drained.push(msg.map(|v| v.map_err(SubscribeError::from))); + } + drained } } @@ -207,32 +252,43 @@ where type Item = Result<StoredGroupMessage>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { - // tracing::debug!("POLLING STREAM MESSAGES"); use std::task::Poll::*; use ProjectState::*; let mut this = self.as_mut().project(); match this.state.as_mut().project() { - Waiting => { + Waiting => { + if let Some(envelope) = this.drained.pop().flatten() { + let future = ProcessMessageFuture::new(*this.client, envelope?)?; + let future = future.process(); + this.state.set(State::Processing { + future: FutureWrapper::new(future), + }); + return self.try_update_state(cx); + } if let Some(envelope) = ready!(this.inner.poll_next(cx)) { - tracing::debug!("processing {:?} message in stream", envelope); let future = ProcessMessageFuture::new(*this.client, envelope?)?; let future = future.process(); this.state.set(State::Processing { future: FutureWrapper::new(future), }); self.try_update_state(cx) - } else { // the stream ended + } else { + // the stream ended Ready(None) } - }, + } Processing { .. } => self.try_update_state(cx), Adding { future } => { - let stream = ready!(future.poll(cx))?; + let (stream, group, cursor) = ready!(future.poll(cx))?; + let this = self.as_mut(); + cursor.and_then(|c| Some(this.set_cursor(group.as_slice(), c))); + let drained = self.as_mut().drain(cx); let mut this = self.as_mut().project(); + this.drained.extend(drained); this.inner.set(stream); + tracing::info!("finished establishing new messages stream"); this.state.as_mut().set(State::Waiting); - tracing::debug!("added group to messages stream"); self.poll_next(cx) } } @@ -254,10 +310,20 @@ where // Subscription: Stream<Item = std::result::Result<GroupMessage, xmtp_proto::Error>> + 'a, <C as ScopedGroupClient>::ApiClient: XmtpApi + XmtpMlsStreams + 'a, { + fn set_cursor(mut self: Pin<&mut Self>, group_id: &[u8], new_cursor: u64) { + let this = self.as_mut().project(); + if let Some(cursor) = this.group_list.get_mut(group_id) { + cursor.set(new_cursor); + } + } + /// Try to finish processing the stream item by polling the stored future. /// Update state to `Waiting` and insert the new cursor if ready. /// If Stream state is in `Waiting`, returns `Pending`. - fn try_update_state(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<<Self as Stream>::Item>> { + fn try_update_state( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<<Self as Stream>::Item>> { use ProjectState::*; let mut this = self.as_mut().project(); @@ -265,25 +331,9 @@ where match ready!(future.poll(cx))? { Some((msg, new_cursor)) => { this.state.set(State::Waiting); - if let Some(tracked_cursor) = this.group_list.get_mut(msg.group_id.as_slice()) { - if tracked_cursor.is_unknown() { // we assume a cursor of 1 means unknown cursor - // reinit the stream with the correct cursor - tracked_cursor.set(new_cursor); - self.as_mut().reinit(); - // return Poll::Pending; - return self.poll_next(cx); - } else { - tracked_cursor.set(new_cursor); - return Poll::Ready(Some(Ok(msg))); - } - } else { // this should never happen - tracing::info!("\n\nGot new group\n\n"); - this.group_list - .insert(msg.group_id.clone().into(), new_cursor.into()); - return self.poll_next(cx); - } - // return Poll::Ready(Some(Ok(msg))); - }, + self.set_cursor(msg.group_id.as_slice(), new_cursor); + return Poll::Ready(Some(Ok(msg))); + } None => { tracing::warn!("skipping message streaming payload"); this.state.set(State::Waiting); @@ -295,24 +345,6 @@ where } } -impl<'a, C> StreamGroupMessages<'a, C, MessagesApiSubscription<'a, C>> -where - // S: Stream<Item = std::result::Result<GroupMessage, xmtp_proto::Error>> + 'a, - C: ScopedGroupClient + 'a, - <C as ScopedGroupClient>::ApiClient: XmtpApi + XmtpMlsStreams + 'a, -{ - pub(super) fn drain( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) { - let mut this = self.as_mut().project(); - // let mut drained = Vec::new(); - while let Poll::Ready(msg) = this.inner.as_mut().poll_next(cx) { - tracing::info!("msg: {:?}", msg); - } - } -} - /// Future that processes a group message from the network pub struct ProcessMessageFuture<Client> { provider: XmtpOpenMlsProvider, @@ -331,7 +363,8 @@ where tracing::info!( inbox_id = client.inbox_id(), group_id = hex::encode(&msg.group_id), - "Received message streaming payload" + cursor = msg.id, + "streamed new message" ); Ok(Self { @@ -373,18 +406,7 @@ where .provider .conn_ref() .get_group_message_by_timestamp(&self.msg.group_id, *created_ns as i64)?; - /* - .inspect(|e| { - if matches!(e, SubscribeError::GroupMessageNotFound) { - tracing::warn!( - cursor_id, - inbox_id = self.inbox_id(), - group_id = hex::encode(&self.msg.group_id), - "group message not found" - ); - } - })?; - */ + if let Some(msg) = new_message { Ok(Some((msg, *cursor_id))) } else { @@ -402,9 +424,8 @@ where /// stream processing function async fn process_stream_entry(&self) { let process_result = self - .client - .store() - .retryable_transaction_async(&self.provider, |provider| async move { + .provider + .retryable_transaction_async(None, |provider| async move { let (group, _) = MlsGroup::new_validated(&self.client, self.msg.group_id.clone(), provider)?; tracing::info!( @@ -500,13 +521,12 @@ mod tests { use futures::stream::StreamExt; use wasm_bindgen_test::wasm_bindgen_test; - use crate::{assert_msg, assert_msg_exists}; + use crate::assert_msg; use crate::{builder::ClientBuilder, groups::GroupMetadataOptions}; use xmtp_cryptography::utils::generate_local_wallet; #[wasm_bindgen_test(unsupported = tokio::test(flavor = "current_thread"))] async fn test_stream_messages() { - xmtp_common::logger(); let alice = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); let bob = ClientBuilder::new_test_client(&generate_local_wallet()).await;