Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit network actions to tracked chains (in main) #2393

Merged
merged 13 commits into from
Aug 28, 2024
7 changes: 7 additions & 0 deletions linera-client/src/client_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ where
storage,
options.max_pending_messages,
delivery,
wallet.chain_ids(),
);

ClientContext {
Expand Down Expand Up @@ -519,9 +520,15 @@ where
.expect("failed to create new chain");
let chain_id = ChainId::child(message_id);
key_pairs.insert(chain_id, key_pair.copy());
self.client.track_chain(chain_id);
self.update_wallet_for_new_chain(chain_id, Some(key_pair.copy()), timestamp);
}
}
let updated_chain_client = self.make_chain_client(default_chain_id);
updated_chain_client
.retry_pending_outgoing_messages()
.await
.context("outgoing messages to create the new chains should be delivered")?;

for chain_id in key_pairs.keys() {
let child_client = self.make_chain_client(*chain_id);
Expand Down
1 change: 1 addition & 0 deletions linera-client/src/unit_tests/chain_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ async fn test_chain_listener() -> anyhow::Result<()> {
storage.clone(),
10,
delivery,
[chain_id0],
)),
};
let key_pair = KeyPair::generate_from(&mut rng);
Expand Down
5 changes: 4 additions & 1 deletion linera-core/src/chain_worker/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
//! An actor that runs a chain worker.

use std::{
collections::HashSet,
fmt::{self, Debug, Formatter},
sync::Arc,
sync::{Arc, RwLock},
};

use linera_base::{
Expand Down Expand Up @@ -151,6 +152,7 @@ where
storage: StorageClient,
certificate_value_cache: Arc<ValueCache<CryptoHash, HashedCertificateValue>>,
blob_cache: Arc<ValueCache<BlobId, Blob>>,
tracked_chains: Option<Arc<RwLock<HashSet<ChainId>>>>,
chain_id: ChainId,
) -> Result<Self, WorkerError> {
let (service_runtime_thread, execution_state_receiver, runtime_request_sender) =
Expand All @@ -161,6 +163,7 @@ where
storage,
certificate_value_cache,
blob_cache,
tracked_chains,
chain_id,
execution_state_receiver,
runtime_request_sender,
Expand Down
1 change: 1 addition & 0 deletions linera-core/src/chain_worker/state/attempted_changes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ where
tip.num_outgoing_messages += executed_block.outcome.messages.len() as u32;
self.state.chain.confirmed_log.push(certificate.hash());
let info = ChainInfoResponse::new(&self.state.chain, self.state.config.key_pair());
self.state.track_newly_created_chains(executed_block);
let mut actions = self.state.create_network_actions().await?;
actions.notifications.push(Notification {
chain_id: block.chain_id,
Expand Down
44 changes: 39 additions & 5 deletions linera-core/src/chain_worker/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ mod temporary_changes;
use std::{
borrow::Cow,
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
sync::Arc,
sync::{self, Arc},
};

use linera_base::{
Expand All @@ -26,8 +26,8 @@ use linera_chain::{
ChainError, ChainStateView,
};
use linera_execution::{
committee::Epoch, ExecutionRequest, Query, QueryContext, Response, ServiceRuntimeRequest,
UserApplicationDescription, UserApplicationId,
committee::Epoch, ExecutionRequest, Message, Query, QueryContext, Response,
ServiceRuntimeRequest, SystemMessage, UserApplicationDescription, UserApplicationId,
};
use linera_storage::Storage;
use linera_views::views::{ClonableView, ViewError};
Expand Down Expand Up @@ -60,6 +60,7 @@ where
runtime_request_sender: std::sync::mpsc::Sender<ServiceRuntimeRequest>,
recent_hashed_certificate_values: Arc<ValueCache<CryptoHash, HashedCertificateValue>>,
recent_blobs: Arc<ValueCache<BlobId, Blob>>,
tracked_chains: Option<Arc<sync::RwLock<HashSet<ChainId>>>>,
knows_chain_is_active: bool,
}

Expand All @@ -69,11 +70,13 @@ where
ViewError: From<StorageClient::StoreError>,
{
/// Creates a new [`ChainWorkerState`] using the provided `storage` client.
#[allow(clippy::too_many_arguments)]
pub async fn load(
config: ChainWorkerConfig,
storage: StorageClient,
certificate_value_cache: Arc<ValueCache<CryptoHash, HashedCertificateValue>>,
blob_cache: Arc<ValueCache<BlobId, Blob>>,
tracked_chains: Option<Arc<sync::RwLock<HashSet<ChainId>>>>,
chain_id: ChainId,
execution_state_receiver: futures::channel::mpsc::UnboundedReceiver<ExecutionRequest>,
runtime_request_sender: std::sync::mpsc::Sender<ServiceRuntimeRequest>,
Expand All @@ -89,6 +92,7 @@ where
runtime_request_sender,
recent_hashed_certificate_values: certificate_value_cache,
recent_blobs: blob_cache,
tracked_chains,
knows_chain_is_active: false,
})
}
Expand Down Expand Up @@ -369,11 +373,41 @@ where
self.recent_blobs.insert(blob).await
}

/// Adds any newly created chains to the set of `tracked_chains`.
fn track_newly_created_chains(&self, block: &ExecutedBlock) {
if let Some(tracked_chains) = self.tracked_chains.as_ref() {
let messages = block.messages().iter().flatten();
let open_chain_message_indices =
messages
.enumerate()
.filter_map(|(index, outgoing_message)| match outgoing_message.message {
Message::System(SystemMessage::OpenChain(_)) => Some(index),
_ => None,
});
let open_chain_message_ids =
open_chain_message_indices.map(|index| block.message_id(index as u32));
let new_chain_ids = open_chain_message_ids.map(ChainId::child);

tracked_chains
.write()
.expect("Panics should not happen while holding a lock to `tracked_chains`")
.extend(new_chain_ids);
}
}

/// Loads pending cross-chain requests.
async fn create_network_actions(&self) -> Result<NetworkActions, WorkerError> {
let mut heights_by_recipient: BTreeMap<_, BTreeMap<_, _>> = Default::default();
let pairs = self.chain.outboxes.try_load_all_entries().await?;
for (target, outbox) in pairs {
let mut targets = self.chain.outboxes.indices().await?;
if let Some(tracked_chains) = self.tracked_chains.as_ref() {
let tracked_chains = tracked_chains
.read()
.expect("Panics should not happen while holding a lock to `tracked_chains`");
targets.retain(|target| tracked_chains.contains(&target.recipient));
}
let outboxes = self.chain.outboxes.try_load_entries(&targets).await?;
for (target, outbox) in targets.into_iter().zip(outboxes) {
let outbox = outbox.expect("Only existing outboxes should be referenced by `indices`");
let heights = outbox.queue.elements().await?;
heights_by_recipient
.entry(target.recipient)
Expand Down
43 changes: 39 additions & 4 deletions linera-core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
convert::Infallible,
iter,
ops::{Deref, DerefMut},
sync::Arc,
sync::{Arc, RwLock},
};

use dashmap::{
Expand Down Expand Up @@ -91,6 +91,9 @@ where
message_policy: MessagePolicy,
/// Whether to block on cross-chain message delivery.
cross_chain_message_delivery: CrossChainMessageDelivery,
/// Chains that should be tracked by the client.
// TODO(#2412): Merge with set of chains the client is receiving notifications from validators
Copy link
Contributor

@ma2bd ma2bd Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok that answers my question. We actually don't want that for the faucet, so perhaps we need two notions of "tracked" (replication allowed <> subscription)

tracked_chains: Arc<RwLock<HashSet<ChainId>>>,
/// References to clients waiting for chain notifications.
notifier: Arc<Notifier<Notification>>,
/// A copy of the storage client so that we don't have to lock the local node client
Expand All @@ -111,10 +114,16 @@ where
storage: S,
max_pending_messages: usize,
cross_chain_message_delivery: CrossChainMessageDelivery,
tracked_chains: impl IntoIterator<Item = ChainId>,
) -> Self {
let state = WorkerState::new_for_client("Client node".to_string(), storage.clone())
.with_allow_inactive_chains(true)
.with_allow_messages_from_deprecated_epochs(true);
let tracked_chains = Arc::new(RwLock::new(tracked_chains.into_iter().collect()));
let state = WorkerState::new_for_client(
"Client node".to_string(),
storage.clone(),
tracked_chains.clone(),
)
.with_allow_inactive_chains(true)
.with_allow_messages_from_deprecated_epochs(true);
let local_node = LocalNodeClient::new(state);

Self {
Expand All @@ -124,6 +133,7 @@ where
max_pending_messages,
message_policy: MessagePolicy::new(BlanketMessagePolicy::Accept, None),
cross_chain_message_delivery,
tracked_chains,
notifier: Arc::new(Notifier::default()),
storage,
}
Expand All @@ -141,6 +151,15 @@ where
&self.local_node
}

#[tracing::instrument(level = "trace", skip(self))]
/// Adds a chain to the set of chains tracked by the local node.
pub fn track_chain(&self, chain_id: ChainId) {
self.tracked_chains
.write()
.expect("Panics should not happen while holding a lock to `tracked_chains`")
.insert(chain_id);
}

#[tracing::instrument(level = "trace", skip_all, fields(chain_id, next_block_height))]
/// Creates a new `ChainClient`.
#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -2501,6 +2520,12 @@ where
executed_block.message_id_for_operation(0, OPEN_CHAIN_MESSAGE_INDEX)
})
.ok_or_else(|| ChainClientError::InternalError("Failed to create new chain"))?;
// Add the new chain to the list of tracked chains
self.client.track_chain(ChainId::child(message_id));
self.client
.local_node
.retry_pending_cross_chain_requests(self.chain_id)
.await?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not even sure we always want that. E.g. the faucet probably shouldn't track all its children. Maybe it's better to default to not doing that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of what we discussed offline: for most users it makes sense to track child chains they create. For the Faucet it's not great, but not tracking them would increase the size of its outboxes, which could be less scalable.

Copy link
Contributor

@ma2bd ma2bd Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be sure: By "tracking" here, do you mean subscribing to notifications or just allowing a local chain state to exist?

(found the answer I think)

return Ok(ClientOutcome::Committed((message_id, certificate)));
}
}
Expand Down Expand Up @@ -2791,6 +2816,16 @@ where
.await
}

#[tracing::instrument(level = "trace")]
/// Handles any cross-chain requests for any pending outgoing messages.
pub async fn retry_pending_outgoing_messages(&self) -> Result<(), ChainClientError> {
self.client
.local_node
.retry_pending_cross_chain_requests(self.chain_id)
.await?;
Ok(())
}

#[tracing::instrument(level = "trace", skip(from, limit))]
pub async fn read_hashed_certificate_values_downward(
&self,
Expand Down
21 changes: 20 additions & 1 deletion linera-core/src/local_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

use std::{
borrow::Cow,
collections::{HashMap, HashSet},
collections::{HashMap, HashSet, VecDeque},
sync::Arc,
};

Expand Down Expand Up @@ -570,4 +570,23 @@ where
}
}
}

/// Handles any pending local cross-chain requests.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn retry_pending_cross_chain_requests(
&self,
sender_chain: ChainId,
) -> Result<(), LocalNodeError> {
let (_response, actions) = self
.node
.state
.handle_chain_info_query(ChainInfoQuery::new(sender_chain))
.await?;
let mut requests = VecDeque::from_iter(actions.cross_chain_requests);
while let Some(request) = requests.pop_front() {
let new_actions = self.node.state.handle_cross_chain_request(request).await?;
requests.extend(new_actions.cross_chain_requests);
}
Ok(())
}
}
1 change: 1 addition & 0 deletions linera-core/src/unit_tests/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ where
storage,
10,
CrossChainMessageDelivery::NonBlocking,
[chain_id],
));
Ok(builder.create_chain_client(
chain_id,
Expand Down
35 changes: 31 additions & 4 deletions linera-core/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

use std::{
borrow::Cow,
collections::{hash_map, BTreeMap, HashMap, VecDeque},
collections::{hash_map, BTreeMap, HashMap, HashSet, VecDeque},
num::NonZeroUsize,
sync::{Arc, LazyLock, Mutex},
sync::{Arc, LazyLock, Mutex, RwLock},
time::Duration,
};

Expand Down Expand Up @@ -235,6 +235,8 @@ where
recent_hashed_certificate_values: Arc<ValueCache<CryptoHash, HashedCertificateValue>>,
/// Cached blobs by `BlobId`.
recent_blobs: Arc<ValueCache<BlobId, Blob>>,
/// Chain IDs that should be tracked by a worker.
tracked_chains: Option<Arc<RwLock<HashSet<ChainId>>>>,
/// One-shot channels to notify callers when messages of a particular chain have been
/// delivered.
delivery_notifiers: Arc<Mutex<DeliveryNotifiers>>,
Expand Down Expand Up @@ -264,15 +266,30 @@ where
chain_worker_config: ChainWorkerConfig::default().with_key_pair(key_pair),
recent_hashed_certificate_values: Arc::new(ValueCache::default()),
recent_blobs: Arc::new(ValueCache::default()),
tracked_chains: None,
delivery_notifiers: Arc::default(),
chain_worker_tasks: Arc::default(),
chain_workers: Arc::new(Mutex::new(LruCache::new(*CHAIN_WORKER_LIMIT))),
}
}

#[tracing::instrument(level = "trace", skip(nickname, storage))]
pub fn new_for_client(nickname: String, storage: StorageClient) -> Self {
Self::new(nickname, None, storage)
pub fn new_for_client(
nickname: String,
storage: StorageClient,
tracked_chains: Arc<RwLock<HashSet<ChainId>>>,
) -> Self {
WorkerState {
nickname,
storage,
chain_worker_config: ChainWorkerConfig::default(),
recent_hashed_certificate_values: Arc::new(ValueCache::default()),
recent_blobs: Arc::new(ValueCache::default()),
tracked_chains: Some(tracked_chains),
delivery_notifiers: Arc::default(),
chain_worker_tasks: Arc::default(),
chain_workers: Arc::new(Mutex::new(LruCache::new(*CHAIN_WORKER_LIMIT))),
}
}

#[tracing::instrument(level = "trace", skip(self, value))]
Expand All @@ -288,6 +305,15 @@ where
self
}

/// Configures the subset of chains that this worker is tracking.
pub fn with_tracked_chains(
mut self,
tracked_chains: impl IntoIterator<Item = ChainId>,
) -> Self {
self.tracked_chains = Some(Arc::new(RwLock::new(tracked_chains.into_iter().collect())));
self
}

/// Returns an instance with the specified grace period, in microseconds.
///
/// Blocks with a timestamp this far in the future will still be accepted, but the validator
Expand Down Expand Up @@ -665,6 +691,7 @@ where
self.storage.clone(),
self.recent_hashed_certificate_values.clone(),
self.recent_blobs.clone(),
self.tracked_chains.clone(),
chain_id,
)
.await?;
Expand Down
1 change: 1 addition & 0 deletions linera-service/src/linera/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,7 @@ impl Job {
ViewError: From<S::StoreError>,
{
let state = WorkerState::new("Local node".to_string(), None, storage)
.with_tracked_chains([message_id.chain_id, chain_id])
.with_allow_inactive_chains(true)
.with_allow_messages_from_deprecated_epochs(true);
let node_client = LocalNodeClient::new(state);
Expand Down
Loading