Skip to content

Commit

Permalink
Limit network actions to tracked chains (in main) (#2393)
Browse files Browse the repository at this point in the history
* Add a `tracked_chains` field to `ChainWorkerState`

Prepare to only create network actions for the set of tracked chains.

* Add a `tracked_chains` field to `WorkerState`

Share it with the created chain worker actors.

* Add a `tracked_chains` field to `Client` type

Keep track of the chains that client is interested in.

* Select tracked chains when starting

Specify which chains should be tracked by a new `Client`.

* Forward tracked chains to `WorkerState`

Configure the worker based on the client's selection.

* Only create network actions for tracked chains

Avoid handling chains that aren't interesting to the client.

* Add `retry_pending_cross_chain_requests` helper

Allow resending messages intended for chains that weren't tracked when
the outgoing message was scheduled, but became tracked later.

* Add `Client::track_chain` method

Allow adding more chains to the initial set of tracked chains.

* Track newly created chains

Ensure that chains that the client open are tracked.

* Ensure newly assigned chain is tracked

So that the worker can properly handle it.

* Track chains used in benchmark

Ensure that they are properly executed during the benchmark.

* Add a TODO to merge tracked chains set

Remember to replace the quick-fix with a more comprehensive refactor.

* Track chains created during block execution

Check all executed blocks for messages that open new chains, and add the
new chain IDs to the set of tracked chains.
  • Loading branch information
jvff authored Aug 28, 2024
1 parent a07a3af commit 3bc1928
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 15 deletions.
7 changes: 7 additions & 0 deletions linera-client/src/client_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ where
storage,
options.max_pending_messages,
delivery,
wallet.chain_ids(),
);

ClientContext {
Expand Down Expand Up @@ -519,9 +520,15 @@ where
.expect("failed to create new chain");
let chain_id = ChainId::child(message_id);
key_pairs.insert(chain_id, key_pair.copy());
self.client.track_chain(chain_id);
self.update_wallet_for_new_chain(chain_id, Some(key_pair.copy()), timestamp);
}
}
let updated_chain_client = self.make_chain_client(default_chain_id);
updated_chain_client
.retry_pending_outgoing_messages()
.await
.context("outgoing messages to create the new chains should be delivered")?;

for chain_id in key_pairs.keys() {
let child_client = self.make_chain_client(*chain_id);
Expand Down
1 change: 1 addition & 0 deletions linera-client/src/unit_tests/chain_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ async fn test_chain_listener() -> anyhow::Result<()> {
storage.clone(),
10,
delivery,
[chain_id0],
)),
};
let key_pair = KeyPair::generate_from(&mut rng);
Expand Down
5 changes: 4 additions & 1 deletion linera-core/src/chain_worker/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
//! An actor that runs a chain worker.
use std::{
collections::HashSet,
fmt::{self, Debug, Formatter},
sync::Arc,
sync::{Arc, RwLock},
};

use linera_base::{
Expand Down Expand Up @@ -151,6 +152,7 @@ where
storage: StorageClient,
certificate_value_cache: Arc<ValueCache<CryptoHash, HashedCertificateValue>>,
blob_cache: Arc<ValueCache<BlobId, Blob>>,
tracked_chains: Option<Arc<RwLock<HashSet<ChainId>>>>,
chain_id: ChainId,
) -> Result<Self, WorkerError> {
let (service_runtime_thread, execution_state_receiver, runtime_request_sender) =
Expand All @@ -161,6 +163,7 @@ where
storage,
certificate_value_cache,
blob_cache,
tracked_chains,
chain_id,
execution_state_receiver,
runtime_request_sender,
Expand Down
1 change: 1 addition & 0 deletions linera-core/src/chain_worker/state/attempted_changes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ where
tip.num_outgoing_messages += executed_block.outcome.messages.len() as u32;
self.state.chain.confirmed_log.push(certificate.hash());
let info = ChainInfoResponse::new(&self.state.chain, self.state.config.key_pair());
self.state.track_newly_created_chains(executed_block);
let mut actions = self.state.create_network_actions().await?;
actions.notifications.push(Notification {
chain_id: block.chain_id,
Expand Down
44 changes: 39 additions & 5 deletions linera-core/src/chain_worker/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ mod temporary_changes;
use std::{
borrow::Cow,
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
sync::Arc,
sync::{self, Arc},
};

use linera_base::{
Expand All @@ -26,8 +26,8 @@ use linera_chain::{
ChainError, ChainStateView,
};
use linera_execution::{
committee::Epoch, ExecutionRequest, Query, QueryContext, Response, ServiceRuntimeRequest,
UserApplicationDescription, UserApplicationId,
committee::Epoch, ExecutionRequest, Message, Query, QueryContext, Response,
ServiceRuntimeRequest, SystemMessage, UserApplicationDescription, UserApplicationId,
};
use linera_storage::Storage;
use linera_views::views::{ClonableView, ViewError};
Expand Down Expand Up @@ -60,6 +60,7 @@ where
runtime_request_sender: std::sync::mpsc::Sender<ServiceRuntimeRequest>,
recent_hashed_certificate_values: Arc<ValueCache<CryptoHash, HashedCertificateValue>>,
recent_blobs: Arc<ValueCache<BlobId, Blob>>,
tracked_chains: Option<Arc<sync::RwLock<HashSet<ChainId>>>>,
knows_chain_is_active: bool,
}

Expand All @@ -69,11 +70,13 @@ where
ViewError: From<StorageClient::StoreError>,
{
/// Creates a new [`ChainWorkerState`] using the provided `storage` client.
#[allow(clippy::too_many_arguments)]
pub async fn load(
config: ChainWorkerConfig,
storage: StorageClient,
certificate_value_cache: Arc<ValueCache<CryptoHash, HashedCertificateValue>>,
blob_cache: Arc<ValueCache<BlobId, Blob>>,
tracked_chains: Option<Arc<sync::RwLock<HashSet<ChainId>>>>,
chain_id: ChainId,
execution_state_receiver: futures::channel::mpsc::UnboundedReceiver<ExecutionRequest>,
runtime_request_sender: std::sync::mpsc::Sender<ServiceRuntimeRequest>,
Expand All @@ -89,6 +92,7 @@ where
runtime_request_sender,
recent_hashed_certificate_values: certificate_value_cache,
recent_blobs: blob_cache,
tracked_chains,
knows_chain_is_active: false,
})
}
Expand Down Expand Up @@ -369,11 +373,41 @@ where
self.recent_blobs.insert(blob).await
}

/// Adds any newly created chains to the set of `tracked_chains`.
fn track_newly_created_chains(&self, block: &ExecutedBlock) {
if let Some(tracked_chains) = self.tracked_chains.as_ref() {
let messages = block.messages().iter().flatten();
let open_chain_message_indices =
messages
.enumerate()
.filter_map(|(index, outgoing_message)| match outgoing_message.message {
Message::System(SystemMessage::OpenChain(_)) => Some(index),
_ => None,
});
let open_chain_message_ids =
open_chain_message_indices.map(|index| block.message_id(index as u32));
let new_chain_ids = open_chain_message_ids.map(ChainId::child);

tracked_chains
.write()
.expect("Panics should not happen while holding a lock to `tracked_chains`")
.extend(new_chain_ids);
}
}

/// Loads pending cross-chain requests.
async fn create_network_actions(&self) -> Result<NetworkActions, WorkerError> {
let mut heights_by_recipient: BTreeMap<_, BTreeMap<_, _>> = Default::default();
let pairs = self.chain.outboxes.try_load_all_entries().await?;
for (target, outbox) in pairs {
let mut targets = self.chain.outboxes.indices().await?;
if let Some(tracked_chains) = self.tracked_chains.as_ref() {
let tracked_chains = tracked_chains
.read()
.expect("Panics should not happen while holding a lock to `tracked_chains`");
targets.retain(|target| tracked_chains.contains(&target.recipient));
}
let outboxes = self.chain.outboxes.try_load_entries(&targets).await?;
for (target, outbox) in targets.into_iter().zip(outboxes) {
let outbox = outbox.expect("Only existing outboxes should be referenced by `indices`");
let heights = outbox.queue.elements().await?;
heights_by_recipient
.entry(target.recipient)
Expand Down
43 changes: 39 additions & 4 deletions linera-core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
convert::Infallible,
iter,
ops::{Deref, DerefMut},
sync::Arc,
sync::{Arc, RwLock},
};

use dashmap::{
Expand Down Expand Up @@ -91,6 +91,9 @@ where
message_policy: MessagePolicy,
/// Whether to block on cross-chain message delivery.
cross_chain_message_delivery: CrossChainMessageDelivery,
/// Chains that should be tracked by the client.
// TODO(#2412): Merge with set of chains the client is receiving notifications from validators
tracked_chains: Arc<RwLock<HashSet<ChainId>>>,
/// References to clients waiting for chain notifications.
notifier: Arc<Notifier<Notification>>,
/// A copy of the storage client so that we don't have to lock the local node client
Expand All @@ -111,10 +114,16 @@ where
storage: S,
max_pending_messages: usize,
cross_chain_message_delivery: CrossChainMessageDelivery,
tracked_chains: impl IntoIterator<Item = ChainId>,
) -> Self {
let state = WorkerState::new_for_client("Client node".to_string(), storage.clone())
.with_allow_inactive_chains(true)
.with_allow_messages_from_deprecated_epochs(true);
let tracked_chains = Arc::new(RwLock::new(tracked_chains.into_iter().collect()));
let state = WorkerState::new_for_client(
"Client node".to_string(),
storage.clone(),
tracked_chains.clone(),
)
.with_allow_inactive_chains(true)
.with_allow_messages_from_deprecated_epochs(true);
let local_node = LocalNodeClient::new(state);

Self {
Expand All @@ -124,6 +133,7 @@ where
max_pending_messages,
message_policy: MessagePolicy::new(BlanketMessagePolicy::Accept, None),
cross_chain_message_delivery,
tracked_chains,
notifier: Arc::new(Notifier::default()),
storage,
}
Expand All @@ -141,6 +151,15 @@ where
&self.local_node
}

#[tracing::instrument(level = "trace", skip(self))]
/// Adds a chain to the set of chains tracked by the local node.
pub fn track_chain(&self, chain_id: ChainId) {
self.tracked_chains
.write()
.expect("Panics should not happen while holding a lock to `tracked_chains`")
.insert(chain_id);
}

#[tracing::instrument(level = "trace", skip_all, fields(chain_id, next_block_height))]
/// Creates a new `ChainClient`.
#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -2501,6 +2520,12 @@ where
executed_block.message_id_for_operation(0, OPEN_CHAIN_MESSAGE_INDEX)
})
.ok_or_else(|| ChainClientError::InternalError("Failed to create new chain"))?;
// Add the new chain to the list of tracked chains
self.client.track_chain(ChainId::child(message_id));
self.client
.local_node
.retry_pending_cross_chain_requests(self.chain_id)
.await?;
return Ok(ClientOutcome::Committed((message_id, certificate)));
}
}
Expand Down Expand Up @@ -2791,6 +2816,16 @@ where
.await
}

#[tracing::instrument(level = "trace")]
/// Handles any cross-chain requests for any pending outgoing messages.
pub async fn retry_pending_outgoing_messages(&self) -> Result<(), ChainClientError> {
self.client
.local_node
.retry_pending_cross_chain_requests(self.chain_id)
.await?;
Ok(())
}

#[tracing::instrument(level = "trace", skip(from, limit))]
pub async fn read_hashed_certificate_values_downward(
&self,
Expand Down
21 changes: 20 additions & 1 deletion linera-core/src/local_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

use std::{
borrow::Cow,
collections::{HashMap, HashSet},
collections::{HashMap, HashSet, VecDeque},
sync::Arc,
};

Expand Down Expand Up @@ -570,4 +570,23 @@ where
}
}
}

/// Handles any pending local cross-chain requests.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn retry_pending_cross_chain_requests(
&self,
sender_chain: ChainId,
) -> Result<(), LocalNodeError> {
let (_response, actions) = self
.node
.state
.handle_chain_info_query(ChainInfoQuery::new(sender_chain))
.await?;
let mut requests = VecDeque::from_iter(actions.cross_chain_requests);
while let Some(request) = requests.pop_front() {
let new_actions = self.node.state.handle_cross_chain_request(request).await?;
requests.extend(new_actions.cross_chain_requests);
}
Ok(())
}
}
1 change: 1 addition & 0 deletions linera-core/src/unit_tests/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ where
storage,
10,
CrossChainMessageDelivery::NonBlocking,
[chain_id],
));
Ok(builder.create_chain_client(
chain_id,
Expand Down
35 changes: 31 additions & 4 deletions linera-core/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

use std::{
borrow::Cow,
collections::{hash_map, BTreeMap, HashMap, VecDeque},
collections::{hash_map, BTreeMap, HashMap, HashSet, VecDeque},
num::NonZeroUsize,
sync::{Arc, LazyLock, Mutex},
sync::{Arc, LazyLock, Mutex, RwLock},
time::Duration,
};

Expand Down Expand Up @@ -235,6 +235,8 @@ where
recent_hashed_certificate_values: Arc<ValueCache<CryptoHash, HashedCertificateValue>>,
/// Cached blobs by `BlobId`.
recent_blobs: Arc<ValueCache<BlobId, Blob>>,
/// Chain IDs that should be tracked by a worker.
tracked_chains: Option<Arc<RwLock<HashSet<ChainId>>>>,
/// One-shot channels to notify callers when messages of a particular chain have been
/// delivered.
delivery_notifiers: Arc<Mutex<DeliveryNotifiers>>,
Expand Down Expand Up @@ -264,15 +266,30 @@ where
chain_worker_config: ChainWorkerConfig::default().with_key_pair(key_pair),
recent_hashed_certificate_values: Arc::new(ValueCache::default()),
recent_blobs: Arc::new(ValueCache::default()),
tracked_chains: None,
delivery_notifiers: Arc::default(),
chain_worker_tasks: Arc::default(),
chain_workers: Arc::new(Mutex::new(LruCache::new(*CHAIN_WORKER_LIMIT))),
}
}

#[tracing::instrument(level = "trace", skip(nickname, storage))]
pub fn new_for_client(nickname: String, storage: StorageClient) -> Self {
Self::new(nickname, None, storage)
pub fn new_for_client(
nickname: String,
storage: StorageClient,
tracked_chains: Arc<RwLock<HashSet<ChainId>>>,
) -> Self {
WorkerState {
nickname,
storage,
chain_worker_config: ChainWorkerConfig::default(),
recent_hashed_certificate_values: Arc::new(ValueCache::default()),
recent_blobs: Arc::new(ValueCache::default()),
tracked_chains: Some(tracked_chains),
delivery_notifiers: Arc::default(),
chain_worker_tasks: Arc::default(),
chain_workers: Arc::new(Mutex::new(LruCache::new(*CHAIN_WORKER_LIMIT))),
}
}

#[tracing::instrument(level = "trace", skip(self, value))]
Expand All @@ -288,6 +305,15 @@ where
self
}

/// Configures the subset of chains that this worker is tracking.
pub fn with_tracked_chains(
mut self,
tracked_chains: impl IntoIterator<Item = ChainId>,
) -> Self {
self.tracked_chains = Some(Arc::new(RwLock::new(tracked_chains.into_iter().collect())));
self
}

/// Returns an instance with the specified grace period, in microseconds.
///
/// Blocks with a timestamp this far in the future will still be accepted, but the validator
Expand Down Expand Up @@ -665,6 +691,7 @@ where
self.storage.clone(),
self.recent_hashed_certificate_values.clone(),
self.recent_blobs.clone(),
self.tracked_chains.clone(),
chain_id,
)
.await?;
Expand Down
1 change: 1 addition & 0 deletions linera-service/src/linera/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,7 @@ impl Job {
ViewError: From<S::StoreError>,
{
let state = WorkerState::new("Local node".to_string(), None, storage)
.with_tracked_chains([message_id.chain_id, chain_id])
.with_allow_inactive_chains(true)
.with_allow_messages_from_deprecated_epochs(true);
let node_client = LocalNodeClient::new(state);
Expand Down

0 comments on commit 3bc1928

Please sign in to comment.