diff --git a/Cargo.lock b/Cargo.lock index 1808274fc6..d21869027f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1664,6 +1664,7 @@ dependencies = [ "num_cpus", "once_cell", "packable", + "pin-project", "prefix-hex", "primitive-types", "rand", diff --git a/sdk/Cargo.toml b/sdk/Cargo.toml index 791550ea69..d2d2d00937 100644 --- a/sdk/Cargo.toml +++ b/sdk/Cargo.toml @@ -86,6 +86,7 @@ rocksdb = { version = "0.21.0", default-features = false, features = [ rumqttc = { version = "0.22.0", default-features = false, features = [ "websocket", ], optional = true } +pin-project = { version = "1.1.3", default-features = false, optional = true } serde_repr = { version = "0.1.16", default-features = false, optional = true } thiserror = { version = "1.0.46", default-features = false, optional = true } time = { version = "0.3.25", default-features = false, features = [ @@ -206,6 +207,8 @@ client = [ "iota-crypto/keccak", "iota-crypto/bip44", "iota-crypto/random", + "dep:pin-project", + "rand" ] wallet = ["client"] diff --git a/sdk/src/client/api/block_builder/pow.rs b/sdk/src/client/api/block_builder/pow.rs index d8a11d3f41..c2e6e1a1f8 100644 --- a/sdk/src/client/api/block_builder/pow.rs +++ b/sdk/src/client/api/block_builder/pow.rs @@ -8,11 +8,11 @@ use crate::pow::miner::{Miner, MinerBuilder, MinerCancel}; #[cfg(target_family = "wasm")] use crate::pow::wasm_miner::{SingleThreadedMiner, SingleThreadedMinerBuilder}; use crate::{ - client::{ClientInner, Error, Result}, + client::{Client, Error, Result}, types::block::{parent::Parents, payload::Payload, Block, BlockBuilder, Error as BlockError}, }; -impl ClientInner { +impl Client { /// Finishes the block with local PoW if needed. /// Without local PoW, it will finish the block with a 0 nonce. pub async fn finish_block_builder(&self, parents: Option, payload: Option) -> Result { diff --git a/sdk/src/client/builder.rs b/sdk/src/client/builder.rs index 576042a576..a5ba229e8f 100644 --- a/sdk/src/client/builder.rs +++ b/sdk/src/client/builder.rs @@ -48,6 +48,9 @@ pub struct ClientBuilder { #[cfg(not(target_family = "wasm"))] #[serde(default, skip_serializing_if = "Option::is_none")] pub pow_worker_count: Option, + #[cfg(not(target_family = "wasm"))] + #[serde(default = "default_max_parallel_api_requests")] + pub max_parallel_api_requests: usize, } fn default_api_timeout() -> Duration { @@ -58,6 +61,11 @@ fn default_remote_pow_timeout() -> Duration { DEFAULT_REMOTE_POW_API_TIMEOUT } +#[cfg(not(target_family = "wasm"))] +fn default_max_parallel_api_requests() -> usize { + super::constants::MAX_PARALLEL_API_REQUESTS +} + impl Default for NetworkInfo { fn default() -> Self { Self { @@ -82,6 +90,8 @@ impl Default for ClientBuilder { remote_pow_timeout: DEFAULT_REMOTE_POW_API_TIMEOUT, #[cfg(not(target_family = "wasm"))] pow_worker_count: None, + #[cfg(not(target_family = "wasm"))] + max_parallel_api_requests: super::constants::MAX_PARALLEL_API_REQUESTS, } } } @@ -269,6 +279,7 @@ impl ClientBuilder { sender: RwLock::new(mqtt_event_tx), receiver: RwLock::new(mqtt_event_rx), }, + worker_pool: crate::client::worker::WorkerPool::new(self.max_parallel_api_requests), }); client_inner.sync_nodes(&nodes, ignore_node_health).await?; @@ -327,6 +338,8 @@ impl ClientBuilder { remote_pow_timeout: client.get_remote_pow_timeout().await, #[cfg(not(target_family = "wasm"))] pow_worker_count: *client.pow_worker_count.read().await, + #[cfg(not(target_family = "wasm"))] + max_parallel_api_requests: client.worker_pool.size().await, } } } diff --git a/sdk/src/client/core.rs b/sdk/src/client/core.rs index fb34c5764c..4caedb1598 100644 --- a/sdk/src/client/core.rs +++ b/sdk/src/client/core.rs @@ -13,6 +13,10 @@ use { tokio::sync::watch::{Receiver as WatchReceiver, Sender as WatchSender}, }; +#[cfg(not(target_family = "wasm"))] +pub use super::worker::TaskPriority; +#[cfg(not(target_family = "wasm"))] +use super::worker::WorkerPool; #[cfg(target_family = "wasm")] use crate::client::constants::CACHE_NETWORK_INFO_TIMEOUT_IN_SECONDS; use crate::{ @@ -56,6 +60,8 @@ pub struct ClientInner { pub(crate) mqtt: MqttInner, #[cfg(target_family = "wasm")] pub(crate) last_sync: tokio::sync::Mutex>, + #[cfg(not(target_family = "wasm"))] + pub(crate) worker_pool: WorkerPool, } #[derive(Default)] @@ -83,10 +89,13 @@ pub(crate) struct MqttInner { impl std::fmt::Debug for Client { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut d = f.debug_struct("Client"); - d.field("node_manager", &self.inner.node_manager); + d.field("node_manager", &self.node_manager); #[cfg(feature = "mqtt")] - d.field("broker_options", &self.inner.mqtt.broker_options); - d.field("network_info", &self.inner.network_info).finish() + d.field("broker_options", &self.mqtt.broker_options); + d.field("network_info", &self.network_info); + #[cfg(not(target_family = "wasm"))] + d.field("worker_pool", &self.worker_pool); + d.finish() } } @@ -95,9 +104,33 @@ impl Client { pub fn builder() -> ClientBuilder { ClientBuilder::new() } + + #[cfg(not(target_family = "wasm"))] + pub async fn rate_limit(&self, f: F) -> Fut::Output + where + F: 'static + Send + Sync + FnOnce(Self) -> Fut, + Fut: futures::Future + Send, + Fut::Output: Send, + { + self.prioritized_rate_limit(TaskPriority::Medium, f).await + } + + #[cfg(not(target_family = "wasm"))] + pub async fn prioritized_rate_limit(&self, priority: TaskPriority, f: F) -> Fut::Output + where + F: 'static + Send + Sync + FnOnce(Self) -> Fut, + Fut: futures::Future + Send, + Fut::Output: Send, + { + let client = self.clone(); + self.worker_pool + .process_task(priority, async move { f(client).await }) + .await + .unwrap() // TODO + } } -impl ClientInner { +impl Client { /// Gets the network related information such as network_id and min_pow_score /// and if it's the default one, sync it first and set the NetworkInfo. pub async fn get_network_info(&self) -> Result { diff --git a/sdk/src/client/mod.rs b/sdk/src/client/mod.rs index aabc2d440a..341b9a9d3a 100644 --- a/sdk/src/client/mod.rs +++ b/sdk/src/client/mod.rs @@ -48,6 +48,8 @@ pub mod storage; #[cfg_attr(docsrs, doc(cfg(feature = "stronghold")))] pub mod stronghold; pub mod utils; +#[cfg(not(target_family = "wasm"))] +pub(crate) mod worker; #[cfg(feature = "mqtt")] pub use self::node_api::mqtt; diff --git a/sdk/src/client/node_api/core/mod.rs b/sdk/src/client/node_api/core/mod.rs index 9ae39e13b5..0e0e57108a 100644 --- a/sdk/src/client/node_api/core/mod.rs +++ b/sdk/src/client/node_api/core/mod.rs @@ -5,8 +5,6 @@ pub mod routes; -#[cfg(not(target_family = "wasm"))] -use crate::client::constants::MAX_PARALLEL_API_REQUESTS; use crate::{ client::{Client, Result}, types::block::output::{OutputId, OutputMetadata, OutputWithMetadata}, @@ -15,87 +13,29 @@ use crate::{ impl Client { /// Request outputs by their output ID in parallel pub async fn get_outputs(&self, output_ids: &[OutputId]) -> Result> { - #[cfg(target_family = "wasm")] - let outputs = futures::future::try_join_all(output_ids.iter().map(|id| self.get_output(id))).await?; - - #[cfg(not(target_family = "wasm"))] - let outputs = - futures::future::try_join_all(output_ids.chunks(MAX_PARALLEL_API_REQUESTS).map(|output_ids_chunk| { - let client = self.clone(); - let output_ids_chunk = output_ids_chunk.to_vec(); - async move { - tokio::spawn(async move { - futures::future::try_join_all(output_ids_chunk.iter().map(|id| client.get_output(id))).await - }) - .await? - } - })) - .await? - .into_iter() - .flatten() - .collect(); - - Ok(outputs) + futures::future::try_join_all(output_ids.iter().map(|id| self.get_output(id))).await } /// Request outputs by their output ID in parallel, ignoring failed requests /// Useful to get data about spent outputs, that might not be pruned yet pub async fn get_outputs_ignore_errors(&self, output_ids: &[OutputId]) -> Result> { - #[cfg(target_family = "wasm")] - let outputs = futures::future::join_all(output_ids.iter().map(|id| self.get_output(id))) - .await - .into_iter() - .filter_map(Result::ok) - .collect(); - - #[cfg(not(target_family = "wasm"))] - let outputs = - futures::future::try_join_all(output_ids.chunks(MAX_PARALLEL_API_REQUESTS).map(|output_ids_chunk| { - let client = self.clone(); - let output_ids_chunk = output_ids_chunk.to_vec(); - tokio::spawn(async move { - futures::future::join_all(output_ids_chunk.iter().map(|id| client.get_output(id))) - .await - .into_iter() - .filter_map(Result::ok) - .collect::>() - }) - })) - .await? - .into_iter() - .flatten() - .collect(); - - Ok(outputs) + Ok( + futures::future::join_all(output_ids.iter().map(|id| self.get_output(id))) + .await + .into_iter() + .filter_map(Result::ok) + .collect(), + ) } /// Requests metadata for outputs by their output ID in parallel, ignoring failed requests pub async fn get_outputs_metadata_ignore_errors(&self, output_ids: &[OutputId]) -> Result> { - #[cfg(target_family = "wasm")] - let metadata = futures::future::join_all(output_ids.iter().map(|id| self.get_output_metadata(id))) - .await - .into_iter() - .filter_map(Result::ok) - .collect(); - - #[cfg(not(target_family = "wasm"))] - let metadata = - futures::future::try_join_all(output_ids.chunks(MAX_PARALLEL_API_REQUESTS).map(|output_ids_chunk| { - let client = self.clone(); - let output_ids_chunk = output_ids_chunk.to_vec(); - tokio::spawn(async move { - futures::future::join_all(output_ids_chunk.iter().map(|id| client.get_output_metadata(id))) - .await - .into_iter() - .filter_map(Result::ok) - .collect::>() - }) - })) - .await? - .into_iter() - .flatten() - .collect(); - - Ok(metadata) + Ok( + futures::future::join_all(output_ids.iter().map(|id| self.get_output_metadata(id))) + .await + .into_iter() + .filter_map(Result::ok) + .collect(), + ) } } diff --git a/sdk/src/client/node_api/core/routes.rs b/sdk/src/client/node_api/core/routes.rs index b0e2120507..cfdf43824c 100644 --- a/sdk/src/client/node_api/core/routes.rs +++ b/sdk/src/client/node_api/core/routes.rs @@ -11,7 +11,7 @@ use crate::{ client::{ constants::{DEFAULT_API_TIMEOUT, DEFAULT_USER_AGENT}, node_manager::node::{Node, NodeAuth}, - Client, ClientInner, Error, Result, + Client, Error, Result, }, types::{ api::core::response::{ @@ -43,7 +43,7 @@ pub struct NodeInfoWrapper { pub url: String, } -impl ClientInner { +impl Client { // Node routes. /// Returns the health of the node. @@ -76,20 +76,14 @@ impl ClientInner { pub async fn get_routes(&self) -> Result { let path = "api/routes"; - self.node_manager - .read() - .await - .get_request(path, None, self.get_timeout().await, false, false) + self.get_request(path, None, self.get_timeout().await, false, false) .await } /// Returns general information about the node. /// GET /api/core/v2/info pub async fn get_info(&self) -> Result { - self.node_manager - .read() - .await - .get_request(INFO_PATH, None, self.get_timeout().await, false, false) + self.get_request(INFO_PATH, None, self.get_timeout().await, false, false) .await } @@ -101,9 +95,6 @@ impl ClientInner { let path = "api/core/v2/tips"; let response = self - .node_manager - .read() - .await .get_request::(path, None, self.get_timeout().await, false, false) .await?; @@ -126,9 +117,6 @@ impl ClientInner { // fallback to local PoW if remote PoW fails let response = match self - .node_manager - .read() - .await .post_request_json::(path, timeout, serde_json::to_value(block_dto)?, local_pow) .await { @@ -155,10 +143,7 @@ impl ClientInner { }; let block_dto = BlockDto::from(&block_with_local_pow); - self.node_manager - .read() - .await - .post_request_json(path, timeout, serde_json::to_value(block_dto)?, true) + self.post_request_json(path, timeout, serde_json::to_value(block_dto)?, true) .await? } Err(e) => return Err(e), @@ -180,9 +165,6 @@ impl ClientInner { // fallback to local Pow if remote Pow fails let response = match self - .node_manager - .read() - .await .post_request_bytes::(path, timeout, &block.pack_to_vec(), local_pow) .await { @@ -207,10 +189,7 @@ impl ClientInner { return Err(e); } }; - self.node_manager - .read() - .await - .post_request_bytes(path, timeout, &block_with_local_pow.pack_to_vec(), true) + self.post_request_bytes(path, timeout, &block_with_local_pow.pack_to_vec(), true) .await? } Err(e) => return Err(e), @@ -225,9 +204,6 @@ impl ClientInner { let path = &format!("api/core/v2/blocks/{block_id}"); let dto = self - .node_manager - .read() - .await .get_request::(path, None, self.get_timeout().await, false, true) .await?; @@ -242,11 +218,7 @@ impl ClientInner { pub async fn get_block_raw(&self, block_id: &BlockId) -> Result> { let path = &format!("api/core/v2/blocks/{block_id}"); - self.node_manager - .read() - .await - .get_request_bytes(path, None, self.get_timeout().await) - .await + self.get_request_bytes(path, None, self.get_timeout().await).await } /// Returns the metadata of a block. @@ -254,11 +226,7 @@ impl ClientInner { pub async fn get_block_metadata(&self, block_id: &BlockId) -> Result { let path = &format!("api/core/v2/blocks/{block_id}/metadata"); - self.node_manager - .read() - .await - .get_request(path, None, self.get_timeout().await, true, true) - .await + self.get_request(path, None, self.get_timeout().await, true, true).await } // UTXO routes. @@ -269,9 +237,6 @@ impl ClientInner { let path = &format!("api/core/v2/outputs/{output_id}"); let response: OutputWithMetadataResponse = self - .node_manager - .read() - .await .get_request(path, None, self.get_timeout().await, false, true) .await?; @@ -286,11 +251,7 @@ impl ClientInner { pub async fn get_output_raw(&self, output_id: &OutputId) -> Result> { let path = &format!("api/core/v2/outputs/{output_id}"); - self.node_manager - .read() - .await - .get_request_bytes(path, None, self.get_timeout().await) - .await + self.get_request_bytes(path, None, self.get_timeout().await).await } /// Get the metadata for a given `OutputId` (TransactionId + output_index). @@ -298,10 +259,7 @@ impl ClientInner { pub async fn get_output_metadata(&self, output_id: &OutputId) -> Result { let path = &format!("api/core/v2/outputs/{output_id}/metadata"); - self.node_manager - .read() - .await - .get_request::(path, None, self.get_timeout().await, false, true) + self.get_request::(path, None, self.get_timeout().await, false, true) .await } @@ -311,9 +269,6 @@ impl ClientInner { let path = &"api/core/v2/receipts"; let resp = self - .node_manager - .read() - .await .get_request::(path, None, DEFAULT_API_TIMEOUT, false, false) .await?; @@ -326,9 +281,6 @@ impl ClientInner { let path = &format!("api/core/v2/receipts/{milestone_index}"); let resp = self - .node_manager - .read() - .await .get_request::(path, None, DEFAULT_API_TIMEOUT, false, false) .await?; @@ -341,11 +293,7 @@ impl ClientInner { pub async fn get_treasury(&self) -> Result { let path = "api/core/v2/treasury"; - self.node_manager - .read() - .await - .get_request(path, None, DEFAULT_API_TIMEOUT, false, false) - .await + self.get_request(path, None, DEFAULT_API_TIMEOUT, false, false).await } /// Returns the block, as object, that was included in the ledger for a given TransactionId. @@ -354,9 +302,6 @@ impl ClientInner { let path = &format!("api/core/v2/transactions/{transaction_id}/included-block"); let dto = self - .node_manager - .read() - .await .get_request::(path, None, self.get_timeout().await, true, true) .await?; @@ -371,11 +316,7 @@ impl ClientInner { pub async fn get_included_block_raw(&self, transaction_id: &TransactionId) -> Result> { let path = &format!("api/core/v2/transactions/{transaction_id}/included-block"); - self.node_manager - .read() - .await - .get_request_bytes(path, None, self.get_timeout().await) - .await + self.get_request_bytes(path, None, self.get_timeout().await).await } /// Returns the metadata of the block that was included in the ledger for a given TransactionId. @@ -383,11 +324,7 @@ impl ClientInner { pub async fn get_included_block_metadata(&self, transaction_id: &TransactionId) -> Result { let path = &format!("api/core/v2/transactions/{transaction_id}/included-block/metadata"); - self.node_manager - .read() - .await - .get_request(path, None, self.get_timeout().await, true, true) - .await + self.get_request(path, None, self.get_timeout().await, true, true).await } // Milestones routes. @@ -398,9 +335,6 @@ impl ClientInner { let path = &format!("api/core/v2/milestones/{milestone_id}"); let dto = self - .node_manager - .read() - .await .get_request::(path, None, self.get_timeout().await, false, true) .await?; @@ -415,11 +349,7 @@ impl ClientInner { pub async fn get_milestone_by_id_raw(&self, milestone_id: &MilestoneId) -> Result> { let path = &format!("api/core/v2/milestones/{milestone_id}"); - self.node_manager - .read() - .await - .get_request_bytes(path, None, self.get_timeout().await) - .await + self.get_request_bytes(path, None, self.get_timeout().await).await } /// Gets all UTXO changes of a milestone by its milestone id. @@ -427,10 +357,7 @@ impl ClientInner { pub async fn get_utxo_changes_by_id(&self, milestone_id: &MilestoneId) -> Result { let path = &format!("api/core/v2/milestones/{milestone_id}/utxo-changes"); - self.node_manager - .read() - .await - .get_request(path, None, self.get_timeout().await, false, false) + self.get_request(path, None, self.get_timeout().await, false, false) .await } @@ -440,9 +367,6 @@ impl ClientInner { let path = &format!("api/core/v2/milestones/by-index/{index}"); let dto = self - .node_manager - .read() - .await .get_request::(path, None, self.get_timeout().await, false, true) .await?; @@ -457,11 +381,7 @@ impl ClientInner { pub async fn get_milestone_by_index_raw(&self, index: u32) -> Result> { let path = &format!("api/core/v2/milestones/by-index/{index}"); - self.node_manager - .read() - .await - .get_request_bytes(path, None, self.get_timeout().await) - .await + self.get_request_bytes(path, None, self.get_timeout().await).await } /// Gets all UTXO changes of a milestone by its milestone index. @@ -469,10 +389,7 @@ impl ClientInner { pub async fn get_utxo_changes_by_index(&self, index: u32) -> Result { let path = &format!("api/core/v2/milestones/by-index/{index}/utxo-changes"); - self.node_manager - .read() - .await - .get_request(path, None, self.get_timeout().await, false, false) + self.get_request(path, None, self.get_timeout().await, false, false) .await } @@ -483,9 +400,6 @@ impl ClientInner { let path = "api/core/v2/peers"; let resp = self - .node_manager - .read() - .await .get_request::>(path, None, self.get_timeout().await, false, false) .await?; diff --git a/sdk/src/client/node_api/indexer/mod.rs b/sdk/src/client/node_api/indexer/mod.rs index 128e891f68..14a19b8bd4 100644 --- a/sdk/src/client/node_api/indexer/mod.rs +++ b/sdk/src/client/node_api/indexer/mod.rs @@ -8,11 +8,11 @@ pub mod routes; pub(crate) use self::query_parameters::{QueryParameter, QueryParameters}; use crate::{ - client::{ClientInner, Result}, + client::{Client, Result}, types::api::plugins::indexer::OutputIdsResponse, }; -impl ClientInner { +impl Client { /// Get all output ids for a provided URL route and query parameters. /// If a `QueryParameter::Cursor(_)` is provided, only a single page will be queried. pub async fn get_output_ids( @@ -33,9 +33,6 @@ impl ClientInner { while let Some(cursor) = { let output_ids_response = self - .node_manager - .read() - .await .get_request::( route, query_parameters.to_query_string().as_deref(), diff --git a/sdk/src/client/node_api/indexer/routes.rs b/sdk/src/client/node_api/indexer/routes.rs index 81090976aa..832146e3c9 100644 --- a/sdk/src/client/node_api/indexer/routes.rs +++ b/sdk/src/client/node_api/indexer/routes.rs @@ -12,7 +12,7 @@ use crate::{ }, QueryParameters, }, - ClientInner, Error, Result, + Client, Error, Result, }, types::{ api::plugins::indexer::OutputIdsResponse, @@ -22,7 +22,7 @@ use crate::{ // hornet: https://github.com/gohornet/hornet/blob/develop/plugins/indexer/routes.go -impl ClientInner { +impl Client { /// Get basic outputs filtered by the given parameters. /// GET with query parameter returns all outputIDs that fit these filter criteria. /// Query parameters: "address", "hasStorageDepositReturn", "storageDepositReturnAddress", diff --git a/sdk/src/client/node_api/participation.rs b/sdk/src/client/node_api/participation.rs index 4962d851ce..682d09a0ec 100644 --- a/sdk/src/client/node_api/participation.rs +++ b/sdk/src/client/node_api/participation.rs @@ -6,7 +6,7 @@ //! use crate::{ - client::{ClientInner, Result}, + client::{Client, Result}, types::{ api::plugins::participation::{ responses::{AddressOutputsResponse, EventsResponse, OutputStatusResponse}, @@ -19,7 +19,7 @@ use crate::{ }, }; -impl ClientInner { +impl Client { /// RouteParticipationEvents is the route to list all events, returning their ID, the event name and status. pub async fn events(&self, event_type: Option) -> Result { let route = "api/participation/v1/events"; @@ -29,10 +29,7 @@ impl ClientInner { ParticipationEventType::Staking => "type=1", }); - self.node_manager - .read() - .await - .get_request(route, query, self.get_timeout().await, false, false) + self.get_request(route, query, self.get_timeout().await, false, false) .await } @@ -40,10 +37,7 @@ impl ClientInner { pub async fn event(&self, event_id: &ParticipationEventId) -> Result { let route = format!("api/participation/v1/events/{event_id}"); - self.node_manager - .read() - .await - .get_request(&route, None, self.get_timeout().await, false, false) + self.get_request(&route, None, self.get_timeout().await, false, false) .await } @@ -55,27 +49,21 @@ impl ClientInner { ) -> Result { let route = format!("api/participation/v1/events/{event_id}/status"); - self.node_manager - .read() - .await - .get_request( - &route, - milestone_index.map(|index| index.to_string()).as_deref(), - self.get_timeout().await, - false, - false, - ) - .await + self.get_request( + &route, + milestone_index.map(|index| index.to_string()).as_deref(), + self.get_timeout().await, + false, + false, + ) + .await } /// RouteOutputStatus is the route to get the vote status for a given output ID. pub async fn output_status(&self, output_id: &OutputId) -> Result { let route = format!("api/participation/v1/outputs/{output_id}"); - self.node_manager - .read() - .await - .get_request(&route, None, self.get_timeout().await, false, false) + self.get_request(&route, None, self.get_timeout().await, false, false) .await } @@ -86,10 +74,7 @@ impl ClientInner { ) -> Result { let route = format!("api/participation/v1/addresses/{}", bech32_address.convert()?); - self.node_manager - .read() - .await - .get_request(&route, None, self.get_timeout().await, false, false) + self.get_request(&route, None, self.get_timeout().await, false, false) .await } @@ -100,10 +85,7 @@ impl ClientInner { ) -> Result { let route = format!("api/participation/v1/addresses/{}/outputs", bech32_address.convert()?); - self.node_manager - .read() - .await - .get_request(&route, None, self.get_timeout().await, false, false) + self.get_request(&route, None, self.get_timeout().await, false, false) .await } } diff --git a/sdk/src/client/node_api/plugin/mod.rs b/sdk/src/client/node_api/plugin/mod.rs index aa7260a8f4..8e03879dec 100644 --- a/sdk/src/client/node_api/plugin/mod.rs +++ b/sdk/src/client/node_api/plugin/mod.rs @@ -7,9 +7,9 @@ use core::str::FromStr; use reqwest::Method; -use crate::client::{ClientInner, Result}; +use crate::client::{Client, Result}; -impl ClientInner { +impl Client { /// Extension method which provides request methods for plugins. pub async fn call_plugin_route( &self, @@ -20,22 +20,20 @@ impl ClientInner { request_object: Option, ) -> Result where - T: serde::de::DeserializeOwned + std::fmt::Debug + serde::Serialize, + T: serde::de::DeserializeOwned + std::fmt::Debug + serde::Serialize + Send, { let mut method = method.to_string(); method.make_ascii_uppercase(); let req_method = reqwest::Method::from_str(&method); - let node_manager = self.node_manager.read().await; let path = format!("{}{}{}", base_plugin_path, endpoint, query_params.join("&")); let timeout = self.get_timeout().await; match req_method { - Ok(Method::GET) => node_manager.get_request(&path, None, timeout, false, false).await, + Ok(Method::GET) => self.get_request(&path, None, timeout, false, false).await, Ok(Method::POST) => { - node_manager - .post_request_json(&path, timeout, request_object.into(), true) + self.post_request_json(&path, timeout, request_object.into(), true) .await } _ => Err(crate::client::Error::Node( diff --git a/sdk/src/client/node_manager/mod.rs b/sdk/src/client/node_manager/mod.rs index cfa8003269..67e9f7f9f9 100644 --- a/sdk/src/client/node_manager/mod.rs +++ b/sdk/src/client/node_manager/mod.rs @@ -11,13 +11,16 @@ pub(crate) mod syncing; use std::{ collections::{HashMap, HashSet}, + fmt::Debug, sync::RwLock, time::Duration, }; +use serde::{de::DeserializeOwned, Serialize}; use serde_json::Value; use self::{http_client::HttpClient, node::Node}; +use super::Client; use crate::{ client::{ error::{Error, Result}, @@ -42,7 +45,7 @@ pub struct NodeManager { pub(crate) http_client: HttpClient, } -impl std::fmt::Debug for NodeManager { +impl Debug for NodeManager { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut d = f.debug_struct("NodeManager"); d.field("primary_node", &self.primary_node); @@ -58,6 +61,123 @@ impl std::fmt::Debug for NodeManager { } } +impl Client { + pub(crate) async fn get_request( + &self, + path: &str, + query: Option<&str>, + timeout: Duration, + need_quorum: bool, + prefer_permanode: bool, + ) -> Result { + #[cfg(not(target_family = "wasm"))] + { + let path = path.to_owned(); + let query = query.map(ToOwned::to_owned); + self.rate_limit(move |client| async move { + client + .node_manager + .read() + .await + .get_request(&path, query.as_deref(), timeout, need_quorum, prefer_permanode) + .await + }) + .await + } + #[cfg(target_family = "wasm")] + self.node_manager + .read() + .await + .get_request(&path, query.as_deref(), timeout, need_quorum, prefer_permanode) + .await + } + + pub(crate) async fn get_request_bytes( + &self, + path: &str, + query: Option<&str>, + timeout: Duration, + ) -> Result> { + #[cfg(not(target_family = "wasm"))] + { + let path = path.to_owned(); + let query = query.map(ToOwned::to_owned); + self.rate_limit(move |client| async move { + client + .node_manager + .read() + .await + .get_request_bytes(&path, query.as_deref(), timeout) + .await + }) + .await + } + #[cfg(target_family = "wasm")] + self.node_manager + .read() + .await + .get_request_bytes(&path, query.as_deref(), timeout) + .await + } + + pub(crate) async fn post_request_bytes( + &self, + path: &str, + timeout: Duration, + body: &[u8], + local_pow: bool, + ) -> Result { + #[cfg(not(target_family = "wasm"))] + { + let path = path.to_owned(); + let body = body.to_owned(); + self.rate_limit(move |client| async move { + client + .node_manager + .read() + .await + .post_request_bytes(&path, timeout, &body, local_pow) + .await + }) + .await + } + #[cfg(target_family = "wasm")] + self.node_manager + .read() + .await + .post_request_bytes(&path, timeout, &body, local_pow) + .await + } + + pub(crate) async fn post_request_json( + &self, + path: &str, + timeout: Duration, + json: Value, + local_pow: bool, + ) -> Result { + #[cfg(not(target_family = "wasm"))] + { + let path = path.to_owned(); + self.rate_limit(move |client| async move { + client + .node_manager + .read() + .await + .post_request_json(&path, timeout, json, local_pow) + .await + }) + .await + } + #[cfg(target_family = "wasm")] + self.node_manager + .read() + .await + .post_request_json(&path, timeout, json, local_pow) + .await + } +} + impl NodeManager { pub(crate) fn builder() -> NodeManagerBuilder { NodeManagerBuilder::new() @@ -164,7 +284,7 @@ impl NodeManager { Ok(nodes_with_modified_url) } - pub(crate) async fn get_request( + pub(crate) async fn get_request( &self, path: &str, query: Option<&str>, @@ -312,7 +432,7 @@ impl NodeManager { Err(error.unwrap()) } - pub(crate) async fn post_request_bytes( + pub(crate) async fn post_request_bytes( &self, path: &str, timeout: Duration, @@ -341,7 +461,7 @@ impl NodeManager { Err(error.unwrap()) } - pub(crate) async fn post_request_json( + pub(crate) async fn post_request_json( &self, path: &str, timeout: Duration, diff --git a/sdk/src/client/utils.rs b/sdk/src/client/utils.rs index 2500b5e48f..59890a75a6 100644 --- a/sdk/src/client/utils.rs +++ b/sdk/src/client/utils.rs @@ -14,7 +14,7 @@ use crypto::{ use serde::{Deserialize, Serialize}; use zeroize::{Zeroize, ZeroizeOnDrop}; -use super::{Client, ClientInner}; +use super::Client; use crate::{ client::{Error, Result}, types::block::{ @@ -101,7 +101,7 @@ pub async fn request_funds_from_faucet(url: &str, bech32_address: &Bech32Address Ok(faucet_response) } -impl ClientInner { +impl Client { /// Transforms a hex encoded address to a bech32 encoded address pub async fn hex_to_bech32( &self, @@ -149,9 +149,7 @@ impl ClientInner { None => Ok(hex_public_key_to_bech32_address(hex, self.get_bech32_hrp().await?)?), } } -} -impl Client { /// Transforms bech32 to hex pub fn bech32_to_hex(bech32: impl ConvertTo) -> crate::client::Result { bech32_to_hex(bech32) diff --git a/sdk/src/client/worker.rs b/sdk/src/client/worker.rs new file mode 100644 index 0000000000..a9f691c3be --- /dev/null +++ b/sdk/src/client/worker.rs @@ -0,0 +1,239 @@ +// Copyright 2023 IOTA Stiftung +// SPDX-License-Identifier: Apache-2.0 + +use alloc::{ + collections::{BinaryHeap, VecDeque}, + sync::Arc, +}; +use core::pin::Pin; + +use futures::{future::BoxFuture, Future, FutureExt}; +use pin_project::pin_project; +use thiserror::Error; +use tokio::{ + sync::{mpsc::error::TryRecvError, oneshot::error::RecvError, Mutex}, + task::{JoinError, JoinHandle}, +}; + +#[derive(Debug, Error)] +pub(crate) enum WorkerError { + #[error("worker pool is empty")] + EmptyPool, + #[error("error sending worker task to processor")] + Send, + #[error("error receiving worker output: {0}")] + Receive(#[from] RecvError), + #[error("error exiting worker: {0}")] + Join(#[from] JoinError), +} + +#[derive(Debug)] +pub(crate) struct WorkerPool(Mutex>>); + +impl WorkerPool { + pub(crate) fn new(count: usize) -> Self { + let mut workers = VecDeque::with_capacity(count); + for _ in 0..count { + workers.push_back(Arc::new(Worker::spawn())); + } + Self(Mutex::new(workers)) + } + + pub(crate) async fn process_task( + &self, + priority: TaskPriority, + future: F, + ) -> Result + where + F::Output: Send, + { + let mut pool = self.0.lock().await; + let worker = pool.front().ok_or(WorkerError::EmptyPool)?.clone(); + // Move the worker to the back + pool.rotate_left(1); + drop(pool); + let output = worker.process_task(priority, future).await?; + Ok(output) + } + + pub(crate) async fn resize(&self, new_size: usize) -> Result<(), WorkerError> { + if new_size == 0 { + return Err(WorkerError::EmptyPool); + } + let mut pool = self.0.lock().await; + let curr_size = pool.len(); + match new_size.cmp(&curr_size) { + core::cmp::Ordering::Less => { + while pool.len() > new_size { + if let Some(worker) = pool.pop_front() { + worker.exit()?; + } + } + } + core::cmp::Ordering::Greater => { + while pool.len() < new_size { + pool.push_front(Arc::new(Worker::spawn())); + } + } + core::cmp::Ordering::Equal => (), + } + Ok(()) + } + + pub(crate) async fn size(&self) -> usize { + self.0.lock().await.len() + } +} + +pub(crate) enum WorkerEvent { + Task(WorkerTask), + Exit, +} + +#[repr(u8)] +#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Copy, Clone)] +pub enum TaskPriority { + Low = 0, + Medium = 1, + High = 2, +} + +#[derive(Debug)] +pub(crate) struct Worker { + join_handle: JoinHandle<()>, + sender: tokio::sync::mpsc::UnboundedSender, +} + +impl Worker { + fn spawn() -> Self { + let (sender, mut recv) = tokio::sync::mpsc::unbounded_channel(); + let join_handle = tokio::spawn(async move { + let mut queue = BinaryHeap::new(); + let mut exiting = false; + // Wait to be awakened by the channel + while let Some(task) = recv.recv().await { + match task { + WorkerEvent::Task(task) => queue.push(task), + WorkerEvent::Exit => exiting = true, + } + loop { + if !exiting { + // Get up to 10 messages at a time + for _ in 0..10 { + match recv.try_recv() { + Ok(task) => match task { + WorkerEvent::Task(task) => queue.push(task), + WorkerEvent::Exit => exiting = true, + }, + Err(e) => match e { + TryRecvError::Empty => break, + TryRecvError::Disconnected => return, + }, + } + } + } + if let Some(next) = queue.pop() { + next.await; + } else { + break; + } + } + if exiting { + return; + } + } + }); + Self { join_handle, sender } + } + + async fn process_task( + &self, + priority: TaskPriority, + future: F, + ) -> Result + where + F::Output: Send, + { + let (task, recv) = WorkerTask::new(priority, future); + self.sender + .send(WorkerEvent::Task(task)) + .map_err(|_| WorkerError::Send)?; + Ok(recv.await?) + } + + fn exit(&self) -> Result<(), WorkerError> { + self.sender.send(WorkerEvent::Exit).map_err(|_| WorkerError::Send)?; + Ok(()) + } +} + +impl Drop for Worker { + fn drop(&mut self) { + self.join_handle.abort(); + } +} + +#[pin_project] +pub(crate) struct WorkerTask { + id: u128, + priority: TaskPriority, + #[pin] + future: BoxFuture<'static, ()>, +} + +impl WorkerTask { + fn new( + priority: TaskPriority, + future: F, + ) -> (Self, tokio::sync::oneshot::Receiver) + where + F::Output: Send, + { + let uuid = rand::random(); + let (sender, receiver) = tokio::sync::oneshot::channel(); + let future = async { + sender.send(future.await).ok(); + } + .boxed(); + ( + Self { + id: uuid, + priority, + future, + }, + receiver, + ) + } +} + +impl Future for WorkerTask { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll { + let this = self.project(); + this.future.poll(cx) + } +} + +impl core::fmt::Debug for WorkerTask { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("WorkerTask").field("priority", &self.priority).finish() + } +} + +impl Ord for WorkerTask { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.priority.cmp(&other.priority) + } +} +impl PartialOrd for WorkerTask { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl PartialEq for WorkerTask { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} +impl Eq for WorkerTask {} diff --git a/sdk/src/wallet/core/builder.rs b/sdk/src/wallet/core/builder.rs index 978ee198b2..31e2bba876 100644 --- a/sdk/src/wallet/core/builder.rs +++ b/sdk/src/wallet/core/builder.rs @@ -251,7 +251,7 @@ where #[cfg(feature = "storage")] pub(crate) async fn from_wallet(wallet: &Wallet) -> Self { Self { - client_options: Some(ClientOptions::from_client(wallet.client()).await), + client_options: Some(wallet.client_options().await), coin_type: Some(wallet.coin_type.load(Ordering::Relaxed)), storage_options: Some(wallet.storage_options.clone()), secret_manager: Some(wallet.secret_manager.clone()), diff --git a/sdk/src/wallet/core/operations/client.rs b/sdk/src/wallet/core/operations/client.rs index bf8c883adb..298e4f439a 100644 --- a/sdk/src/wallet/core/operations/client.rs +++ b/sdk/src/wallet/core/operations/client.rs @@ -42,6 +42,8 @@ where remote_pow_timeout, #[cfg(not(target_family = "wasm"))] pow_worker_count, + #[cfg(not(target_family = "wasm"))] + max_parallel_api_requests, } = client_options; self.client .update_node_manager(node_manager_builder.build(HashMap::new())) @@ -50,6 +52,12 @@ where *self.client.api_timeout.write().await = api_timeout; *self.client.remote_pow_timeout.write().await = remote_pow_timeout; #[cfg(not(target_family = "wasm"))] + self.client + .worker_pool + .resize(max_parallel_api_requests) + .await + .map_err(|e| crate::wallet::Error::Other(Box::new(e) as _))?; + #[cfg(not(target_family = "wasm"))] { *self.client.pow_worker_count.write().await = pow_worker_count; }