Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix http stream with stream all messages #1510

Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix http stream with stream all messages
insipx committed Jan 27, 2025
commit aa5bc1ae5a737c6b3179d0a5abc44949a6d79190
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

87 changes: 87 additions & 0 deletions common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -26,6 +26,93 @@ use rand::{
RngCore,
};
use xmtp_cryptography::utils as crypto_utils;
use std::{pin::Pin, task::Poll, future::Future};
use futures::{Stream, FutureExt, StreamExt};

#[cfg(not(target_arch = "wasm32"))]
pub struct StreamWrapper<'a, I> {
inner: Pin<Box<dyn Stream<Item = I> + Send + 'a>>,
}

#[cfg(target_arch = "wasm32")]
pub struct StreamWrapper<'a, I> {
inner: Pin<Box<dyn Stream<Item = I> + 'a>>,
}

impl<'a, I> Stream for StreamWrapper<'a, I> {
type Item = I;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let inner = &mut self.inner;
futures::pin_mut!(inner);
inner.as_mut().poll_next(cx)
}
}

impl<'a, I> StreamWrapper<'a, I> {
#[cfg(not(target_arch = "wasm32"))]
pub fn new<S>(stream: S) -> Self
where
S: Stream<Item = I> + Send + 'a,
{
Self {
inner: stream.boxed(),
}
}

#[cfg(target_arch = "wasm32")]
pub fn new<S>(stream: S) -> Self
where
S: Stream<Item = I> + 'a,
{
Self {
inner: stream.boxed_local(),
}
}
}

// Wrappers to deal with Send Bounds
#[cfg(not(target_arch = "wasm32"))]
pub struct FutureWrapper<'a, O> {
inner: Pin<Box<dyn Future<Output = O> + Send + 'a>>,
}

#[cfg(target_arch = "wasm32")]
pub struct FutureWrapper<'a, O> {
inner: Pin<Box<dyn Future<Output = O> + 'a>>,
}

impl<'a, O> Future for FutureWrapper<'a, O> {
type Output = O;

fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let inner = &mut self.inner;
futures::pin_mut!(inner);
inner.as_mut().poll(cx)
}
}

impl<'a, O> FutureWrapper<'a, O> {
#[cfg(not(target_arch = "wasm32"))]
pub fn new<F>(future: F) -> Self
where
F: Future<Output = O> + Send + 'a,
{
Self {
inner: future.boxed(),
}
}

#[cfg(target_arch = "wasm32")]
pub fn new<F>(future: F) -> Self
where
F: Future<Output = O> + 'a,
{
Self {
inner: future.boxed_local(),
}
}
}

pub fn rand_string<const N: usize>() -> String {
Alphanumeric.sample_string(&mut crypto_utils::rng(), N)
4 changes: 2 additions & 2 deletions common/src/test.rs
Original file line number Diff line number Diff line change
@@ -74,8 +74,8 @@ pub fn logger() {

INIT.get_or_init(|| {
let filter = EnvFilter::builder()
.with_default_directive(tracing::metadata::LevelFilter::DEBUG.into())
.from_env_lossy();
.parse_lossy("xmtp_mls::subscriptions=TRACE,xmtp_api_http=TRACE");
// .with_default_directive(tracing::metadata::LevelFilter::DEBUG.into())

tracing_subscriber::registry()
.with(tracing_wasm::WASMLayer::default())
8 changes: 1 addition & 7 deletions dev/test-wasm
Original file line number Diff line number Diff line change
@@ -6,10 +6,4 @@ WASM_BINDGEN_SPLIT_LINKED_MODULES=1 \
WASM_BINDGEN_TEST_TIMEOUT=120 \
CHROMEDRIVER="chromedriver" \
cargo test --target wasm32-unknown-unknown --release \
-p xmtp_mls -p xmtp_id -p xmtp_api_http -p xmtp_cryptography -- \
--skip xmtp_mls::subscriptions \
--skip xmtp_mls::groups::subscriptions \
--skip xmtp_mls::storage::encrypted_store::group_message::tests::it_cannot_insert_message_without_group \
--skip xmtp_mls::groups::tests::process_messages_abort_on_retryable_error \
--skip xmtp_mls::storage::encrypted_store::group::tests::test_find_groups \
--skip xmtp_mls::storage::encrypted_store::group::tests::test_installations_last_checked_is_updated
-p xmtp_mls -- xmtp_mls::subscriptions
8 changes: 1 addition & 7 deletions dev/test-wasm-interactive
Original file line number Diff line number Diff line change
@@ -15,10 +15,4 @@ WASM_BINDGEN_SPLIT_LINKED_MODULES=1 \
WASM_BINDGEN_TEST_ONLY_WEB=1 \
NO_HEADLESS=1 \
cargo test --target wasm32-unknown-unknown --release \
-p $PACKAGE -- \
--skip xmtp_mls::subscriptions \
--skip xmtp_mls::groups::subscriptions \
--skip xmtp_mls::storage::encrypted_store::group_message::tests::it_cannot_insert_message_without_group \
--skip xmtp_mls::groups::tests::process_messages_abort_on_retryable_error \
--skip xmtp_mls::storage::encrypted_store::group::tests::test_find_groups \
--skip xmtp_mls::storage::encrypted_store::group::tests::test_installations_last_checked_is_updated
-p $PACKAGE -- subscriptions::stream_conversations::test::test_stream_welcomes
1 change: 1 addition & 0 deletions xmtp_api_http/Cargo.toml
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] }
async-trait = "0.1"
bytes = "1.9"
pin-project-lite = "0.2.15"
xmtp_common.workspace = true

[dev-dependencies]
xmtp_proto = { path = "../xmtp_proto", features = ["test-utils"] }
313 changes: 196 additions & 117 deletions xmtp_api_http/src/http_stream.rs

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions xmtp_api_http/src/lib.rs
Original file line number Diff line number Diff line change
@@ -258,6 +258,7 @@ impl XmtpMlsStreams for XmtpHttpApiClient {
#[cfg(target_arch = "wasm32")]
type WelcomeMessageStream<'a> = stream::LocalBoxStream<'a, Result<WelcomeMessage, Error>>;

#[tracing::instrument(skip_all)]
async fn subscribe_group_messages(
&self,
request: SubscribeGroupMessagesRequest,
@@ -267,9 +268,11 @@ impl XmtpMlsStreams for XmtpHttpApiClient {
request,
self.endpoint(ApiEndpoints::SUBSCRIBE_GROUP_MESSAGES),
self.http_client.clone(),
))
)
.await?)
}

#[tracing::instrument(skip_all)]
async fn subscribe_welcome_messages(
&self,
request: SubscribeWelcomeMessagesRequest,
@@ -279,7 +282,8 @@ impl XmtpMlsStreams for XmtpHttpApiClient {
request,
self.endpoint(ApiEndpoints::SUBSCRIBE_WELCOME_MESSAGES),
self.http_client.clone(),
))
)
.await?)
}
}

5 changes: 4 additions & 1 deletion xmtp_mls/src/api/mls.rs
Original file line number Diff line number Diff line change
@@ -298,11 +298,14 @@ where
ApiClient: XmtpMlsStreams,
{
tracing::debug!(inbox_id = self.inbox_id, "subscribing to welcome messages");
// _NOTE_:
// Default ID Cursor should be one
// else we miss welcome messages
self.api_client
.subscribe_welcome_messages(SubscribeWelcomeMessagesRequest {
filters: vec![WelcomeFilterProto {
installation_key: installation_key.to_vec(),
id_cursor: id_cursor.unwrap_or(0),
id_cursor: id_cursor.unwrap_or(1),
}],
})
.await
21 changes: 20 additions & 1 deletion xmtp_mls/src/groups/validated_commit.rs
Original file line number Diff line number Diff line change
@@ -94,7 +94,7 @@ impl RetryableError for CommitValidationError {
}
}

#[derive(Debug, Clone, PartialEq, Hash)]
#[derive(Clone, PartialEq, Hash)]
pub struct CommitParticipant {
pub inbox_id: String,
pub installation_id: Vec<u8>,
@@ -103,6 +103,25 @@ pub struct CommitParticipant {
pub is_super_admin: bool,
}

impl std::fmt::Debug for CommitParticipant {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
ref inbox_id,
ref installation_id,
ref is_creator,
ref is_admin,
ref is_super_admin,
} = self;
write!(f, "CommitParticipant {{ inbox_id={}, installation_id={}, is_creator={}, is_admin={}, is_super_admin={} }}",
inbox_id,
hex::encode(&installation_id),
is_creator,
is_admin,
is_super_admin,
)
}
}

impl CommitParticipant {
pub fn build(
inbox_id: String,
47 changes: 2 additions & 45 deletions xmtp_mls/src/subscriptions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use futures::{FutureExt, Stream, StreamExt};
use futures::{Stream, StreamExt};
use prost::Message;
use std::{collections::HashSet, future::Future, pin::Pin, sync::Arc, task::Poll};
use std::{collections::HashSet, sync::Arc};
use tokio::{
sync::{broadcast, oneshot},
task::JoinHandle,
@@ -47,49 +47,6 @@ impl RetryableError for LocalEventError {
}
}

// Wrappers to deal with Send Bounds
#[cfg(not(target_arch = "wasm32"))]
pub struct FutureWrapper<'a, O> {
inner: Pin<Box<dyn Future<Output = O> + Send + 'a>>,
}

#[cfg(target_arch = "wasm32")]
pub struct FutureWrapper<'a, O> {
inner: Pin<Box<dyn Future<Output = O> + 'a>>,
}

impl<'a, O> Future for FutureWrapper<'a, O> {
type Output = O;

fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let inner = &mut self.inner;
futures::pin_mut!(inner);
inner.as_mut().poll(cx)
}
}

impl<'a, O> FutureWrapper<'a, O> {
#[cfg(not(target_arch = "wasm32"))]
pub fn new<F>(future: F) -> Self
where
F: Future<Output = O> + Send + 'a,
{
Self {
inner: future.boxed(),
}
}

#[cfg(target_arch = "wasm32")]
pub fn new<F>(future: F) -> Self
where
F: Future<Output = O> + 'a,
{
Self {
inner: future.boxed_local(),
}
}
}

#[derive(Debug)]
/// Wrapper around a [`tokio::task::JoinHandle`] but with a oneshot receiver
/// which allows waiting for a `with_callback` stream fn to be ready for stream items.
67 changes: 64 additions & 3 deletions xmtp_mls/src/subscriptions/stream_all.rs
Original file line number Diff line number Diff line change
@@ -20,9 +20,10 @@ use xmtp_proto::api_client::{trait_impls::XmtpApi, XmtpMlsStreams};
use super::{
stream_conversations::{StreamConversations, WelcomesApiSubscription},
stream_messages::StreamGroupMessages,
FutureWrapper, Result, SubscribeError,
Result, SubscribeError,
};
use pin_project_lite::pin_project;
use xmtp_common::FutureWrapper;

pin_project! {
pub(super) struct StreamAllMessages<'a, C, Conversations, Messages> {
@@ -123,8 +124,8 @@ where
use std::task::Poll::*;
use SwitchProject::*;
let this = self.as_mut().project();

let state = this.state.project();

match state {
Waiting => {
if let Ready(msg) = this.messages.poll_next(cx) {
@@ -253,13 +254,16 @@ mod tests {

alix_group.send_message(b"first").await.unwrap();
assert_msg!(stream, "first");
tracing::info!("\n\nGOT FIRST\n\n");
let bo_group = bo.create_dm(caro_wallet.get_address()).await.unwrap();

bo_group.send_message(b"second").await.unwrap();
assert_msg!(stream, "second");
tracing::info!("\n\nGOT SECOND\n\n");

alix_group.send_message(b"third").await.unwrap();
assert_msg!(stream, "third");
tracing::info!("\n\nGOT THIRD\n\n");

let alix_group_2 = alix
.create_group(None, GroupMetadataOptions::default())
@@ -271,13 +275,16 @@ mod tests {

alix_group.send_message(b"fourth").await.unwrap();
assert_msg!(stream, "fourth");
tracing::info!("\n\nGOT FOURTH\n\n");

alix_group_2.send_message(b"fifth").await.unwrap();
assert_msg!(stream, "fifth");
tracing::info!("\n\nGOT FIFTH\n\n");
}

#[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))]
async fn test_stream_all_messages_unchanging_group_list() {
xmtp_common::logger();
let alix = ClientBuilder::new_test_client(&generate_local_wallet()).await;
let bo = ClientBuilder::new_test_client(&generate_local_wallet()).await;
let caro = ClientBuilder::new_test_client(&generate_local_wallet()).await;
@@ -338,8 +345,10 @@ mod tests {
.unwrap();
futures::pin_mut!(stream);
alix_dm.send_message("first".as_bytes()).await.unwrap();
tracing::info!("sent first message");
alix_group.send_message("second".as_bytes()).await.unwrap();
assert_msg!(stream, "second");
tracing::info!("got second Group-Only message");

// Start a stream with only dms
// Wait for 2 seconds for the group creation to be streamed
@@ -351,6 +360,7 @@ mod tests {
alix_group.send_message("first".as_bytes()).await.unwrap();
alix_dm.send_message("second".as_bytes()).await.unwrap();
assert_msg!(stream, "second");
tracing::info!("Got second DM ONLy Message");

// Start a stream with all conversations
// Wait for 2 seconds for the group creation to be streamed
@@ -411,7 +421,7 @@ mod tests {
});

let mut messages = Vec::new();
let _ = tokio::time::timeout(core::time::Duration::from_secs(60), async {
let _ = tokio::time::timeout(core::time::Duration::from_secs(30), async {
futures::pin_mut!(stream);
loop {
if messages.len() < 100 {
@@ -437,4 +447,55 @@ mod tests {
tracing::info!("Total Messages: {}", messages.len());
assert_eq!(messages.len(), 100);
}

#[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))]
async fn test_stream_all_messages_detached_group_changes() {
let caro = ClientBuilder::new_test_client(&generate_local_wallet()).await;
let hale = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await);
tracing::info!(inbox_id = hale.inbox_id(), "HALE");
let stream = caro.stream_all_messages(None).await.unwrap();

let caro_id = caro.inbox_id().to_string();
crate::spawn(None, async move {
let caro = &caro_id;
for i in 0..5 {
let new_group = hale
.create_group(None, GroupMetadataOptions::default())
.unwrap();
new_group.add_members_by_inbox_id(&[caro]).await.unwrap();
tracing::info!("\n\n HALE SENDING {i} \n\n");
new_group
.send_message(b"spam from new group")
.await
.unwrap();
}
});

let mut messages = Vec::new();
let _ = tokio::time::timeout(core::time::Duration::from_secs(5), async {
futures::pin_mut!(stream);
loop {
if messages.len() < 5 {
if let Some(Ok(msg)) = stream.next().await {
tracing::info!(
message_id = hex::encode(&msg.id),
sender_inbox_id = msg.sender_inbox_id,
sender_installation_id = hex::encode(&msg.sender_installation_id),
group_id = hex::encode(&msg.group_id),
"GOT MESSAGE {}, text={}",
messages.len(),
String::from_utf8_lossy(msg.decrypted_message_bytes.as_slice())
);
messages.push(msg)
}
} else {
break;
}
}
})
.await;

tracing::info!("Total Messages: {}", messages.len());
assert_eq!(messages.len(), 5);
}
}
123 changes: 86 additions & 37 deletions xmtp_mls/src/subscriptions/stream_conversations.rs
Original file line number Diff line number Diff line change
@@ -19,7 +19,8 @@ use xmtp_proto::{
xmtp::mls::api::v1::{welcome_message, WelcomeMessage},
};

use super::{FutureWrapper, LocalEvents, Result, SubscribeError};
use super::{LocalEvents, Result, SubscribeError};
use xmtp_common::FutureWrapper;

#[derive(thiserror::Error, Debug)]
pub enum ConversationStreamError {
@@ -137,6 +138,15 @@ pin_project! {
}
}

impl<'a, C> ProcessState<'a, C> {
fn try_poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Option<Poll<Result<Option<(MlsGroup<C>, Option<i64>)>>>> {
match self.as_mut().project() {
ProcessProject::Waiting => None,
ProcessProject::Processing { future } => Some(future.poll(cx))
}
}
}

// we can't avoid the cfg(target_arch) without making the entire
// 'process_new_item' flow a Future, which makes this code
// significantly more difficult to modify. The other option is storing a
@@ -173,7 +183,8 @@ where
let provider = client.mls_provider()?;
let conn = provider.conn_ref();
let installation_key = client.installation_public_key();
let id_cursor = 0;
// _NOTE_ IdCursor should be one
let id_cursor = 1;
tracing::info!(
inbox_id = client.inbox_id(),
"Setting up conversation stream"
@@ -213,11 +224,16 @@ where
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
use std::task::Poll::*;
let mut this = self.as_mut().project();
match this.state.as_mut().project() {
ProcessProject::Waiting => {
use ProcessProject::*;

let this = self.as_mut().project();
let state = this.state.project();

match state {
Waiting => {
match this.inner.poll_next(cx) {
Ready(Some(item)) => {
let mut this = self.as_mut().project();
let future = ProcessWelcomeFuture::new(
this.known_welcome_ids.clone(),
this.client.clone(),
@@ -228,37 +244,62 @@ where
this.state.set(ProcessState::Processing {
future: FutureWrapper::new(future.process()),
});
cx.waker().wake_by_ref();
Pending
// try to process the future immediately
// this will return immediately if we have already processed the welcome
// and it exists in the db
let poll = this.state.try_poll(cx).expect("future just set & pinned in state");
self.as_mut().try_process(poll, cx)
}
// stream ended
Ready(None) => Ready(None),
Ready(None) => {
Ready(None)
},
Pending => {
cx.waker().wake_by_ref();
Pending
}
}
},
Processing { future } => {
let poll = future.poll(cx);
self.as_mut().try_process(poll, cx)
}
ProcessProject::Processing { future } => match future.poll(cx) {
Ready(Ok(Some((group, welcome_id)))) => {
if let Some(id) = welcome_id {
this.known_welcome_ids.insert(id);
}
this.state.set(ProcessState::Waiting);
Ready(Some(Ok(group)))
}
// we are ignoring this payload
Ready(Ok(None)) => {
this.state.set(ProcessState::Waiting);
cx.waker().wake_by_ref();
Pending
}
Ready(Err(e)) => Ready(Some(Err(e))),
Pending => {
cx.waker().wake_by_ref();
Pending
}
}
}

impl<'a, C, Subscription> StreamConversations<'a, C, Subscription>
where
C: ScopedGroupClient + Clone + 'a,
Subscription: Stream<Item = Result<WelcomeOrGroup>> + 'a,
{
/// Try to process the welcome future
fn try_process(
mut self: Pin<&mut Self>,
poll: Poll<Result<Option<(MlsGroup<C>, Option<i64>)>>>,
cx: &mut Context<'_>,
) -> Poll<Option<<Self as Stream>::Item>> {
use Poll::*;
let mut this = self.as_mut().project();
match poll {
Ready(Ok(Some((group, welcome_id)))) => {
if let Some(id) = welcome_id {
this.known_welcome_ids.insert(id);
}
},
this.state.set(ProcessState::Waiting);
Ready(Some(Ok(group)))
}
// we are ignoring this payload
Ready(Ok(None)) => {
this.state.as_mut().set(ProcessState::Waiting);
cx.waker().wake_by_ref();
Pending
}
Ready(Err(e)) => Ready(Some(Err(e))),
Pending => {
cx.waker().wake_by_ref();
Pending
}
}
}
}
@@ -316,10 +357,22 @@ where
use WelcomeOrGroup::*;
let (group, welcome_id) = match self.item {
Welcome(ref w) => {
let (group, id) = self.on_welcome(w).await?;
let welcome = extract_welcome_message(w)?;
let id = welcome.id as i64;
// try to load it from store first and avoid overhead
// of processing a welcome & erroring
// for immediate return, this must stay in the top-level future,
// to avoid a possible yield on the await in on_welcome.
if self.known_welcome_ids.contains(&id) {
tracing::debug!("Found existing welcome. Returning from db & skipping processing");
return Ok(Some(self.load_from_store(id).map(|(g, v)| (g, Some(v)))?));
}

let (group, id) = self.on_welcome(welcome).await?;
(group, Some(id))
}
Group(id) => {
tracing::debug!("Stream conversations got existing group, pulling from db.");
let (group, stored_group) =
MlsGroup::new_validated(self.client, id, &self.provider)?;
(group, stored_group.welcome_id)
@@ -334,22 +387,16 @@ where
}

/// process a new welcome, returning the Group & Welcome ID
async fn on_welcome(&self, welcome: &WelcomeMessage) -> Result<(MlsGroup<C>, i64)> {
async fn on_welcome(&self, welcome: &welcome_message::V1) -> Result<(MlsGroup<C>, i64)> {
let welcome_message::V1 {
id,
created_ns: _,
ref installation_key,
ref data,
ref hpke_public_key,
} = extract_welcome_message(welcome)?;
} = welcome;
let id = *id as i64;

// try to load it from store first and avoid overhead
// of processing a welcome & erroring
if self.known_welcome_ids.contains(&id) {
return self.load_from_store(id);
}

let Self {
ref client,
ref provider,
@@ -377,7 +424,8 @@ where
.await;

if let Err(e) = group {
// try to load it from the store again
tracing::info!("Processing welcome failed, trying to load existing..");
// try to load it from the store again in case of race
return self
.load_from_store(id)
.map_err(|_| SubscribeError::from(e));
@@ -424,6 +472,7 @@ mod test {

#[wasm_bindgen_test(unsupported = tokio::test(flavor = "multi_thread", worker_threads = 10))]
async fn test_stream_welcomes() {
xmtp_common::logger();
let alice = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await);
let bob = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await);
let alice_bob_group = alice
3 changes: 2 additions & 1 deletion xmtp_mls/src/subscriptions/stream_messages.rs
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ use std::{
task::{Context, Poll},
};

use super::{FutureWrapper, Result, SubscribeError};
use super::{Result, SubscribeError};
use crate::{
api::GroupFilter,
groups::{scoped_client::ScopedGroupClient, MlsGroup},
@@ -17,6 +17,7 @@ use crate::{
};
use futures::Stream;
use pin_project_lite::pin_project;
use xmtp_common::FutureWrapper;
use xmtp_common::{retry_async, Retry};
use xmtp_id::InboxIdRef;
use xmtp_proto::{