Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sdk: Allow to limit the number of concurrent network requests #3625

1 change: 1 addition & 0 deletions crates/matrix-sdk/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Breaking changes:

Additions:

- new `RequestConfig.max_concurrent_requests` which allows to limit the maximum number of concurrent requests the internal HTTP client issues (all others have to wait until the number drops below that threshold again)
- Expose new method `Client::Oidc::login_with_qr_code()`.
([#3466](https://github.com/matrix-org/matrix-rust-sdk/pull/3466))
- Add the `ClientBuilder::add_root_certificates()` method which re-exposes the
Expand Down
18 changes: 16 additions & 2 deletions crates/matrix-sdk/src/config/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::{
fmt::{self, Debug},
num::NonZeroUsize,
time::Duration,
};

Expand Down Expand Up @@ -44,18 +45,21 @@ pub struct RequestConfig {
pub(crate) timeout: Duration,
pub(crate) retry_limit: Option<u64>,
pub(crate) retry_timeout: Option<Duration>,
pub(crate) max_concurrent_requests: Option<NonZeroUsize>,
pub(crate) force_auth: bool,
}

#[cfg(not(tarpaulin_include))]
impl Debug for RequestConfig {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { timeout, retry_limit, retry_timeout, force_auth } = self;
let Self { timeout, retry_limit, retry_timeout, force_auth, max_concurrent_requests } =
self;

let mut res = fmt.debug_struct("RequestConfig");
res.field("timeout", timeout)
.maybe_field("retry_limit", retry_limit)
.maybe_field("retry_timeout", retry_timeout);
.maybe_field("retry_timeout", retry_timeout)
.maybe_field("max_concurrent_requests", max_concurrent_requests);

if *force_auth {
res.field("force_auth", &true);
Expand All @@ -71,6 +75,7 @@ impl Default for RequestConfig {
timeout: DEFAULT_REQUEST_TIMEOUT,
retry_limit: Default::default(),
retry_timeout: Default::default(),
max_concurrent_requests: Default::default(),
force_auth: false,
}
}
Expand Down Expand Up @@ -106,6 +111,15 @@ impl RequestConfig {
self
}

/// The total limit of request that are pending or run concurrently.
/// Any additional request beyond that number will be waiting until another
/// concurrent requests finished. Requests are queued fairly.
#[must_use]
pub fn max_concurrent_requests(mut self, limit: Option<NonZeroUsize>) -> Self {
self.max_concurrent_requests = limit;
self
}

/// Set the timeout duration for all HTTP requests.
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> Self {
Expand Down
147 changes: 146 additions & 1 deletion crates/matrix-sdk/src/http_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::{
any::type_name,
fmt::Debug,
num::NonZeroUsize,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
Expand All @@ -30,6 +31,7 @@ use ruma::api::{
error::{FromHttpResponseError, IntoHttpError},
AuthScheme, MatrixVersion, OutgoingRequest, SendAccessToken,
};
use tokio::sync::{Semaphore, SemaphorePermit};
use tracing::{debug, field::debug, instrument, trace};

use crate::{config::RequestConfig, error::HttpError};
Expand All @@ -48,16 +50,48 @@ pub(crate) use native::HttpSettings;

pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);

#[derive(Clone, Debug)]
struct MaybeSemaphore(Arc<Option<Semaphore>>);

#[allow(dead_code)] // false-positive lint: we never use it but only hold it for the drop
struct MaybeSemaphorePermit<'a>(Option<SemaphorePermit<'a>>);

impl MaybeSemaphore {
fn new(max: Option<NonZeroUsize>) -> Self {
let inner = max.map(|i| Semaphore::new(i.into()));
MaybeSemaphore(Arc::new(inner))
}

async fn acquire(&self) -> MaybeSemaphorePermit<'_> {
match self.0.as_ref() {
Some(inner) => {
// This can only ever error if the semaphore was closed,
// which we never do, so we can safely ignore any error case
MaybeSemaphorePermit(inner.acquire().await.ok())
}
None => MaybeSemaphorePermit(None),
}
}
}

#[derive(Clone, Debug)]
pub(crate) struct HttpClient {
pub(crate) inner: reqwest::Client,
pub(crate) request_config: RequestConfig,
concurrent_request_semaphore: MaybeSemaphore,
next_request_id: Arc<AtomicU64>,
}

impl HttpClient {
pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
HttpClient { inner, request_config, next_request_id: AtomicU64::new(0).into() }
HttpClient {
inner,
request_config,
concurrent_request_semaphore: MaybeSemaphore::new(
request_config.max_concurrent_requests,
),
next_request_id: AtomicU64::new(0).into(),
}
}

fn get_request_id(&self) -> String {
Expand Down Expand Up @@ -184,6 +218,9 @@ impl HttpClient {
request
};

// will be automatically dropped at the end of this function
let _handle = self.concurrent_request_semaphore.acquire().await;

debug!("Sending request");

// There's a bunch of state in send_request, factor out a pinned inner
Expand Down Expand Up @@ -259,3 +296,111 @@ impl tower::Service<http_old::Request<Bytes>> for HttpClient {
Box::pin(fut)
}
}

#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use std::{
num::NonZeroUsize,
sync::{
atomic::{AtomicU8, Ordering},
Arc,
},
time::Duration,
};

use matrix_sdk_test::{async_test, test_json};
use wiremock::{
matchers::{method, path},
Mock, Request, ResponseTemplate,
};

use crate::{
http_client::RequestConfig,
test_utils::{set_client_session, test_client_builder_with_server},
};

#[async_test]
async fn ensure_concurrent_request_limit_is_observed() {
let (client_builder, server) = test_client_builder_with_server().await;
let client = client_builder
.request_config(RequestConfig::default().max_concurrent_requests(NonZeroUsize::new(5)))
.build()
.await
.unwrap();

set_client_session(&client).await;

let counter = Arc::new(AtomicU8::new(0));
let inner_counter = counter.clone();

Mock::given(method("GET"))
.and(path("/_matrix/client/versions"))
.respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
.mount(&server)
.await;

Mock::given(method("GET"))
.and(path("_matrix/client/r0/account/whoami"))
.respond_with(move |_req: &Request| {
inner_counter.fetch_add(1, Ordering::SeqCst);
// we stall the requests
ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
})
.mount(&server)
.await;

let bg_task = tokio::spawn(async move {
futures_util::future::join_all((0..10).map(|_| client.whoami())).await
});

// give it some time to issue the requests
tokio::time::sleep(Duration::from_millis(300)).await;

assert_eq!(
counter.load(Ordering::SeqCst),
5,
"More requests passed than the limit we configured"
);
bg_task.abort();
}

#[async_test]
async fn ensure_no_max_concurrent_request_does_not_limit() {
let (client_builder, server) = test_client_builder_with_server().await;
let client = client_builder
.request_config(RequestConfig::default().max_concurrent_requests(None))
.build()
.await
.unwrap();

set_client_session(&client).await;

let counter = Arc::new(AtomicU8::new(0));
let inner_counter = counter.clone();

Mock::given(method("GET"))
.and(path("/_matrix/client/versions"))
.respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
.mount(&server)
.await;

Mock::given(method("GET"))
.and(path("_matrix/client/r0/account/whoami"))
.respond_with(move |_req: &Request| {
inner_counter.fetch_add(1, Ordering::SeqCst);
ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
})
.mount(&server)
.await;

let bg_task = tokio::spawn(async move {
futures_util::future::join_all((0..254).map(|_| client.whoami())).await
});

// give it some time to issue the requests
tokio::time::sleep(Duration::from_secs(1)).await;

assert_eq!(counter.load(Ordering::SeqCst), 254, "Not all requests passed through");
bg_task.abort();
}
}
Loading