diff --git a/bindings/nodejs/CHANGELOG.md b/bindings/nodejs/CHANGELOG.md index 990d60264f..d783102ee0 100644 --- a/bindings/nodejs/CHANGELOG.md +++ b/bindings/nodejs/CHANGELOG.md @@ -19,7 +19,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Security --> -## 1.0.9 - 2023-09-06 +## 1.0.9 - 2023-09-07 + +### Added + +- `IClientOptions::maxParallelApiRequests`; ### Fixed diff --git a/bindings/nodejs/lib/types/client/client-options.ts b/bindings/nodejs/lib/types/client/client-options.ts index 9036e532cd..cc5491b038 100644 --- a/bindings/nodejs/lib/types/client/client-options.ts +++ b/bindings/nodejs/lib/types/client/client-options.ts @@ -36,6 +36,8 @@ export interface IClientOptions { powWorkerCount?: number; /** Whether the PoW should be done locally or remotely. */ localPow?: boolean; + /** The maximum parallel API requests. */ + maxParallelApiRequests?: number; } /** Time duration */ diff --git a/bindings/nodejs/package.json b/bindings/nodejs/package.json index e70cb1eddf..17e1c54c2d 100644 --- a/bindings/nodejs/package.json +++ b/bindings/nodejs/package.json @@ -1,6 +1,6 @@ { "name": "@iota/sdk", - "version": "1.0.8", + "version": "1.0.9", "description": "Node.js binding to the IOTA SDK library", "main": "out/index.js", "types": "out/index.d.ts", diff --git a/bindings/python/CHANGELOG.md b/bindings/python/CHANGELOG.md index d121c23629..2f3c0ac01d 100644 --- a/bindings/python/CHANGELOG.md +++ b/bindings/python/CHANGELOG.md @@ -19,6 +19,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Security --> +## 1.0.2 - 2023-MM-DD + +### Added + +- `ClientOptions::maxParallelApiRequests`; + ## 1.0.1 - 2023-08-23 ### Fixed diff --git a/bindings/python/iota_sdk/types/client_options.py b/bindings/python/iota_sdk/types/client_options.py index dcbe0d71b6..dd2631507b 100644 --- a/bindings/python/iota_sdk/types/client_options.py +++ b/bindings/python/iota_sdk/types/client_options.py @@ -84,6 +84,8 @@ class ClientOptions: Timeout when sending a block that requires remote proof of work. powWorkerCount (int): The amount of threads to be used for proof of work. + maxParallelApiRequests (int): + The maximum parallel API requests. """ primaryNode: Optional[str] = None primaryPowNode: Optional[str] = None @@ -103,6 +105,7 @@ class ClientOptions: apiTimeout: Optional[Duration] = None remotePowTimeout: Optional[Duration] = None powWorkerCount: Optional[int] = None + maxParallelApiRequests: Optional[int] = None def as_dict(self): config = {k: v for k, v in self.__dict__.items() if v is not None} diff --git a/sdk/CHANGELOG.md b/sdk/CHANGELOG.md index 6f6a2d0eff..bda9c242be 100644 --- a/sdk/CHANGELOG.md +++ b/sdk/CHANGELOG.md @@ -19,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Security --> -## 1.0.3 - 2023-MM-DD +## 1.0.3 - 2023-09-07 ### Added @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Wallet::get_chrysalis_data()` method; - `PrivateKeySecretManager` and `SecretManager::PrivateKey`; - `SecretManager::from` impl for variants; +- `Client` requests now obey a maximum concurrency using a request pool (set via `ClientBuilder::with_max_parallel_api_requests`); ### Fixed diff --git a/sdk/src/client/builder.rs b/sdk/src/client/builder.rs index 576042a576..8282525dbf 100644 --- a/sdk/src/client/builder.rs +++ b/sdk/src/client/builder.rs @@ -48,6 +48,10 @@ pub struct ClientBuilder { #[cfg(not(target_family = "wasm"))] #[serde(default, skip_serializing_if = "Option::is_none")] pub pow_worker_count: Option, + /// The maximum parallel API requests + #[cfg(not(target_family = "wasm"))] + #[serde(default = "default_max_parallel_api_requests")] + pub max_parallel_api_requests: usize, } fn default_api_timeout() -> Duration { @@ -58,6 +62,11 @@ fn default_remote_pow_timeout() -> Duration { DEFAULT_REMOTE_POW_API_TIMEOUT } +#[cfg(not(target_family = "wasm"))] +fn default_max_parallel_api_requests() -> usize { + super::constants::MAX_PARALLEL_API_REQUESTS +} + impl Default for NetworkInfo { fn default() -> Self { Self { @@ -82,6 +91,8 @@ impl Default for ClientBuilder { remote_pow_timeout: DEFAULT_REMOTE_POW_API_TIMEOUT, #[cfg(not(target_family = "wasm"))] pow_worker_count: None, + #[cfg(not(target_family = "wasm"))] + max_parallel_api_requests: super::constants::MAX_PARALLEL_API_REQUESTS, } } } @@ -237,6 +248,13 @@ impl ClientBuilder { self } + /// Set maximum parallel API requests. + #[cfg(not(target_family = "wasm"))] + pub fn with_max_parallel_api_requests(mut self, max_parallel_api_requests: usize) -> Self { + self.max_parallel_api_requests = max_parallel_api_requests; + self + } + /// Build the Client instance. #[cfg(not(target_family = "wasm"))] pub async fn finish(self) -> Result { @@ -269,6 +287,7 @@ impl ClientBuilder { sender: RwLock::new(mqtt_event_tx), receiver: RwLock::new(mqtt_event_rx), }, + request_pool: crate::client::request_pool::RequestPool::new(self.max_parallel_api_requests), }); client_inner.sync_nodes(&nodes, ignore_node_health).await?; @@ -327,6 +346,8 @@ impl ClientBuilder { remote_pow_timeout: client.get_remote_pow_timeout().await, #[cfg(not(target_family = "wasm"))] pow_worker_count: *client.pow_worker_count.read().await, + #[cfg(not(target_family = "wasm"))] + max_parallel_api_requests: client.request_pool.size().await, } } } diff --git a/sdk/src/client/core.rs b/sdk/src/client/core.rs index fb34c5764c..0b7249865b 100644 --- a/sdk/src/client/core.rs +++ b/sdk/src/client/core.rs @@ -13,6 +13,8 @@ use { tokio::sync::watch::{Receiver as WatchReceiver, Sender as WatchSender}, }; +#[cfg(not(target_family = "wasm"))] +use super::request_pool::RequestPool; #[cfg(target_family = "wasm")] use crate::client::constants::CACHE_NETWORK_INFO_TIMEOUT_IN_SECONDS; use crate::{ @@ -56,6 +58,8 @@ pub struct ClientInner { pub(crate) mqtt: MqttInner, #[cfg(target_family = "wasm")] pub(crate) last_sync: tokio::sync::Mutex>, + #[cfg(not(target_family = "wasm"))] + pub(crate) request_pool: RequestPool, } #[derive(Default)] @@ -83,10 +87,13 @@ pub(crate) struct MqttInner { impl std::fmt::Debug for Client { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut d = f.debug_struct("Client"); - d.field("node_manager", &self.inner.node_manager); + d.field("node_manager", &self.node_manager); #[cfg(feature = "mqtt")] - d.field("broker_options", &self.inner.mqtt.broker_options); - d.field("network_info", &self.inner.network_info).finish() + d.field("broker_options", &self.mqtt.broker_options); + d.field("network_info", &self.network_info); + #[cfg(not(target_family = "wasm"))] + d.field("request_pool", &self.request_pool); + d.finish() } } diff --git a/sdk/src/client/mod.rs b/sdk/src/client/mod.rs index aabc2d440a..7f6d79f8a3 100644 --- a/sdk/src/client/mod.rs +++ b/sdk/src/client/mod.rs @@ -42,6 +42,8 @@ pub mod core; pub mod error; pub mod node_api; pub mod node_manager; +#[cfg(not(target_family = "wasm"))] +pub(crate) mod request_pool; pub mod secret; pub mod storage; #[cfg(feature = "stronghold")] diff --git a/sdk/src/client/node_api/core/mod.rs b/sdk/src/client/node_api/core/mod.rs index 9ae39e13b5..0e0e57108a 100644 --- a/sdk/src/client/node_api/core/mod.rs +++ b/sdk/src/client/node_api/core/mod.rs @@ -5,8 +5,6 @@ pub mod routes; -#[cfg(not(target_family = "wasm"))] -use crate::client::constants::MAX_PARALLEL_API_REQUESTS; use crate::{ client::{Client, Result}, types::block::output::{OutputId, OutputMetadata, OutputWithMetadata}, @@ -15,87 +13,29 @@ use crate::{ impl Client { /// Request outputs by their output ID in parallel pub async fn get_outputs(&self, output_ids: &[OutputId]) -> Result> { - #[cfg(target_family = "wasm")] - let outputs = futures::future::try_join_all(output_ids.iter().map(|id| self.get_output(id))).await?; - - #[cfg(not(target_family = "wasm"))] - let outputs = - futures::future::try_join_all(output_ids.chunks(MAX_PARALLEL_API_REQUESTS).map(|output_ids_chunk| { - let client = self.clone(); - let output_ids_chunk = output_ids_chunk.to_vec(); - async move { - tokio::spawn(async move { - futures::future::try_join_all(output_ids_chunk.iter().map(|id| client.get_output(id))).await - }) - .await? - } - })) - .await? - .into_iter() - .flatten() - .collect(); - - Ok(outputs) + futures::future::try_join_all(output_ids.iter().map(|id| self.get_output(id))).await } /// Request outputs by their output ID in parallel, ignoring failed requests /// Useful to get data about spent outputs, that might not be pruned yet pub async fn get_outputs_ignore_errors(&self, output_ids: &[OutputId]) -> Result> { - #[cfg(target_family = "wasm")] - let outputs = futures::future::join_all(output_ids.iter().map(|id| self.get_output(id))) - .await - .into_iter() - .filter_map(Result::ok) - .collect(); - - #[cfg(not(target_family = "wasm"))] - let outputs = - futures::future::try_join_all(output_ids.chunks(MAX_PARALLEL_API_REQUESTS).map(|output_ids_chunk| { - let client = self.clone(); - let output_ids_chunk = output_ids_chunk.to_vec(); - tokio::spawn(async move { - futures::future::join_all(output_ids_chunk.iter().map(|id| client.get_output(id))) - .await - .into_iter() - .filter_map(Result::ok) - .collect::>() - }) - })) - .await? - .into_iter() - .flatten() - .collect(); - - Ok(outputs) + Ok( + futures::future::join_all(output_ids.iter().map(|id| self.get_output(id))) + .await + .into_iter() + .filter_map(Result::ok) + .collect(), + ) } /// Requests metadata for outputs by their output ID in parallel, ignoring failed requests pub async fn get_outputs_metadata_ignore_errors(&self, output_ids: &[OutputId]) -> Result> { - #[cfg(target_family = "wasm")] - let metadata = futures::future::join_all(output_ids.iter().map(|id| self.get_output_metadata(id))) - .await - .into_iter() - .filter_map(Result::ok) - .collect(); - - #[cfg(not(target_family = "wasm"))] - let metadata = - futures::future::try_join_all(output_ids.chunks(MAX_PARALLEL_API_REQUESTS).map(|output_ids_chunk| { - let client = self.clone(); - let output_ids_chunk = output_ids_chunk.to_vec(); - tokio::spawn(async move { - futures::future::join_all(output_ids_chunk.iter().map(|id| client.get_output_metadata(id))) - .await - .into_iter() - .filter_map(Result::ok) - .collect::>() - }) - })) - .await? - .into_iter() - .flatten() - .collect(); - - Ok(metadata) + Ok( + futures::future::join_all(output_ids.iter().map(|id| self.get_output_metadata(id))) + .await + .into_iter() + .filter_map(Result::ok) + .collect(), + ) } } diff --git a/sdk/src/client/node_api/core/routes.rs b/sdk/src/client/node_api/core/routes.rs index b0e2120507..ec459f59d2 100644 --- a/sdk/src/client/node_api/core/routes.rs +++ b/sdk/src/client/node_api/core/routes.rs @@ -76,21 +76,13 @@ impl ClientInner { pub async fn get_routes(&self) -> Result { let path = "api/routes"; - self.node_manager - .read() - .await - .get_request(path, None, self.get_timeout().await, false, false) - .await + self.get_request(path, None, false, false).await } /// Returns general information about the node. /// GET /api/core/v2/info pub async fn get_info(&self) -> Result { - self.node_manager - .read() - .await - .get_request(INFO_PATH, None, self.get_timeout().await, false, false) - .await + self.get_request(INFO_PATH, None, false, false).await } // Tangle routes. @@ -100,12 +92,7 @@ impl ClientInner { pub async fn get_tips(&self) -> Result> { let path = "api/core/v2/tips"; - let response = self - .node_manager - .read() - .await - .get_request::(path, None, self.get_timeout().await, false, false) - .await?; + let response = self.get_request::(path, None, false, false).await?; Ok(response.tips) } @@ -224,12 +211,7 @@ impl ClientInner { pub async fn get_block(&self, block_id: &BlockId) -> Result { let path = &format!("api/core/v2/blocks/{block_id}"); - let dto = self - .node_manager - .read() - .await - .get_request::(path, None, self.get_timeout().await, false, true) - .await?; + let dto = self.get_request::(path, None, false, true).await?; Ok(Block::try_from_dto_with_params( dto, @@ -242,11 +224,7 @@ impl ClientInner { pub async fn get_block_raw(&self, block_id: &BlockId) -> Result> { let path = &format!("api/core/v2/blocks/{block_id}"); - self.node_manager - .read() - .await - .get_request_bytes(path, None, self.get_timeout().await) - .await + self.get_request_bytes(path, None).await } /// Returns the metadata of a block. @@ -254,11 +232,7 @@ impl ClientInner { pub async fn get_block_metadata(&self, block_id: &BlockId) -> Result { let path = &format!("api/core/v2/blocks/{block_id}/metadata"); - self.node_manager - .read() - .await - .get_request(path, None, self.get_timeout().await, true, true) - .await + self.get_request(path, None, true, true).await } // UTXO routes. @@ -268,12 +242,7 @@ impl ClientInner { pub async fn get_output(&self, output_id: &OutputId) -> Result { let path = &format!("api/core/v2/outputs/{output_id}"); - let response: OutputWithMetadataResponse = self - .node_manager - .read() - .await - .get_request(path, None, self.get_timeout().await, false, true) - .await?; + let response: OutputWithMetadataResponse = self.get_request(path, None, false, true).await?; let token_supply = self.get_token_supply().await?; let output = Output::try_from_dto_with_params(response.output, token_supply)?; @@ -286,11 +255,7 @@ impl ClientInner { pub async fn get_output_raw(&self, output_id: &OutputId) -> Result> { let path = &format!("api/core/v2/outputs/{output_id}"); - self.node_manager - .read() - .await - .get_request_bytes(path, None, self.get_timeout().await) - .await + self.get_request_bytes(path, None).await } /// Get the metadata for a given `OutputId` (TransactionId + output_index). @@ -298,11 +263,7 @@ impl ClientInner { pub async fn get_output_metadata(&self, output_id: &OutputId) -> Result { let path = &format!("api/core/v2/outputs/{output_id}/metadata"); - self.node_manager - .read() - .await - .get_request::(path, None, self.get_timeout().await, false, true) - .await + self.get_request::(path, None, false, true).await } /// Gets all stored receipts. @@ -310,12 +271,7 @@ impl ClientInner { pub async fn get_receipts(&self) -> Result> { let path = &"api/core/v2/receipts"; - let resp = self - .node_manager - .read() - .await - .get_request::(path, None, DEFAULT_API_TIMEOUT, false, false) - .await?; + let resp = self.get_request::(path, None, false, false).await?; Ok(resp.receipts) } @@ -325,12 +281,7 @@ impl ClientInner { pub async fn get_receipts_migrated_at(&self, milestone_index: u32) -> Result> { let path = &format!("api/core/v2/receipts/{milestone_index}"); - let resp = self - .node_manager - .read() - .await - .get_request::(path, None, DEFAULT_API_TIMEOUT, false, false) - .await?; + let resp = self.get_request::(path, None, false, false).await?; Ok(resp.receipts) } @@ -341,11 +292,7 @@ impl ClientInner { pub async fn get_treasury(&self) -> Result { let path = "api/core/v2/treasury"; - self.node_manager - .read() - .await - .get_request(path, None, DEFAULT_API_TIMEOUT, false, false) - .await + self.get_request(path, None, false, false).await } /// Returns the block, as object, that was included in the ledger for a given TransactionId. @@ -353,12 +300,7 @@ impl ClientInner { pub async fn get_included_block(&self, transaction_id: &TransactionId) -> Result { let path = &format!("api/core/v2/transactions/{transaction_id}/included-block"); - let dto = self - .node_manager - .read() - .await - .get_request::(path, None, self.get_timeout().await, true, true) - .await?; + let dto = self.get_request::(path, None, true, true).await?; Ok(Block::try_from_dto_with_params( dto, @@ -371,11 +313,7 @@ impl ClientInner { pub async fn get_included_block_raw(&self, transaction_id: &TransactionId) -> Result> { let path = &format!("api/core/v2/transactions/{transaction_id}/included-block"); - self.node_manager - .read() - .await - .get_request_bytes(path, None, self.get_timeout().await) - .await + self.get_request_bytes(path, None).await } /// Returns the metadata of the block that was included in the ledger for a given TransactionId. @@ -383,11 +321,7 @@ impl ClientInner { pub async fn get_included_block_metadata(&self, transaction_id: &TransactionId) -> Result { let path = &format!("api/core/v2/transactions/{transaction_id}/included-block/metadata"); - self.node_manager - .read() - .await - .get_request(path, None, self.get_timeout().await, true, true) - .await + self.get_request(path, None, true, true).await } // Milestones routes. @@ -397,12 +331,7 @@ impl ClientInner { pub async fn get_milestone_by_id(&self, milestone_id: &MilestoneId) -> Result { let path = &format!("api/core/v2/milestones/{milestone_id}"); - let dto = self - .node_manager - .read() - .await - .get_request::(path, None, self.get_timeout().await, false, true) - .await?; + let dto = self.get_request::(path, None, false, true).await?; Ok(MilestonePayload::try_from_dto_with_params( dto, @@ -415,11 +344,7 @@ impl ClientInner { pub async fn get_milestone_by_id_raw(&self, milestone_id: &MilestoneId) -> Result> { let path = &format!("api/core/v2/milestones/{milestone_id}"); - self.node_manager - .read() - .await - .get_request_bytes(path, None, self.get_timeout().await) - .await + self.get_request_bytes(path, None).await } /// Gets all UTXO changes of a milestone by its milestone id. @@ -427,11 +352,7 @@ impl ClientInner { pub async fn get_utxo_changes_by_id(&self, milestone_id: &MilestoneId) -> Result { let path = &format!("api/core/v2/milestones/{milestone_id}/utxo-changes"); - self.node_manager - .read() - .await - .get_request(path, None, self.get_timeout().await, false, false) - .await + self.get_request(path, None, false, false).await } /// Gets the milestone by the given milestone index. @@ -439,12 +360,7 @@ impl ClientInner { pub async fn get_milestone_by_index(&self, index: u32) -> Result { let path = &format!("api/core/v2/milestones/by-index/{index}"); - let dto = self - .node_manager - .read() - .await - .get_request::(path, None, self.get_timeout().await, false, true) - .await?; + let dto = self.get_request::(path, None, false, true).await?; Ok(MilestonePayload::try_from_dto_with_params( dto, @@ -457,11 +373,7 @@ impl ClientInner { pub async fn get_milestone_by_index_raw(&self, index: u32) -> Result> { let path = &format!("api/core/v2/milestones/by-index/{index}"); - self.node_manager - .read() - .await - .get_request_bytes(path, None, self.get_timeout().await) - .await + self.get_request_bytes(path, None).await } /// Gets all UTXO changes of a milestone by its milestone index. @@ -469,11 +381,7 @@ impl ClientInner { pub async fn get_utxo_changes_by_index(&self, index: u32) -> Result { let path = &format!("api/core/v2/milestones/by-index/{index}/utxo-changes"); - self.node_manager - .read() - .await - .get_request(path, None, self.get_timeout().await, false, false) - .await + self.get_request(path, None, false, false).await } // Peers routes. @@ -482,12 +390,7 @@ impl ClientInner { pub async fn get_peers(&self) -> Result> { let path = "api/core/v2/peers"; - let resp = self - .node_manager - .read() - .await - .get_request::>(path, None, self.get_timeout().await, false, false) - .await?; + let resp = self.get_request::>(path, None, false, false).await?; Ok(resp) } diff --git a/sdk/src/client/node_api/indexer/mod.rs b/sdk/src/client/node_api/indexer/mod.rs index 128e891f68..6ff4876411 100644 --- a/sdk/src/client/node_api/indexer/mod.rs +++ b/sdk/src/client/node_api/indexer/mod.rs @@ -33,13 +33,9 @@ impl ClientInner { while let Some(cursor) = { let output_ids_response = self - .node_manager - .read() - .await .get_request::( route, query_parameters.to_query_string().as_deref(), - self.get_timeout().await, need_quorum, prefer_permanode, ) diff --git a/sdk/src/client/node_api/participation.rs b/sdk/src/client/node_api/participation.rs index 4962d851ce..928a63a56c 100644 --- a/sdk/src/client/node_api/participation.rs +++ b/sdk/src/client/node_api/participation.rs @@ -29,22 +29,14 @@ impl ClientInner { ParticipationEventType::Staking => "type=1", }); - self.node_manager - .read() - .await - .get_request(route, query, self.get_timeout().await, false, false) - .await + self.get_request(route, query, false, false).await } /// RouteParticipationEvent is the route to access a single participation by its ID. pub async fn event(&self, event_id: &ParticipationEventId) -> Result { let route = format!("api/participation/v1/events/{event_id}"); - self.node_manager - .read() - .await - .get_request(&route, None, self.get_timeout().await, false, false) - .await + self.get_request(&route, None, false, false).await } /// RouteParticipationEventStatus is the route to access the status of a single participation by its ID. @@ -55,28 +47,20 @@ impl ClientInner { ) -> Result { let route = format!("api/participation/v1/events/{event_id}/status"); - self.node_manager - .read() - .await - .get_request( - &route, - milestone_index.map(|index| index.to_string()).as_deref(), - self.get_timeout().await, - false, - false, - ) - .await + self.get_request( + &route, + milestone_index.map(|index| index.to_string()).as_deref(), + false, + false, + ) + .await } /// RouteOutputStatus is the route to get the vote status for a given output ID. pub async fn output_status(&self, output_id: &OutputId) -> Result { let route = format!("api/participation/v1/outputs/{output_id}"); - self.node_manager - .read() - .await - .get_request(&route, None, self.get_timeout().await, false, false) - .await + self.get_request(&route, None, false, false).await } /// RouteAddressBech32Status is the route to get the staking rewards for the given bech32 address. @@ -86,11 +70,7 @@ impl ClientInner { ) -> Result { let route = format!("api/participation/v1/addresses/{}", bech32_address.convert()?); - self.node_manager - .read() - .await - .get_request(&route, None, self.get_timeout().await, false, false) - .await + self.get_request(&route, None, false, false).await } /// RouteAddressBech32Outputs is the route to get the outputs for the given bech32 address. @@ -100,10 +80,6 @@ impl ClientInner { ) -> Result { let route = format!("api/participation/v1/addresses/{}/outputs", bech32_address.convert()?); - self.node_manager - .read() - .await - .get_request(&route, None, self.get_timeout().await, false, false) - .await + self.get_request(&route, None, false, false).await } } diff --git a/sdk/src/client/node_api/plugin/mod.rs b/sdk/src/client/node_api/plugin/mod.rs index aa7260a8f4..493932da79 100644 --- a/sdk/src/client/node_api/plugin/mod.rs +++ b/sdk/src/client/node_api/plugin/mod.rs @@ -27,17 +27,11 @@ impl ClientInner { let req_method = reqwest::Method::from_str(&method); - let node_manager = self.node_manager.read().await; let path = format!("{}{}{}", base_plugin_path, endpoint, query_params.join("&")); - let timeout = self.get_timeout().await; match req_method { - Ok(Method::GET) => node_manager.get_request(&path, None, timeout, false, false).await, - Ok(Method::POST) => { - node_manager - .post_request_json(&path, timeout, request_object.into(), true) - .await - } + Ok(Method::GET) => self.get_request(&path, None, false, false).await, + Ok(Method::POST) => self.post_request_json(&path, request_object.into(), true).await, _ => Err(crate::client::Error::Node( crate::client::node_api::error::Error::NotSupported(method.to_string()), )), diff --git a/sdk/src/client/node_manager/mod.rs b/sdk/src/client/node_manager/mod.rs index cfa8003269..f6468e2bfd 100644 --- a/sdk/src/client/node_manager/mod.rs +++ b/sdk/src/client/node_manager/mod.rs @@ -11,13 +11,18 @@ pub(crate) mod syncing; use std::{ collections::{HashMap, HashSet}, + fmt::Debug, sync::RwLock, time::Duration, }; +use serde::{de::DeserializeOwned, Serialize}; use serde_json::Value; use self::{http_client::HttpClient, node::Node}; +use super::ClientInner; +#[cfg(not(target_family = "wasm"))] +use crate::client::request_pool::RateLimitExt; use crate::{ client::{ error::{Error, Result}, @@ -42,7 +47,7 @@ pub struct NodeManager { pub(crate) http_client: HttpClient, } -impl std::fmt::Debug for NodeManager { +impl Debug for NodeManager { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut d = f.debug_struct("NodeManager"); d.field("primary_node", &self.primary_node); @@ -58,6 +63,43 @@ impl std::fmt::Debug for NodeManager { } } +impl ClientInner { + pub(crate) async fn get_request( + &self, + path: &str, + query: Option<&str>, + need_quorum: bool, + prefer_permanode: bool, + ) -> Result { + let node_manager = self.node_manager.read().await; + let request = node_manager.get_request(path, query, self.get_timeout().await, need_quorum, prefer_permanode); + #[cfg(not(target_family = "wasm"))] + let request = request.rate_limit(&self.request_pool); + request.await + } + + pub(crate) async fn get_request_bytes(&self, path: &str, query: Option<&str>) -> Result> { + let node_manager = self.node_manager.read().await; + let request = node_manager.get_request_bytes(path, query, self.get_timeout().await); + #[cfg(not(target_family = "wasm"))] + let request = request.rate_limit(&self.request_pool); + request.await + } + + pub(crate) async fn post_request_json( + &self, + path: &str, + json: Value, + local_pow: bool, + ) -> Result { + let node_manager = self.node_manager.read().await; + let request = node_manager.post_request_json(path, self.get_timeout().await, json, local_pow); + #[cfg(not(target_family = "wasm"))] + let request = request.rate_limit(&self.request_pool); + request.await + } +} + impl NodeManager { pub(crate) fn builder() -> NodeManagerBuilder { NodeManagerBuilder::new() @@ -164,7 +206,7 @@ impl NodeManager { Ok(nodes_with_modified_url) } - pub(crate) async fn get_request( + pub(crate) async fn get_request( &self, path: &str, query: Option<&str>, @@ -312,7 +354,7 @@ impl NodeManager { Err(error.unwrap()) } - pub(crate) async fn post_request_bytes( + pub(crate) async fn post_request_bytes( &self, path: &str, timeout: Duration, @@ -341,7 +383,7 @@ impl NodeManager { Err(error.unwrap()) } - pub(crate) async fn post_request_json( + pub(crate) async fn post_request_json( &self, path: &str, timeout: Duration, diff --git a/sdk/src/client/request_pool.rs b/sdk/src/client/request_pool.rs new file mode 100644 index 0000000000..d040b2c664 --- /dev/null +++ b/sdk/src/client/request_pool.rs @@ -0,0 +1,93 @@ +// Copyright 2023 IOTA Stiftung +// SPDX-License-Identifier: Apache-2.0 + +use alloc::sync::Arc; + +use async_trait::async_trait; +use futures::Future; +use tokio::sync::{ + mpsc::{UnboundedReceiver, UnboundedSender}, + RwLock, +}; + +#[derive(Debug, Clone)] +pub(crate) struct RequestPool { + inner: Arc>, +} + +#[derive(Debug)] +pub(crate) struct RequestPoolInner { + sender: UnboundedSender<()>, + recv: UnboundedReceiver<()>, + size: usize, +} + +#[derive(Debug)] +pub(crate) struct Requester { + sender: UnboundedSender<()>, +} + +impl RequestPool { + pub(crate) fn new(size: usize) -> Self { + Self { + inner: Arc::new(RwLock::new(RequestPoolInner::new(size))), + } + } + + pub(crate) async fn borrow(&self) -> Requester { + // Get permission to request + let mut lock = self.write().await; + lock.recv.recv().await; + let sender = lock.sender.clone(); + drop(lock); + Requester { sender } + } + + pub(crate) async fn size(&self) -> usize { + self.read().await.size + } + + pub(crate) async fn resize(&self, new_size: usize) { + *self.write().await = RequestPoolInner::new(new_size); + } +} + +impl core::ops::Deref for RequestPool { + type Target = RwLock; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl RequestPoolInner { + fn new(size: usize) -> Self { + let (sender, recv) = tokio::sync::mpsc::unbounded_channel(); + // Prepare the channel with the requesters + for _ in 0..size { + sender.send(()).ok(); + } + Self { sender, recv, size } + } +} + +impl Drop for Requester { + fn drop(&mut self) { + // This can only fail if the receiver is closed, in which case we don't care. + self.sender.send(()).ok(); + } +} + +#[async_trait] +pub(crate) trait RateLimitExt: Future { + async fn rate_limit(self, request_pool: &RequestPool) -> Self::Output + where + Self: Sized, + { + let requester = request_pool.borrow().await; + let output = self.await; + drop(requester); + output + } +} +impl RateLimitExt for F {} diff --git a/sdk/src/wallet/core/builder.rs b/sdk/src/wallet/core/builder.rs index 978ee198b2..31e2bba876 100644 --- a/sdk/src/wallet/core/builder.rs +++ b/sdk/src/wallet/core/builder.rs @@ -251,7 +251,7 @@ where #[cfg(feature = "storage")] pub(crate) async fn from_wallet(wallet: &Wallet) -> Self { Self { - client_options: Some(ClientOptions::from_client(wallet.client()).await), + client_options: Some(wallet.client_options().await), coin_type: Some(wallet.coin_type.load(Ordering::Relaxed)), storage_options: Some(wallet.storage_options.clone()), secret_manager: Some(wallet.secret_manager.clone()), diff --git a/sdk/src/wallet/core/operations/client.rs b/sdk/src/wallet/core/operations/client.rs index bf8c883adb..a1a5f6d654 100644 --- a/sdk/src/wallet/core/operations/client.rs +++ b/sdk/src/wallet/core/operations/client.rs @@ -42,6 +42,8 @@ where remote_pow_timeout, #[cfg(not(target_family = "wasm"))] pow_worker_count, + #[cfg(not(target_family = "wasm"))] + max_parallel_api_requests, } = client_options; self.client .update_node_manager(node_manager_builder.build(HashMap::new())) @@ -50,6 +52,8 @@ where *self.client.api_timeout.write().await = api_timeout; *self.client.remote_pow_timeout.write().await = remote_pow_timeout; #[cfg(not(target_family = "wasm"))] + self.client.request_pool.resize(max_parallel_api_requests).await; + #[cfg(not(target_family = "wasm"))] { *self.client.pow_worker_count.write().await = pow_worker_count; }