Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sdk: Allow to limit the number of concurrent network requests #3625

1 change: 1 addition & 0 deletions crates/matrix-sdk/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Breaking changes:

Additions:

- new `RequestConfig.max_concurrent_requests` which allows to limit the maximum number of concurrent requests the internal HTTP client issues (all others have to wait until the number drops below that threshold again)
- Expose new method `Client::Oidc::login_with_qr_code()`.
([#3466](https://github.com/matrix-org/matrix-rust-sdk/pull/3466))
- Add the `ClientBuilder::add_root_certificates()` method which re-exposes the
Expand Down
18 changes: 16 additions & 2 deletions crates/matrix-sdk/src/config/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::{
fmt::{self, Debug},
num::NonZeroUsize,
time::Duration,
};

Expand Down Expand Up @@ -44,18 +45,21 @@ pub struct RequestConfig {
pub(crate) timeout: Duration,
pub(crate) retry_limit: Option<u64>,
pub(crate) retry_timeout: Option<Duration>,
pub(crate) max_concurrent_requests: Option<NonZeroUsize>,
pub(crate) force_auth: bool,
}

#[cfg(not(tarpaulin_include))]
impl Debug for RequestConfig {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { timeout, retry_limit, retry_timeout, force_auth } = self;
let Self { timeout, retry_limit, retry_timeout, force_auth, max_concurrent_requests } =
self;

let mut res = fmt.debug_struct("RequestConfig");
res.field("timeout", timeout)
.maybe_field("retry_limit", retry_limit)
.maybe_field("retry_timeout", retry_timeout);
.maybe_field("retry_timeout", retry_timeout)
.maybe_field("max_concurrent_requests", max_concurrent_requests);

if *force_auth {
res.field("force_auth", &true);
Expand All @@ -71,6 +75,7 @@ impl Default for RequestConfig {
timeout: DEFAULT_REQUEST_TIMEOUT,
retry_limit: Default::default(),
retry_timeout: Default::default(),
max_concurrent_requests: Default::default(),
force_auth: false,
}
}
Expand Down Expand Up @@ -106,6 +111,15 @@ impl RequestConfig {
self
}

/// The total limit of request that are pending or run concurrently.
/// Any additional request beyond that number will be waiting until another
/// concurrent requests finished. Requests are queued fairly.
#[must_use]
pub fn max_concurrent_requests(mut self, limit: Option<NonZeroUsize>) -> Self {
self.max_concurrent_requests = limit;
self
}

/// Set the timeout duration for all HTTP requests.
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> Self {
Expand Down
142 changes: 141 additions & 1 deletion crates/matrix-sdk/src/http_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::{
any::type_name,
fmt::Debug,
num::NonZeroUsize,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
Expand All @@ -30,6 +31,7 @@ use ruma::api::{
error::{FromHttpResponseError, IntoHttpError},
AuthScheme, MatrixVersion, OutgoingRequest, SendAccessToken,
};
use tokio::sync::{Semaphore, SemaphorePermit};
use tracing::{debug, field::debug, instrument, trace};

use crate::{config::RequestConfig, error::HttpError};
Expand All @@ -48,16 +50,47 @@ pub(crate) use native::HttpSettings;

pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);

#[derive(Clone, Debug)]
struct MaybeSemaphore(Arc<Option<Semaphore>>);

#[allow(dead_code)] // holding this until drop is all we are doing
struct MaybeSemaphorePermit<'a>(Option<SemaphorePermit<'a>>);

impl MaybeSemaphore {
fn new(max: Option<NonZeroUsize>) -> Self {
let inner = max.map(|i| Semaphore::new(i.into()));
MaybeSemaphore(Arc::new(inner))
}

async fn acquire(&self) -> MaybeSemaphorePermit<'_> {
match self.0.as_ref() {
Some(inner) => {
// ignoring errors as we never close this
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of errors do you mean here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the semaphore has been closed, this returns an AcquireError.

This only ever happens when there the Semphore was explicitly closed, which this MaybeSemaphore doesn't ever do. So the error can never occur and can be ignored.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please update the comment to make this clear?

MaybeSemaphorePermit(inner.acquire().await.ok())
}
None => MaybeSemaphorePermit(None),
}
}
}

#[derive(Clone, Debug)]
pub(crate) struct HttpClient {
pub(crate) inner: reqwest::Client,
pub(crate) request_config: RequestConfig,
concurrent_request_semaphore: MaybeSemaphore,
next_request_id: Arc<AtomicU64>,
}

impl HttpClient {
pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
HttpClient { inner, request_config, next_request_id: AtomicU64::new(0).into() }
HttpClient {
inner,
request_config,
concurrent_request_semaphore: MaybeSemaphore::new(
request_config.max_concurrent_requests,
),
next_request_id: AtomicU64::new(0).into(),
}
}

fn get_request_id(&self) -> String {
Expand Down Expand Up @@ -184,6 +217,9 @@ impl HttpClient {
request
};

// will be automatically dropped at the end of this function
let _handle = self.concurrent_request_semaphore.acquire().await;

debug!("Sending request");

// There's a bunch of state in send_request, factor out a pinned inner
Expand Down Expand Up @@ -259,3 +295,107 @@ impl tower::Service<http_old::Request<Bytes>> for HttpClient {
Box::pin(fut)
}
}

#[cfg(test)]
mod tests {
use crate::{
http_client::RequestConfig,
test_utils::{set_client_session, test_client_builder_with_server},
};
use matrix_sdk_test::async_test;
use std::{
num::NonZeroUsize,
sync::{
atomic::{AtomicU8, Ordering},
Arc,
},
time::Duration,
};
use wiremock::{matchers::method, Mock, Request, ResponseTemplate};

#[async_test]
async fn ensure_concurrent_request_limit_is_observed() {
let (client_builder, server) = test_client_builder_with_server().await;
let mut client = client_builder
.request_config(RequestConfig::default().max_concurrent_requests(NonZeroUsize::new(5)))
.build()
.await
.unwrap();

set_client_session(&mut client).await;

let counter = Arc::new(AtomicU8::new(0));
let inner_counter = counter.clone();

Mock::given(method("GET"))
.respond_with(move |_req: &Request| {
inner_counter.fetch_add(1, Ordering::SeqCst);
// we stall the requests
ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
})
.mount(&server)
.await;

let _bg_task = tokio::spawn(async move {
let mut pollers = Vec::new();

for _n in 0..10 {
pollers.push(client.whoami());
}
// issue parallel execution
futures_util::future::join_all(pollers).await
});

// give it a moment to issue the requests
tokio::time::sleep(Duration::from_secs(2)).await;

assert_eq!(
counter.load(Ordering::SeqCst),
5,
"More requests passed than the limit we configured"
);
}

#[async_test]
async fn ensure_no_max_concurrent_request_does_not_limit() {
let (client_builder, server) = test_client_builder_with_server().await;
let mut client = client_builder
.request_config(RequestConfig::default().max_concurrent_requests(None))
.build()
.await
.unwrap();

set_client_session(&mut client).await;

let counter = Arc::new(AtomicU8::new(0));
let inner_counter = counter.clone();

Mock::given(method("GET"))
.respond_with(move |_req: &Request| {
inner_counter.fetch_add(1, Ordering::SeqCst);
// we stall the requests
ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
})
.mount(&server)
.await;

let _bg_task = tokio::spawn(async move {
let mut pollers = Vec::new();

for _n in 1..254 {
pollers.push(client.whoami());
}
// issue parallel execution
futures_util::future::join_all(pollers).await
});

// give it a moment to issue the requests
tokio::time::sleep(Duration::from_secs(2)).await;

assert_eq!(
counter.load(Ordering::SeqCst),
254,
"More requests passed than the limit we configured"
);
}
}
Loading