diff --git a/linera-client/src/client_context.rs b/linera-client/src/client_context.rs index b4cbcb47607..7c4014844cd 100644 --- a/linera-client/src/client_context.rs +++ b/linera-client/src/client_context.rs @@ -146,6 +146,7 @@ where storage, options.max_pending_messages, delivery, + wallet.chain_ids(), ); ClientContext { @@ -519,9 +520,15 @@ where .expect("failed to create new chain"); let chain_id = ChainId::child(message_id); key_pairs.insert(chain_id, key_pair.copy()); + self.client.track_chain(chain_id); self.update_wallet_for_new_chain(chain_id, Some(key_pair.copy()), timestamp); } } + let updated_chain_client = self.make_chain_client(default_chain_id); + updated_chain_client + .retry_pending_outgoing_messages() + .await + .context("outgoing messages to create the new chains should be delivered")?; for chain_id in key_pairs.keys() { let child_client = self.make_chain_client(*chain_id); diff --git a/linera-client/src/unit_tests/chain_listener.rs b/linera-client/src/unit_tests/chain_listener.rs index 26a5a6fbe11..ad5caf9ce06 100644 --- a/linera-client/src/unit_tests/chain_listener.rs +++ b/linera-client/src/unit_tests/chain_listener.rs @@ -150,6 +150,7 @@ async fn test_chain_listener() -> anyhow::Result<()> { storage.clone(), 10, delivery, + [chain_id0], )), }; let key_pair = KeyPair::generate_from(&mut rng); diff --git a/linera-core/src/chain_worker/actor.rs b/linera-core/src/chain_worker/actor.rs index 4ccee66330b..7fc9a584851 100644 --- a/linera-core/src/chain_worker/actor.rs +++ b/linera-core/src/chain_worker/actor.rs @@ -4,8 +4,9 @@ //! An actor that runs a chain worker. use std::{ + collections::HashSet, fmt::{self, Debug, Formatter}, - sync::Arc, + sync::{Arc, RwLock}, }; use linera_base::{ @@ -151,6 +152,7 @@ where storage: StorageClient, certificate_value_cache: Arc>, blob_cache: Arc>, + tracked_chains: Option>>>, chain_id: ChainId, ) -> Result { let (service_runtime_thread, execution_state_receiver, runtime_request_sender) = @@ -161,6 +163,7 @@ where storage, certificate_value_cache, blob_cache, + tracked_chains, chain_id, execution_state_receiver, runtime_request_sender, diff --git a/linera-core/src/chain_worker/state/attempted_changes.rs b/linera-core/src/chain_worker/state/attempted_changes.rs index 752d0b04c9e..5bd48571988 100644 --- a/linera-core/src/chain_worker/state/attempted_changes.rs +++ b/linera-core/src/chain_worker/state/attempted_changes.rs @@ -346,6 +346,7 @@ where tip.num_outgoing_messages += executed_block.outcome.messages.len() as u32; self.state.chain.confirmed_log.push(certificate.hash()); let info = ChainInfoResponse::new(&self.state.chain, self.state.config.key_pair()); + self.state.track_newly_created_chains(executed_block); let mut actions = self.state.create_network_actions().await?; actions.notifications.push(Notification { chain_id: block.chain_id, diff --git a/linera-core/src/chain_worker/state/mod.rs b/linera-core/src/chain_worker/state/mod.rs index 369a4a034ad..6cc8a1b8487 100644 --- a/linera-core/src/chain_worker/state/mod.rs +++ b/linera-core/src/chain_worker/state/mod.rs @@ -9,7 +9,7 @@ mod temporary_changes; use std::{ borrow::Cow, collections::{BTreeMap, BTreeSet, HashMap, HashSet}, - sync::Arc, + sync::{self, Arc}, }; use linera_base::{ @@ -26,8 +26,8 @@ use linera_chain::{ ChainError, ChainStateView, }; use linera_execution::{ - committee::Epoch, ExecutionRequest, Query, QueryContext, Response, ServiceRuntimeRequest, - UserApplicationDescription, UserApplicationId, + committee::Epoch, ExecutionRequest, Message, Query, QueryContext, Response, + ServiceRuntimeRequest, SystemMessage, UserApplicationDescription, UserApplicationId, }; use linera_storage::Storage; use linera_views::views::{ClonableView, ViewError}; @@ -60,6 +60,7 @@ where runtime_request_sender: std::sync::mpsc::Sender, recent_hashed_certificate_values: Arc>, recent_blobs: Arc>, + tracked_chains: Option>>>, knows_chain_is_active: bool, } @@ -69,11 +70,13 @@ where ViewError: From, { /// Creates a new [`ChainWorkerState`] using the provided `storage` client. + #[allow(clippy::too_many_arguments)] pub async fn load( config: ChainWorkerConfig, storage: StorageClient, certificate_value_cache: Arc>, blob_cache: Arc>, + tracked_chains: Option>>>, chain_id: ChainId, execution_state_receiver: futures::channel::mpsc::UnboundedReceiver, runtime_request_sender: std::sync::mpsc::Sender, @@ -89,6 +92,7 @@ where runtime_request_sender, recent_hashed_certificate_values: certificate_value_cache, recent_blobs: blob_cache, + tracked_chains, knows_chain_is_active: false, }) } @@ -369,11 +373,41 @@ where self.recent_blobs.insert(blob).await } + /// Adds any newly created chains to the set of `tracked_chains`. + fn track_newly_created_chains(&self, block: &ExecutedBlock) { + if let Some(tracked_chains) = self.tracked_chains.as_ref() { + let messages = block.messages().iter().flatten(); + let open_chain_message_indices = + messages + .enumerate() + .filter_map(|(index, outgoing_message)| match outgoing_message.message { + Message::System(SystemMessage::OpenChain(_)) => Some(index), + _ => None, + }); + let open_chain_message_ids = + open_chain_message_indices.map(|index| block.message_id(index as u32)); + let new_chain_ids = open_chain_message_ids.map(ChainId::child); + + tracked_chains + .write() + .expect("Panics should not happen while holding a lock to `tracked_chains`") + .extend(new_chain_ids); + } + } + /// Loads pending cross-chain requests. async fn create_network_actions(&self) -> Result { let mut heights_by_recipient: BTreeMap<_, BTreeMap<_, _>> = Default::default(); - let pairs = self.chain.outboxes.try_load_all_entries().await?; - for (target, outbox) in pairs { + let mut targets = self.chain.outboxes.indices().await?; + if let Some(tracked_chains) = self.tracked_chains.as_ref() { + let tracked_chains = tracked_chains + .read() + .expect("Panics should not happen while holding a lock to `tracked_chains`"); + targets.retain(|target| tracked_chains.contains(&target.recipient)); + } + let outboxes = self.chain.outboxes.try_load_entries(&targets).await?; + for (target, outbox) in targets.into_iter().zip(outboxes) { + let outbox = outbox.expect("Only existing outboxes should be referenced by `indices`"); let heights = outbox.queue.elements().await?; heights_by_recipient .entry(target.recipient) diff --git a/linera-core/src/client.rs b/linera-core/src/client.rs index ba912604210..e1a5b93a557 100644 --- a/linera-core/src/client.rs +++ b/linera-core/src/client.rs @@ -7,7 +7,7 @@ use std::{ convert::Infallible, iter, ops::{Deref, DerefMut}, - sync::Arc, + sync::{Arc, RwLock}, }; use dashmap::{ @@ -91,6 +91,9 @@ where message_policy: MessagePolicy, /// Whether to block on cross-chain message delivery. cross_chain_message_delivery: CrossChainMessageDelivery, + /// Chains that should be tracked by the client. + // TODO(#2412): Merge with set of chains the client is receiving notifications from validators + tracked_chains: Arc>>, /// References to clients waiting for chain notifications. notifier: Arc>, /// A copy of the storage client so that we don't have to lock the local node client @@ -111,10 +114,16 @@ where storage: S, max_pending_messages: usize, cross_chain_message_delivery: CrossChainMessageDelivery, + tracked_chains: impl IntoIterator, ) -> Self { - let state = WorkerState::new_for_client("Client node".to_string(), storage.clone()) - .with_allow_inactive_chains(true) - .with_allow_messages_from_deprecated_epochs(true); + let tracked_chains = Arc::new(RwLock::new(tracked_chains.into_iter().collect())); + let state = WorkerState::new_for_client( + "Client node".to_string(), + storage.clone(), + tracked_chains.clone(), + ) + .with_allow_inactive_chains(true) + .with_allow_messages_from_deprecated_epochs(true); let local_node = LocalNodeClient::new(state); Self { @@ -124,6 +133,7 @@ where max_pending_messages, message_policy: MessagePolicy::new(BlanketMessagePolicy::Accept, None), cross_chain_message_delivery, + tracked_chains, notifier: Arc::new(Notifier::default()), storage, } @@ -141,6 +151,15 @@ where &self.local_node } + #[tracing::instrument(level = "trace", skip(self))] + /// Adds a chain to the set of chains tracked by the local node. + pub fn track_chain(&self, chain_id: ChainId) { + self.tracked_chains + .write() + .expect("Panics should not happen while holding a lock to `tracked_chains`") + .insert(chain_id); + } + #[tracing::instrument(level = "trace", skip_all, fields(chain_id, next_block_height))] /// Creates a new `ChainClient`. #[allow(clippy::too_many_arguments)] @@ -2501,6 +2520,12 @@ where executed_block.message_id_for_operation(0, OPEN_CHAIN_MESSAGE_INDEX) }) .ok_or_else(|| ChainClientError::InternalError("Failed to create new chain"))?; + // Add the new chain to the list of tracked chains + self.client.track_chain(ChainId::child(message_id)); + self.client + .local_node + .retry_pending_cross_chain_requests(self.chain_id) + .await?; return Ok(ClientOutcome::Committed((message_id, certificate))); } } @@ -2791,6 +2816,16 @@ where .await } + #[tracing::instrument(level = "trace")] + /// Handles any cross-chain requests for any pending outgoing messages. + pub async fn retry_pending_outgoing_messages(&self) -> Result<(), ChainClientError> { + self.client + .local_node + .retry_pending_cross_chain_requests(self.chain_id) + .await?; + Ok(()) + } + #[tracing::instrument(level = "trace", skip(from, limit))] pub async fn read_hashed_certificate_values_downward( &self, diff --git a/linera-core/src/local_node.rs b/linera-core/src/local_node.rs index 06b3d99dc3f..1fe436ead78 100644 --- a/linera-core/src/local_node.rs +++ b/linera-core/src/local_node.rs @@ -4,7 +4,7 @@ use std::{ borrow::Cow, - collections::{HashMap, HashSet}, + collections::{HashMap, HashSet, VecDeque}, sync::Arc, }; @@ -570,4 +570,23 @@ where } } } + + /// Handles any pending local cross-chain requests. + #[tracing::instrument(level = "trace", skip(self))] + pub async fn retry_pending_cross_chain_requests( + &self, + sender_chain: ChainId, + ) -> Result<(), LocalNodeError> { + let (_response, actions) = self + .node + .state + .handle_chain_info_query(ChainInfoQuery::new(sender_chain)) + .await?; + let mut requests = VecDeque::from_iter(actions.cross_chain_requests); + while let Some(request) = requests.pop_front() { + let new_actions = self.node.state.handle_cross_chain_request(request).await?; + requests.extend(new_actions.cross_chain_requests); + } + Ok(()) + } } diff --git a/linera-core/src/unit_tests/test_utils.rs b/linera-core/src/unit_tests/test_utils.rs index 749aac8be30..95790e236cb 100644 --- a/linera-core/src/unit_tests/test_utils.rs +++ b/linera-core/src/unit_tests/test_utils.rs @@ -796,6 +796,7 @@ where storage, 10, CrossChainMessageDelivery::NonBlocking, + [chain_id], )); Ok(builder.create_chain_client( chain_id, diff --git a/linera-core/src/worker.rs b/linera-core/src/worker.rs index 5c33898059d..db9df5b1274 100644 --- a/linera-core/src/worker.rs +++ b/linera-core/src/worker.rs @@ -4,9 +4,9 @@ use std::{ borrow::Cow, - collections::{hash_map, BTreeMap, HashMap, VecDeque}, + collections::{hash_map, BTreeMap, HashMap, HashSet, VecDeque}, num::NonZeroUsize, - sync::{Arc, LazyLock, Mutex}, + sync::{Arc, LazyLock, Mutex, RwLock}, time::Duration, }; @@ -235,6 +235,8 @@ where recent_hashed_certificate_values: Arc>, /// Cached blobs by `BlobId`. recent_blobs: Arc>, + /// Chain IDs that should be tracked by a worker. + tracked_chains: Option>>>, /// One-shot channels to notify callers when messages of a particular chain have been /// delivered. delivery_notifiers: Arc>, @@ -264,6 +266,7 @@ where chain_worker_config: ChainWorkerConfig::default().with_key_pair(key_pair), recent_hashed_certificate_values: Arc::new(ValueCache::default()), recent_blobs: Arc::new(ValueCache::default()), + tracked_chains: None, delivery_notifiers: Arc::default(), chain_worker_tasks: Arc::default(), chain_workers: Arc::new(Mutex::new(LruCache::new(*CHAIN_WORKER_LIMIT))), @@ -271,8 +274,22 @@ where } #[tracing::instrument(level = "trace", skip(nickname, storage))] - pub fn new_for_client(nickname: String, storage: StorageClient) -> Self { - Self::new(nickname, None, storage) + pub fn new_for_client( + nickname: String, + storage: StorageClient, + tracked_chains: Arc>>, + ) -> Self { + WorkerState { + nickname, + storage, + chain_worker_config: ChainWorkerConfig::default(), + recent_hashed_certificate_values: Arc::new(ValueCache::default()), + recent_blobs: Arc::new(ValueCache::default()), + tracked_chains: Some(tracked_chains), + delivery_notifiers: Arc::default(), + chain_worker_tasks: Arc::default(), + chain_workers: Arc::new(Mutex::new(LruCache::new(*CHAIN_WORKER_LIMIT))), + } } #[tracing::instrument(level = "trace", skip(self, value))] @@ -288,6 +305,15 @@ where self } + /// Configures the subset of chains that this worker is tracking. + pub fn with_tracked_chains( + mut self, + tracked_chains: impl IntoIterator, + ) -> Self { + self.tracked_chains = Some(Arc::new(RwLock::new(tracked_chains.into_iter().collect()))); + self + } + /// Returns an instance with the specified grace period, in microseconds. /// /// Blocks with a timestamp this far in the future will still be accepted, but the validator @@ -665,6 +691,7 @@ where self.storage.clone(), self.recent_hashed_certificate_values.clone(), self.recent_blobs.clone(), + self.tracked_chains.clone(), chain_id, ) .await?; diff --git a/linera-service/src/linera/main.rs b/linera-service/src/linera/main.rs index aef755fe133..3c5e06ad50e 100644 --- a/linera-service/src/linera/main.rs +++ b/linera-service/src/linera/main.rs @@ -1100,6 +1100,7 @@ impl Job { ViewError: From, { let state = WorkerState::new("Local node".to_string(), None, storage) + .with_tracked_chains([message_id.chain_id, chain_id]) .with_allow_inactive_chains(true) .with_allow_messages_from_deprecated_epochs(true); let node_client = LocalNodeClient::new(state);