Skip to content

Commit

Permalink
Allow to limit the number of concurrent requests made by the sdk (#3625)
Browse files Browse the repository at this point in the history
Add a new `max_concurrent_requests` parameter in the `RequestConfig` limits the number of http(s) requests the internal sdk client issues concurrently (if > 0). The default behavior is the same as before: there is no limit on concurrent requests issued.

This is especially useful for resource constrained platforms (e.g. mobile platforms), and if your pattern might lead to issuing many requests at the same time (like downloading and caching all avatars at startup).

- [x] Public API changes documented in changelogs (optional)

Signed-off-by: Benjamin Kampmann <[email protected]>
  • Loading branch information
gnunicorn authored Jul 4, 2024
1 parent aaccfdf commit d49cb54
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 3 deletions.
1 change: 1 addition & 0 deletions crates/matrix-sdk/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Breaking changes:

Additions:

- new `RequestConfig.max_concurrent_requests` which allows to limit the maximum number of concurrent requests the internal HTTP client issues (all others have to wait until the number drops below that threshold again)
- Expose new method `Client::Oidc::login_with_qr_code()`.
([#3466](https://github.com/matrix-org/matrix-rust-sdk/pull/3466))
- Add the `ClientBuilder::add_root_certificates()` method which re-exposes the
Expand Down
18 changes: 16 additions & 2 deletions crates/matrix-sdk/src/config/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::{
fmt::{self, Debug},
num::NonZeroUsize,
time::Duration,
};

Expand Down Expand Up @@ -44,18 +45,21 @@ pub struct RequestConfig {
pub(crate) timeout: Duration,
pub(crate) retry_limit: Option<u64>,
pub(crate) retry_timeout: Option<Duration>,
pub(crate) max_concurrent_requests: Option<NonZeroUsize>,
pub(crate) force_auth: bool,
}

#[cfg(not(tarpaulin_include))]
impl Debug for RequestConfig {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { timeout, retry_limit, retry_timeout, force_auth } = self;
let Self { timeout, retry_limit, retry_timeout, force_auth, max_concurrent_requests } =
self;

let mut res = fmt.debug_struct("RequestConfig");
res.field("timeout", timeout)
.maybe_field("retry_limit", retry_limit)
.maybe_field("retry_timeout", retry_timeout);
.maybe_field("retry_timeout", retry_timeout)
.maybe_field("max_concurrent_requests", max_concurrent_requests);

if *force_auth {
res.field("force_auth", &true);
Expand All @@ -71,6 +75,7 @@ impl Default for RequestConfig {
timeout: DEFAULT_REQUEST_TIMEOUT,
retry_limit: Default::default(),
retry_timeout: Default::default(),
max_concurrent_requests: Default::default(),
force_auth: false,
}
}
Expand Down Expand Up @@ -106,6 +111,15 @@ impl RequestConfig {
self
}

/// The total limit of request that are pending or run concurrently.
/// Any additional request beyond that number will be waiting until another
/// concurrent requests finished. Requests are queued fairly.
#[must_use]
pub fn max_concurrent_requests(mut self, limit: Option<NonZeroUsize>) -> Self {
self.max_concurrent_requests = limit;
self
}

/// Set the timeout duration for all HTTP requests.
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> Self {
Expand Down
147 changes: 146 additions & 1 deletion crates/matrix-sdk/src/http_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::{
any::type_name,
fmt::Debug,
num::NonZeroUsize,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
Expand All @@ -30,6 +31,7 @@ use ruma::api::{
error::{FromHttpResponseError, IntoHttpError},
AuthScheme, MatrixVersion, OutgoingRequest, SendAccessToken,
};
use tokio::sync::{Semaphore, SemaphorePermit};
use tracing::{debug, field::debug, instrument, trace};

use crate::{config::RequestConfig, error::HttpError};
Expand All @@ -48,16 +50,48 @@ pub(crate) use native::HttpSettings;

pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);

#[derive(Clone, Debug)]
struct MaybeSemaphore(Arc<Option<Semaphore>>);

#[allow(dead_code)] // false-positive lint: we never use it but only hold it for the drop
struct MaybeSemaphorePermit<'a>(Option<SemaphorePermit<'a>>);

impl MaybeSemaphore {
fn new(max: Option<NonZeroUsize>) -> Self {
let inner = max.map(|i| Semaphore::new(i.into()));
MaybeSemaphore(Arc::new(inner))
}

async fn acquire(&self) -> MaybeSemaphorePermit<'_> {
match self.0.as_ref() {
Some(inner) => {
// This can only ever error if the semaphore was closed,
// which we never do, so we can safely ignore any error case
MaybeSemaphorePermit(inner.acquire().await.ok())
}
None => MaybeSemaphorePermit(None),
}
}
}

#[derive(Clone, Debug)]
pub(crate) struct HttpClient {
pub(crate) inner: reqwest::Client,
pub(crate) request_config: RequestConfig,
concurrent_request_semaphore: MaybeSemaphore,
next_request_id: Arc<AtomicU64>,
}

impl HttpClient {
pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
HttpClient { inner, request_config, next_request_id: AtomicU64::new(0).into() }
HttpClient {
inner,
request_config,
concurrent_request_semaphore: MaybeSemaphore::new(
request_config.max_concurrent_requests,
),
next_request_id: AtomicU64::new(0).into(),
}
}

fn get_request_id(&self) -> String {
Expand Down Expand Up @@ -184,6 +218,9 @@ impl HttpClient {
request
};

// will be automatically dropped at the end of this function
let _handle = self.concurrent_request_semaphore.acquire().await;

debug!("Sending request");

// There's a bunch of state in send_request, factor out a pinned inner
Expand Down Expand Up @@ -259,3 +296,111 @@ impl tower::Service<http_old::Request<Bytes>> for HttpClient {
Box::pin(fut)
}
}

#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use std::{
num::NonZeroUsize,
sync::{
atomic::{AtomicU8, Ordering},
Arc,
},
time::Duration,
};

use matrix_sdk_test::{async_test, test_json};
use wiremock::{
matchers::{method, path},
Mock, Request, ResponseTemplate,
};

use crate::{
http_client::RequestConfig,
test_utils::{set_client_session, test_client_builder_with_server},
};

#[async_test]
async fn ensure_concurrent_request_limit_is_observed() {
let (client_builder, server) = test_client_builder_with_server().await;
let client = client_builder
.request_config(RequestConfig::default().max_concurrent_requests(NonZeroUsize::new(5)))
.build()
.await
.unwrap();

set_client_session(&client).await;

let counter = Arc::new(AtomicU8::new(0));
let inner_counter = counter.clone();

Mock::given(method("GET"))
.and(path("/_matrix/client/versions"))
.respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
.mount(&server)
.await;

Mock::given(method("GET"))
.and(path("_matrix/client/r0/account/whoami"))
.respond_with(move |_req: &Request| {
inner_counter.fetch_add(1, Ordering::SeqCst);
// we stall the requests
ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
})
.mount(&server)
.await;

let bg_task = tokio::spawn(async move {
futures_util::future::join_all((0..10).map(|_| client.whoami())).await
});

// give it some time to issue the requests
tokio::time::sleep(Duration::from_millis(300)).await;

assert_eq!(
counter.load(Ordering::SeqCst),
5,
"More requests passed than the limit we configured"
);
bg_task.abort();
}

#[async_test]
async fn ensure_no_max_concurrent_request_does_not_limit() {
let (client_builder, server) = test_client_builder_with_server().await;
let client = client_builder
.request_config(RequestConfig::default().max_concurrent_requests(None))
.build()
.await
.unwrap();

set_client_session(&client).await;

let counter = Arc::new(AtomicU8::new(0));
let inner_counter = counter.clone();

Mock::given(method("GET"))
.and(path("/_matrix/client/versions"))
.respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
.mount(&server)
.await;

Mock::given(method("GET"))
.and(path("_matrix/client/r0/account/whoami"))
.respond_with(move |_req: &Request| {
inner_counter.fetch_add(1, Ordering::SeqCst);
ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
})
.mount(&server)
.await;

let bg_task = tokio::spawn(async move {
futures_util::future::join_all((0..254).map(|_| client.whoami())).await
});

// give it some time to issue the requests
tokio::time::sleep(Duration::from_secs(1)).await;

assert_eq!(counter.load(Ordering::SeqCst), 254, "Not all requests passed through");
bg_task.abort();
}
}

0 comments on commit d49cb54

Please sign in to comment.