From d3a83670a19ac23f8709aebfafd6a5330caa6834 Mon Sep 17 00:00:00 2001 From: Trantorian1 <114066155+Trantorian1@users.noreply.github.com> Date: Mon, 16 Dec 2024 09:55:24 +0100 Subject: [PATCH] feat(services): reworked Madara services for better cancellation control (#405) --- CHANGELOG.md | 1 + Cargo.lock | 5 +- Cargo.toml | 3 + README.md | 187 ++- crates/client/block_production/src/lib.rs | 24 +- crates/client/db/src/bonsai_db.rs | 5 +- crates/client/db/src/lib.rs | 11 +- crates/client/devnet/src/lib.rs | 2 +- crates/client/eth/src/l1_gas_price.rs | 51 +- crates/client/eth/src/l1_messaging.rs | 44 +- crates/client/eth/src/state_update.rs | 112 +- crates/client/eth/src/sync.rs | 26 +- crates/client/gateway/client/src/methods.rs | 2 +- crates/client/gateway/server/src/handler.rs | 5 +- crates/client/gateway/server/src/router.rs | 7 +- crates/client/gateway/server/src/service.rs | 65 +- crates/client/rpc/src/providers/mod.rs | 85 + .../rpc/src/versions/admin/v0_1_0/api.rs | 76 +- .../versions/admin/v0_1_0/methods/services.rs | 91 +- .../versions/admin/v0_1_0/methods/write.rs | 2 +- .../methods/read/get_block_with_tx_hashes.rs | 1 - .../user/v0_7_1/methods/read/get_nonce.rs | 1 - .../versions/user/v0_7_1/methods/write/mod.rs | 6 +- crates/client/sync/src/fetch/fetchers.rs | 71 +- .../sync/src/fetch/fetchers_real_fgw_test.rs | 18 +- crates/client/sync/src/fetch/mod.rs | 100 +- crates/client/sync/src/l2.rs | 58 +- crates/client/sync/src/lib.rs | 48 +- .../sync/src/tests/utils/read_resource.rs | 2 +- crates/client/telemetry/src/lib.rs | 154 +- crates/node/src/cli/l1.rs | 2 +- crates/node/src/cli/{sync.rs => l2.rs} | 21 +- crates/node/src/cli/mod.rs | 12 +- crates/node/src/main.rs | 288 ++-- crates/node/src/service/block_production.rs | 157 +- crates/node/src/service/gateway.rs | 59 +- crates/node/src/service/l1.rs | 52 +- crates/node/src/service/{sync.rs => l2.rs} | 44 +- crates/node/src/service/mod.rs | 4 +- crates/node/src/service/rpc/mod.rs | 185 ++- crates/node/src/service/rpc/server.rs | 43 +- crates/primitives/transactions/src/lib.rs | 1 - crates/primitives/utils/Cargo.toml | 3 + crates/primitives/utils/src/lib.rs | 39 - crates/primitives/utils/src/service.rs | 1364 ++++++++++++++--- crates/proc-macros/Cargo.toml | 15 +- crates/proc-macros/src/lib.rs | 18 +- 47 files changed, 2366 insertions(+), 1204 deletions(-) rename crates/node/src/cli/{sync.rs => l2.rs} (91%) rename crates/node/src/service/{sync.rs => l2.rs} (65%) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae4714d6a..1d830c8b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Next release +- feat(services): reworked Madara services for better cancellation control - feat: fetch eth/strk price and sync strk gas price - feat(block_production): continue pending block on restart - feat(mempool): mempool transaction saving on db diff --git a/Cargo.lock b/Cargo.lock index 1cc47c31f..66a6c5f5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -5352,6 +5352,7 @@ name = "m-proc-macros" version = "0.7.0" dependencies = [ "indoc 2.0.5", + "jsonrpsee", "proc-macro2", "quote", "syn 2.0.89", @@ -6177,12 +6178,14 @@ dependencies = [ "async-trait", "crypto-bigint", "futures", + "num-traits 0.2.19", "opentelemetry", "opentelemetry-appender-tracing", "opentelemetry-otlp", "opentelemetry-semantic-conventions", "opentelemetry-stdout", "opentelemetry_sdk", + "paste", "rand", "rayon", "rstest 0.18.2", diff --git a/Cargo.toml b/Cargo.toml index e46cfc2fc..12273e3be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -194,6 +194,9 @@ num-bigint = "0.4" primitive-types = "0.12" rand = "0.8" indoc = "2" +proc-macro2 = "1.0.86" +quote = "1.0.26" +syn = { version = "2.0.39", features = ["full"] } reqwest = { version = "0.12", features = ["blocking", "json"] } rstest = "0.18" serde = { version = "1.0", default-features = false, features = ["std"] } diff --git a/README.md b/README.md index 9dd0f44ae..c3b71ad31 100644 --- a/README.md +++ b/README.md @@ -29,15 +29,17 @@ Madara is a powerful Starknet client written in Rust. - ⚙️ [Configuration](#%EF%B8%8F-configuration) - [Basic Command-Line Options](#basic-command-line-options) - [Environment variables](#environment-variables) -- 🌐 [Interactions](#-interactions) + 🌐 [Interactions](#-interactions) - [Supported JSON-RPC Methods](#supported-json-rpc-methods) - [Madara-specific JSON-RPC Methods](#madara-specific-json-rpc-methods) - [Example of Calling a JSON-RPC Method](#example-of-calling-a-json-rpc-method) +- 📚 [Database Migration](#-database-migration) + - [Warp Update](#warp-update) + - [Running without `--warp-update-sender`](#running-without---warp-update-sender) - ✅ [Supported Features](#-supported-features) - [Starknet Compliant](#starknet-compliant) - [Feeder-Gateway State Synchronization](#feeder-gateway-state-synchronization) - [State Commitment Computation](#state-commitment-computation) - - [Database Migration](#database-migration) - 💬 [Get in touch](#-get-in-touch) - [Contributing](#contributing) - [Partnerships](#partnerships) @@ -113,7 +115,7 @@ cargo run --release -- \ --name Madara \ --sequencer \ --base-path /var/lib/madara \ - --preset test \ + --preset sepolia \ --l1-endpoint ${ETHEREUM_API_URL} ``` @@ -126,7 +128,7 @@ cargo run --release -- \ --name Madara \ --devnet \ --base-path /var/lib/madara \ - --preset test + --preset sepolia ``` > [!NOTE] @@ -426,16 +428,11 @@ are exposed on a separate port **9943** unless specified otherwise with <details> <summary>Status Methods</summary> -| Method | About | -| -------------------- | ---------------------------------------------------- | -| `madara_ping` | Return the unix time at which this method was called | -| `madara_shutdown` | Gracefully stops the running node | -| `madara_rpcDisable` | Disables user-facing rpc services | -| `madara_rpcEnable` | Enables user-facing rpc services | -| `madara_rpcRestart` | Restarts user-facing rpc services | -| `madara_syncDisable` | Disables l1 and l2 sync services | -| `madara_syncEnable` | Enables l1 and l2 sync services | -| `madara_syncRestart` | Restarts l1 and l2 sync services | +| Method | About | +| ----------------- | ---------------------------------------------------- | +| `madara_ping` | Return the unix time at which this method was called | +| `madara_shutdown` | Gracefully stops the running node | +| `madara_service` | Sets the status of one or more services | </details> @@ -544,6 +541,118 @@ into the subscription stream: Where `you-subscription-id` corresponds to the value of the `subscription` field which is returned with each websocket response. +## 📚 Database Migration + +[⬅️ back to top](#-madara-starknet-client) + +When migration to a newer version of Madara you might need to update your +database. Instead of re-synchronizing the entirety of your chain's state from +genesis, you can use Madara's **warp update** feature. This is essentially a +form of trusted sync with better performances as it is run from a local source. + +### Warp Update + +Warp update requires a working database source for the migration. If you do not +already have one, you can use the following command to generate a sample +database: + +```bash +cargo run --release -- \ + --name madara \ + --network mainnet \ + --full \ + --l1-sync-disabled `# We disable sync, for testing purposes` \ + --n-blocks-to-sync 1000 `# Only synchronize the first 1000 blocks` \ + --stop-on-sync `# ...and shutdown the node once this is done` +``` + +To begin the database migration, you will need to start your node with +[admin methods](#madara-specific-json-rpc-methods) and +[feeder gateway](#feeder-gateway-state-synchronization) enabled. This will be +the _source_ of the migration. You can do this with the `--warp-update-sender` +[preset](#4-presets): + +```bash +cargo run --release -- \ + --name Sender \ + --full `# This also works with other types of nodes` \ + --network mainnet \ + --warp-update-sender \ + --l1-sync-disabled `# We disable sync, for testing purposes` \ + --l2-sync-disabled +``` + +> [!TIP] +> Here, we have disabled sync for testing purposes, so the migration only +> synchronizes the blocks that were already present in the source node's +> database. In a production usecase, you most likely want the source node to +> keep synchronizing with an `--l1-endpoint`, that way when the migration is +> complete the receiver is fully up-to-date with any state that might have been +> produced by the chain _during the migration_. + +You will then need to start a second node to synchronize the state of your +database: + +```bash +cargo run --release -- \ + --name Receiver \ + --base-path /tmp/madara_new `# Where you want the new database to be stored` \ + --full \ + --network mainnet \ + --l1-sync-disabled `# We disable sync, for testing purposes` \ + --warp-update-receiver \ + --warp-update-shutdown-receiver `# Shuts down the receiver once the migration has completed` +``` + +This will start generating a new up-to-date database under `/tmp/madara_new`. +Once this process is over, the receiver node will automatically shutdown. + +> [!TIP] +> There also exists a `--warp-update--shutdown-sender` option which allows the +> receiver to take the place of the sender in certain limited circumstances. + +### Running without `--warp-update-sender` + +Up until now we have had to start a node with `--warp-update-sender` to begin +a migration, but this is only a [preset](#4-presets). In a production +environment, you can start your node with the following arguments and achieve +the same results: + +```bash +cargo run --release -- \ + --name Sender \ + --full `# This also works with other types of nodes` \ + --network mainnet \ + --feeder-gateway-enable `# The source of the migration` \ + --gateway-port 8080 `# Default port, change as required` \ + --rpc-admin `# Used to shutdown the sender after the migration` \ + --rpc-admin-port 9943 `# Default port, change as required` \ + --l1-sync-disabled `# We disable sync, for testing purposes` \ + --l2-sync-disabled +``` + +`--warp-update-receiver` doesn't override any cli arguments but is still needed +on the receiver end to start the migration. Here is an example of using it with +custom ports: + +> [!IMPORTANT] +> If you have already run a node with `--warp-update-receiver` following the +> examples above, remember to delete its database with `rm -rf /tmp/madara_new`. + +```bash +cargo run --release -- \ + --name Receiver \ + --base-path /tmp/madara_new `# Where you want the new database to be stored` \ + --full \ + --network mainnet \ + --l1-sync-disabled `# We disable sync, for testing purposes` \ + --warp-update-port-rpc 9943 `# Same as set with --rpc-admin-port on the sender` \ + --warp-update-port-fgw 8080 `# Same as set with --gateway-port on the sender` \ + --feeder-gateway-enable \ + --warp-update-receiver \ + --warp-update-shutdown-receiver `# Shuts down the receiver once the migration has completed` +``` + ## ✅ Supported Features [⬅️ back to top](#-madara-starknet-client) @@ -569,59 +678,11 @@ a regular sync. ### State Commitment Computation -Madara supports merkelized state verification through its own implementation of +Madara supports merkelized state commitments through its own implementation of Besu Bonsai Merkle Tries. See the [bonsai lib](https://github.com/madara-alliance/bonsai-trie). You can read more about Starknet Block structure and how it affects state commitment [here](https://docs.starknet.io/architecture-and-concepts/network-architecture/block-structure/). -### Database Migration - -When migration to a newer version of Madara you might need to update your -database. Instead of re-synchronizing the entirety of your chain's state from -genesis, you can use Madara's **warp update** feature. - -> [!NOTE] -> Warp update requires an already synchronized _local_ node with a working -> database. - -To begin the database migration, you will need to start an existing node with -[admin methods](#madara-specific-json-rpc-methods) and -[feeder gateway](#feeder-gateway-state-synchronization) enabled. This will be -the _source_ of the migration. You can do this with the `--warp-update-sender` -[preset](#4.-presets): - -```bash -cargo run --releasae -- \ - --name Sender \ - --full \ # This also works with other types of nodes - --network mainnet \ - --warp-update-sender -``` - -You will then need to start a second node to synchronize the state of your -database: - -```bash -cargo run --releasae -- \ - --name Receiver \ - --base-path /tmp/madara_new \ # Where you want the new database to be stored - --full \ - --network mainnet \ - --l1-endpoint https://*** \ - --warp-update-receiver -``` - -This will start generating a new up-to-date database under `/tmp/madara_new`. -Once this process is over, the warp update sender node will automatically -shutdown while the warp update receiver will take its place. - -> [!WARNING] -> As of now, the warp update receiver has its rpc disabled, even after the -> migration process has completed. This will be fixed in the future, so that -> services that would otherwise conflict with the sender node will automatically -> start after the migration has finished, allowing for migrations with 0 -> downtime. - ## 💬 Get in touch [⬅️ back to top](#-madara-starknet-client) diff --git a/crates/client/block_production/src/lib.rs b/crates/client/block_production/src/lib.rs index e7d310b83..623d9117d 100644 --- a/crates/client/block_production/src/lib.rs +++ b/crates/client/block_production/src/lib.rs @@ -35,7 +35,6 @@ use mp_convert::ToFelt; use mp_receipt::from_blockifier_execution_info; use mp_state_update::{ContractStorageDiffItem, StateDiff, StorageEntry}; use mp_transactions::TransactionWithHash; -use mp_utils::graceful_shutdown; use mp_utils::service::ServiceContext; use opentelemetry::KeyValue; use starknet_api::core::ClassHash; @@ -94,7 +93,7 @@ pub struct BlockProductionTask<Mempool: MempoolProvider> { pub(crate) executor: TransactionExecutor<BlockifierStateAdapter>, l1_data_provider: Arc<dyn L1DataProvider>, current_pending_tick: usize, - metrics: BlockProductionMetrics, + metrics: Arc<BlockProductionMetrics>, } impl<Mempool: MempoolProvider> BlockProductionTask<Mempool> { @@ -108,7 +107,7 @@ impl<Mempool: MempoolProvider> BlockProductionTask<Mempool> { backend: Arc<MadaraBackend>, importer: Arc<BlockImporter>, mempool: Arc<Mempool>, - metrics: BlockProductionMetrics, + metrics: Arc<BlockProductionMetrics>, l1_data_provider: Arc<dyn L1DataProvider>, ) -> Result<Self, Error> { let (pending_block, state_diff, pcs) = match backend.get_block(&DbBlockId::Pending)? { @@ -461,7 +460,7 @@ impl<Mempool: MempoolProvider> BlockProductionTask<Mempool> { } #[tracing::instrument(skip(self, ctx), fields(module = "BlockProductionTask"))] - pub async fn block_production_task(&mut self, ctx: ServiceContext) -> Result<(), anyhow::Error> { + pub async fn block_production_task(mut self, mut ctx: ServiceContext) -> Result<(), anyhow::Error> { let start = tokio::time::Instant::now(); let mut interval_block_time = tokio::time::interval_at(start, self.backend.chain_config().block_time); @@ -480,10 +479,13 @@ impl<Mempool: MempoolProvider> BlockProductionTask<Mempool> { instant = interval_block_time.tick() => { if let Err(err) = self.on_block_time().await { tracing::error!("Block production task has errored: {err:#}"); - // Clear pending block. The reason we do this is because if the error happened because the closed - // block is invalid or has not been saved properly, we want to avoid redoing the same error in the next - // block. So we drop all the transactions in the pending block just in case. - // If the problem happened after the block was closed and saved to the db, this will do nothing. + // Clear pending block. The reason we do this is because + // if the error happened because the closed block is + // invalid or has not been saved properly, we want to + // avoid redoing the same error in the next block. So we + // drop all the transactions in the pending block just + // in case. If the problem happened after the block was + // closed and saved to the db, this will do nothing. if let Err(err) = self.backend.clear_pending_block() { tracing::error!("Error while clearing the pending block in recovery of block production error: {err:#}"); } @@ -495,8 +497,8 @@ impl<Mempool: MempoolProvider> BlockProductionTask<Mempool> { let n_pending_ticks_per_block = self.backend.chain_config().n_pending_ticks_per_block(); if self.current_pending_tick == 0 || self.current_pending_tick >= n_pending_ticks_per_block { - // first tick is ignored. - // out of range ticks are also ignored. + // First tick is ignored. Out of range ticks are also + // ignored. self.current_pending_tick += 1; continue } @@ -506,7 +508,7 @@ impl<Mempool: MempoolProvider> BlockProductionTask<Mempool> { } self.current_pending_tick += 1; }, - _ = graceful_shutdown(&ctx) => break, + _ = ctx.cancelled() => break, } } diff --git a/crates/client/db/src/bonsai_db.rs b/crates/client/db/src/bonsai_db.rs index f3ee9c53a..3b07f2bfd 100644 --- a/crates/client/db/src/bonsai_db.rs +++ b/crates/client/db/src/bonsai_db.rs @@ -250,7 +250,10 @@ impl BonsaiDatabase for BonsaiTransaction { } impl BonsaiPersistentDatabase<BasicId> for BonsaiDb { - type Transaction<'a> = BonsaiTransaction where Self: 'a; + type Transaction<'a> + = BonsaiTransaction + where + Self: 'a; type DatabaseError = DbError; /// this is called upstream, but we ignore it for now because we create the snapshot in [`crate::MadaraBackend::store_block`] diff --git a/crates/client/db/src/lib.rs b/crates/client/db/src/lib.rs index 61a92269f..6811cf3f6 100644 --- a/crates/client/db/src/lib.rs +++ b/crates/client/db/src/lib.rs @@ -6,7 +6,7 @@ use bonsai_db::{BonsaiDb, DatabaseKeyMapping}; use bonsai_trie::{BonsaiStorage, BonsaiStorageConfig}; use db_metrics::DbMetrics; use mp_chain_config::ChainConfig; -use mp_utils::service::{MadaraService, Service}; +use mp_utils::service::{MadaraServiceId, PowerOfTwo, Service, ServiceId}; use rocksdb::backup::{BackupEngine, BackupEngineOptions}; use rocksdb::{ BoundColumnFamily, ColumnFamilyDescriptor, DBWithThreadMode, Env, FlushOptions, MultiThreaded, WriteOptions, @@ -346,9 +346,12 @@ impl DatabaseService { } } -impl Service for DatabaseService { - fn id(&self) -> MadaraService { - MadaraService::Database +impl Service for DatabaseService {} + +impl ServiceId for DatabaseService { + #[inline(always)] + fn svc_id(&self) -> PowerOfTwo { + MadaraServiceId::Database.svc_id() } } diff --git a/crates/client/devnet/src/lib.rs b/crates/client/devnet/src/lib.rs index f8e0a7e76..7f0dae113 100644 --- a/crates/client/devnet/src/lib.rs +++ b/crates/client/devnet/src/lib.rs @@ -342,7 +342,7 @@ mod tests { Arc::clone(&backend), Arc::clone(&importer), Arc::clone(&mempool), - metrics, + Arc::new(metrics), Arc::clone(&l1_data_provider), ) .unwrap(); diff --git a/crates/client/eth/src/l1_gas_price.rs b/crates/client/eth/src/l1_gas_price.rs index ad5986eec..4d02f4a37 100644 --- a/crates/client/eth/src/l1_gas_price.rs +++ b/crates/client/eth/src/l1_gas_price.rs @@ -4,23 +4,27 @@ use alloy::providers::Provider; use anyhow::Context; use bigdecimal::BigDecimal; use mc_mempool::{GasPriceProvider, L1DataProvider}; -use std::time::{Duration, UNIX_EPOCH}; +use std::{ + sync::Arc, + time::{Duration, UNIX_EPOCH}, +}; -use mp_utils::{service::ServiceContext, wait_or_graceful_shutdown}; +use mp_utils::service::ServiceContext; use std::time::SystemTime; pub async fn gas_price_worker_once( eth_client: &EthereumClient, - l1_gas_provider: GasPriceProvider, + l1_gas_provider: &GasPriceProvider, gas_price_poll_ms: Duration, ) -> anyhow::Result<()> { - match update_gas_price(eth_client, l1_gas_provider.clone()).await { + match update_gas_price(eth_client, l1_gas_provider).await { Ok(_) => tracing::trace!("Updated gas prices"), Err(e) => tracing::error!("Failed to update gas prices: {:?}", e), } let last_update_timestamp = l1_gas_provider.get_gas_prices_last_update(); let duration_since_last_update = SystemTime::now().duration_since(last_update_timestamp)?; + let last_update_timestemp = last_update_timestamp.duration_since(UNIX_EPOCH).expect("SystemTime before UNIX EPOCH!").as_micros(); if duration_since_last_update > 10 * gas_price_poll_ms { @@ -31,24 +35,26 @@ pub async fn gas_price_worker_once( ); } - Ok(()) + anyhow::Ok(()) } pub async fn gas_price_worker( - eth_client: &EthereumClient, + eth_client: Arc<EthereumClient>, l1_gas_provider: GasPriceProvider, gas_price_poll_ms: Duration, - ctx: ServiceContext, + mut ctx: ServiceContext, ) -> anyhow::Result<()> { l1_gas_provider.update_last_update_timestamp(); let mut interval = tokio::time::interval(gas_price_poll_ms); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - while wait_or_graceful_shutdown(interval.tick(), &ctx).await.is_some() { - gas_price_worker_once(eth_client, l1_gas_provider.clone(), gas_price_poll_ms).await?; + + while ctx.run_until_cancelled(interval.tick()).await.is_some() { + gas_price_worker_once(ð_client, &l1_gas_provider, gas_price_poll_ms).await?; } - Ok(()) + + anyhow::Ok(()) } -async fn update_gas_price(eth_client: &EthereumClient, l1_gas_provider: GasPriceProvider) -> anyhow::Result<()> { +async fn update_gas_price(eth_client: &EthereumClient, l1_gas_provider: &GasPriceProvider) -> anyhow::Result<()> { let block_number = eth_client.get_latest_block_number().await?; let fee_history = eth_client.provider.get_fee_history(300, BlockNumberOrTag::Number(block_number), &[]).await?; @@ -100,7 +106,10 @@ async fn update_gas_price(eth_client: &EthereumClient, l1_gas_provider: GasPrice Ok(()) } -async fn update_l1_block_metrics(eth_client: &EthereumClient, l1_gas_provider: GasPriceProvider) -> anyhow::Result<()> { +async fn update_l1_block_metrics( + eth_client: &EthereumClient, + l1_gas_provider: &GasPriceProvider, +) -> anyhow::Result<()> { // Get the latest block number let latest_block_number = eth_client.get_latest_block_number().await?; @@ -144,7 +153,7 @@ mod eth_client_gas_price_worker_test { let l1_gas_provider = l1_gas_provider.clone(); async move { gas_price_worker( - ð_client, + Arc::new(eth_client), l1_gas_provider, Duration::from_millis(200), ServiceContext::new_for_testing(), @@ -183,7 +192,7 @@ mod eth_client_gas_price_worker_test { let l1_gas_provider = GasPriceProvider::new(); // Run the worker for a short time - let worker_handle = gas_price_worker_once(ð_client, l1_gas_provider.clone(), Duration::from_millis(200)); + let worker_handle = gas_price_worker_once(ð_client, &l1_gas_provider, Duration::from_millis(200)); // Wait for the worker to complete worker_handle.await.expect("issue with the gas worker"); @@ -204,7 +213,7 @@ mod eth_client_gas_price_worker_test { l1_gas_provider.set_gas_price_sync_enabled(false); // Run the worker for a short time - let worker_handle = gas_price_worker_once(ð_client, l1_gas_provider.clone(), Duration::from_millis(200)); + let worker_handle = gas_price_worker_once(ð_client, &l1_gas_provider, Duration::from_millis(200)); // Wait for the worker to complete worker_handle.await.expect("issue with the gas worker"); @@ -225,7 +234,7 @@ mod eth_client_gas_price_worker_test { l1_gas_provider.set_data_gas_price_sync_enabled(false); // Run the worker for a short time - let worker_handle = gas_price_worker_once(ð_client, l1_gas_provider.clone(), Duration::from_millis(200)); + let worker_handle = gas_price_worker_once(ð_client, &l1_gas_provider, Duration::from_millis(200)); // Wait for the worker to complete worker_handle.await.expect("issue with the gas worker"); @@ -261,8 +270,10 @@ mod eth_client_gas_price_worker_test { }); mock_server.mock(|when, then| { - when.method("POST").path("/").json_body_obj(&serde_json::json!({"id":0,"jsonrpc":"2.0","method":"eth_blockNumber"})); - then.status(200).json_body_obj(&serde_json::json!({"jsonrpc":"2.0","id":1,"result":"0x0137368e"} )); + when.method("POST") + .path("/") + .json_body_obj(&serde_json::json!({"id":0,"jsonrpc":"2.0","method":"eth_blockNumber"})); + then.status(200).json_body_obj(&serde_json::json!({"jsonrpc":"2.0","id":1,"result":"0x0137368e"})); }); let l1_gas_provider = GasPriceProvider::new(); @@ -274,7 +285,7 @@ mod eth_client_gas_price_worker_test { let result = timeout( timeout_duration, gas_price_worker( - ð_client, + Arc::new(eth_client), l1_gas_provider.clone(), Duration::from_millis(200), ServiceContext::new_for_testing(), @@ -310,7 +321,7 @@ mod eth_client_gas_price_worker_test { l1_gas_provider.update_last_update_timestamp(); // Update gas prices - update_gas_price(ð_client, l1_gas_provider.clone()).await.expect("Failed to update gas prices"); + update_gas_price(ð_client, &l1_gas_provider).await.expect("Failed to update gas prices"); // Access the updated gas prices let updated_prices = l1_gas_provider.get_gas_prices(); diff --git a/crates/client/eth/src/l1_messaging.rs b/crates/client/eth/src/l1_messaging.rs index 78e92c8db..ee99286f1 100644 --- a/crates/client/eth/src/l1_messaging.rs +++ b/crates/client/eth/src/l1_messaging.rs @@ -8,7 +8,6 @@ use anyhow::Context; use futures::StreamExt; use mc_db::{l1_db::LastSyncedEventBlock, MadaraBackend}; use mc_mempool::{Mempool, MempoolProvider}; -use mp_utils::channel_wait_or_graceful_shutdown; use mp_utils::service::ServiceContext; use starknet_api::core::{ChainId, ContractAddress, EntryPointSelector, Nonce}; use starknet_api::transaction::{Calldata, L1HandlerTransaction, TransactionVersion}; @@ -37,11 +36,11 @@ impl EthereumClient { } pub async fn sync( - backend: &MadaraBackend, - client: &EthereumClient, - chain_id: &ChainId, + backend: Arc<MadaraBackend>, + client: Arc<EthereumClient>, + chain_id: ChainId, mempool: Arc<Mempool>, - ctx: ServiceContext, + mut ctx: ServiceContext, ) -> anyhow::Result<()> { tracing::info!("⟠ Starting L1 Messages Syncing..."); @@ -66,7 +65,8 @@ pub async fn sync( "Failed to watch event filter - Ensure you are using an L1 RPC endpoint that points to an archive node", )? .into_stream(); - while let Some(event_result) = channel_wait_or_graceful_shutdown(event_stream.next(), &ctx).await { + + while let Some(Some(event_result)) = ctx.run_until_cancelled(event_stream.next()).await { if let Ok((event, meta)) = event_result { tracing::info!( "⟠ Processing L1 Message from block: {:?}, transaction_hash: {:?}, log_index: {:?}, fromAddress: {:?}", @@ -97,7 +97,7 @@ pub async fn sync( continue; } - match process_l1_message(backend, &event, &meta.block_number, &meta.log_index, chain_id, mempool.clone()) + match process_l1_message(&backend, &event, &meta.block_number, &meta.log_index, &chain_id, mempool.clone()) .await { Ok(Some(tx_hash)) => { @@ -407,8 +407,14 @@ mod l1_messaging_tests { let worker_handle = { let db = Arc::clone(&db); tokio::spawn(async move { - sync(db.backend(), ð_client, &chain_config.chain_id, mempool, ServiceContext::new_for_testing()) - .await + sync( + Arc::clone(db.backend()), + Arc::new(eth_client), + chain_config.chain_id.clone(), + mempool, + ServiceContext::new_for_testing(), + ) + .await }) }; @@ -462,8 +468,14 @@ mod l1_messaging_tests { let worker_handle = { let db = Arc::clone(&db); tokio::spawn(async move { - sync(db.backend(), ð_client, &chain_config.chain_id, mempool, ServiceContext::new_for_testing()) - .await + sync( + Arc::clone(db.backend()), + Arc::new(eth_client), + chain_config.chain_id.clone(), + mempool, + ServiceContext::new_for_testing(), + ) + .await }) }; @@ -512,8 +524,14 @@ mod l1_messaging_tests { let worker_handle = { let db = Arc::clone(&db); tokio::spawn(async move { - sync(db.backend(), ð_client, &chain_config.chain_id, mempool, ServiceContext::new_for_testing()) - .await + sync( + Arc::clone(db.backend()), + Arc::new(eth_client), + chain_config.chain_id.clone(), + mempool, + ServiceContext::new_for_testing(), + ) + .await }) }; diff --git a/crates/client/eth/src/state_update.rs b/crates/client/eth/src/state_update.rs index fce55b777..c38bce421 100644 --- a/crates/client/eth/src/state_update.rs +++ b/crates/client/eth/src/state_update.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::client::{L1BlockMetrics, StarknetCoreContract}; use crate::{ client::EthereumClient, @@ -6,14 +8,13 @@ use crate::{ use anyhow::Context; use futures::StreamExt; use mc_db::MadaraBackend; -use mp_convert::ToFelt; -use mp_transactions::MAIN_CHAIN_ID; -use mp_utils::channel_wait_or_graceful_shutdown; use mp_utils::service::ServiceContext; use serde::Deserialize; -use starknet_api::core::ChainId; use starknet_types_core::felt::Felt; +const ERR_ARCHIVE: &str = + "Failed to watch event filter - Ensure you are using an L1 RPC endpoint that points to an archive node"; + #[derive(Debug, Clone, Deserialize, PartialEq)] pub struct L1StateUpdate { pub block_number: u64, @@ -30,85 +31,59 @@ pub async fn get_initial_state(client: &EthereumClient) -> anyhow::Result<L1Stat Ok(L1StateUpdate { global_root, block_number, block_hash }) } -/// Subscribes to the LogStateUpdate event from the Starknet core contract and store latest -/// verified state -pub async fn listen_and_update_state( - eth_client: &EthereumClient, - backend: &MadaraBackend, - block_metrics: &L1BlockMetrics, - chain_id: ChainId, - ctx: ServiceContext, -) -> anyhow::Result<()> { - let event_filter = eth_client.l1_core_contract.event_filter::<StarknetCoreContract::LogStateUpdate>(); - - let mut event_stream = event_filter - .watch() - .await - .context( - "Failed to watch event filter - Ensure you are using an L1 RPC endpoint that points to an archive node", - )? - .into_stream(); - - while let Some(event_result) = channel_wait_or_graceful_shutdown(event_stream.next(), &ctx).await { - let log = event_result.context("listening for events")?; - let format_event: L1StateUpdate = - convert_log_state_update(log.0.clone()).context("formatting event into an L1StateUpdate")?; - update_l1(backend, format_event, block_metrics, chain_id.clone())?; - } - - Ok(()) -} - pub fn update_l1( backend: &MadaraBackend, state_update: L1StateUpdate, block_metrics: &L1BlockMetrics, - chain_id: ChainId, ) -> anyhow::Result<()> { - // This is a provisory check to avoid updating the state with an L1StateUpdate that should not have been detected - // - // TODO: Remove this check when the L1StateUpdate is properly verified - if state_update.block_number > 500000u64 || chain_id.to_felt() == MAIN_CHAIN_ID { - tracing::info!( - "🔄 Updated L1 head #{} ({}) with state root ({})", - state_update.block_number, - trim_hash(&state_update.block_hash), - trim_hash(&state_update.global_root) - ); + tracing::info!( + "🔄 Updated L1 head #{} ({}) with state root ({})", + state_update.block_number, + trim_hash(&state_update.block_hash), + trim_hash(&state_update.global_root) + ); - block_metrics.l1_block_number.record(state_update.block_number, &[]); + block_metrics.l1_block_number.record(state_update.block_number, &[]); - backend - .write_last_confirmed_block(state_update.block_number) - .context("Setting l1 last confirmed block number")?; - tracing::debug!("update_l1: wrote last confirmed block number"); - } + backend.write_last_confirmed_block(state_update.block_number).context("Setting l1 last confirmed block number")?; + tracing::debug!("update_l1: wrote last confirmed block number"); Ok(()) } pub async fn state_update_worker( - backend: &MadaraBackend, - eth_client: &EthereumClient, - chain_id: ChainId, - ctx: ServiceContext, + backend: Arc<MadaraBackend>, + eth_client: Arc<EthereumClient>, + mut ctx: ServiceContext, ) -> anyhow::Result<()> { // Clear L1 confirmed block at startup backend.clear_last_confirmed_block().context("Clearing l1 last confirmed block number")?; tracing::debug!("update_l1: cleared confirmed block number"); tracing::info!("🚀 Subscribed to L1 state verification"); - // ideally here there would be one service which will update the l1 gas prices and another one for messages and one that's already present is state update - // Get and store the latest verified state - let initial_state = get_initial_state(eth_client).await.context("Getting initial ethereum state")?; - update_l1(backend, initial_state, ð_client.l1_block_metrics, chain_id.clone())?; + // This does not seem to play well with anvil + #[cfg(not(test))] + { + let initial_state = get_initial_state(ð_client).await.context("Getting initial ethereum state")?; + update_l1(&backend, initial_state, ð_client.l1_block_metrics)?; + } - // Listen to LogStateUpdate (0x77552641) update and send changes continusly - listen_and_update_state(eth_client, backend, ð_client.l1_block_metrics, chain_id, ctx) - .await - .context("Subscribing to the LogStateUpdate event")?; + // Listen to LogStateUpdate (0x77552641) update and send changes continuously + let event_filter = eth_client.l1_core_contract.event_filter::<StarknetCoreContract::LogStateUpdate>(); - Ok(()) + let mut event_stream = match ctx.run_until_cancelled(event_filter.watch()).await { + Some(res) => res.context(ERR_ARCHIVE)?.into_stream(), + None => return anyhow::Ok(()), + }; + + while let Some(Some(event_result)) = ctx.run_until_cancelled(event_stream.next()).await { + let log = event_result.context("listening for events")?; + let format_event: L1StateUpdate = + convert_log_state_update(log.0.clone()).context("formatting event into an L1StateUpdate")?; + update_l1(&backend, format_event, ð_client.l1_block_metrics)?; + } + + anyhow::Ok(()) } #[cfg(test)] @@ -194,14 +169,9 @@ mod eth_client_event_subscription_test { let listen_handle = { let db = Arc::clone(&db); tokio::spawn(async move { - listen_and_update_state( - ð_client, - db.backend(), - ð_client.l1_block_metrics, - chain_info.chain_id.clone(), - ServiceContext::new_for_testing(), - ) - .await + state_update_worker(Arc::clone(db.backend()), Arc::new(eth_client), ServiceContext::new_for_testing()) + .await + .unwrap() }) }; diff --git a/crates/client/eth/src/sync.rs b/crates/client/eth/src/sync.rs index a4794a7a7..2a607fcde 100644 --- a/crates/client/eth/src/sync.rs +++ b/crates/client/eth/src/sync.rs @@ -12,8 +12,8 @@ use mc_db::MadaraBackend; #[allow(clippy::too_many_arguments)] pub async fn l1_sync_worker( - backend: &MadaraBackend, - eth_client: &EthereumClient, + backend: Arc<MadaraBackend>, + eth_client: Arc<EthereumClient>, chain_id: ChainId, l1_gas_provider: GasPriceProvider, gas_price_sync_disabled: bool, @@ -21,16 +21,18 @@ pub async fn l1_sync_worker( mempool: Arc<Mempool>, ctx: ServiceContext, ) -> anyhow::Result<()> { - tokio::try_join!( - state_update_worker(backend, eth_client, chain_id.clone(), ctx.clone()), - async { - if !gas_price_sync_disabled { - gas_price_worker(eth_client, l1_gas_provider, gas_price_poll_ms, ctx.clone()).await?; - } - Ok(()) - }, - sync(backend, eth_client, &chain_id, mempool, ctx.clone()) - )?; + let mut join_set = tokio::task::JoinSet::new(); + + join_set.spawn(state_update_worker(Arc::clone(&backend), Arc::clone(ð_client), ctx.clone())); + join_set.spawn(sync(Arc::clone(&backend), Arc::clone(ð_client), chain_id, mempool, ctx.clone())); + + if !gas_price_sync_disabled { + join_set.spawn(gas_price_worker(Arc::clone(ð_client), l1_gas_provider, gas_price_poll_ms, ctx.clone())); + } + + while let Some(res) = join_set.join_next().await { + res??; + } Ok(()) } diff --git a/crates/client/gateway/client/src/methods.rs b/crates/client/gateway/client/src/methods.rs index 1c214c875..632475ee8 100644 --- a/crates/client/gateway/client/src/methods.rs +++ b/crates/client/gateway/client/src/methods.rs @@ -213,7 +213,7 @@ mod tests { } } - impl<'a> Drop for FileCleanupGuard<'a> { + impl Drop for FileCleanupGuard<'_> { fn drop(&mut self) { if self.is_active { let _ = remove_file(self.path); diff --git a/crates/client/gateway/server/src/handler.rs b/crates/client/gateway/server/src/handler.rs index fc9696ce8..c239c64c4 100644 --- a/crates/client/gateway/server/src/handler.rs +++ b/crates/client/gateway/server/src/handler.rs @@ -230,6 +230,7 @@ pub async fn handle_get_block_traces( req: Request<Incoming>, backend: Arc<MadaraBackend>, add_transaction_provider: Arc<dyn AddTransactionProvider>, + ctx: ServiceContext, ) -> Result<Response<String>, GatewayError> { let params = get_params_from_request(&req); let block_id = block_id_from_params(¶ms).or_internal_server_error("Retrieving block id")?; @@ -239,10 +240,8 @@ pub async fn handle_get_block_traces( traces: Vec<TraceBlockTransactionsResult<Felt>>, } - // TODO: we should probably use the actual service context here instead of - // creating a new one! let traces = v0_7_1_trace_block_transactions( - &Starknet::new(backend, add_transaction_provider, Default::default(), ServiceContext::new()), + &Starknet::new(backend, add_transaction_provider, Default::default(), ctx), block_id, ) .await?; diff --git a/crates/client/gateway/server/src/router.rs b/crates/client/gateway/server/src/router.rs index f0577901c..009863ee0 100644 --- a/crates/client/gateway/server/src/router.rs +++ b/crates/client/gateway/server/src/router.rs @@ -3,6 +3,7 @@ use std::{convert::Infallible, sync::Arc}; use hyper::{body::Incoming, Method, Request, Response}; use mc_db::MadaraBackend; use mc_rpc::providers::AddTransactionProvider; +use mp_utils::service::ServiceContext; use super::handler::{ handle_add_transaction, handle_get_block, handle_get_block_traces, handle_get_class_by_hash, @@ -16,6 +17,7 @@ pub(crate) async fn main_router( req: Request<Incoming>, backend: Arc<MadaraBackend>, add_transaction_provider: Arc<dyn AddTransactionProvider>, + ctx: ServiceContext, feeder_gateway_enable: bool, gateway_enable: bool, ) -> Result<Response<String>, Infallible> { @@ -23,7 +25,7 @@ pub(crate) async fn main_router( match (path.as_ref(), feeder_gateway_enable, gateway_enable) { ("health", _, _) => Ok(Response::new("OK".to_string())), (path, true, _) if path.starts_with("feeder_gateway/") => { - feeder_gateway_router(req, path, backend, add_transaction_provider).await + feeder_gateway_router(req, path, backend, add_transaction_provider, ctx).await } (path, _, true) if path.starts_with("gateway/") => gateway_router(req, path, add_transaction_provider).await, (path, false, _) if path.starts_with("feeder_gateway/") => Ok(service_unavailable_response("Feeder Gateway")), @@ -41,6 +43,7 @@ async fn feeder_gateway_router( path: &str, backend: Arc<MadaraBackend>, add_transaction_provider: Arc<dyn AddTransactionProvider>, + ctx: ServiceContext, ) -> Result<Response<String>, Infallible> { match (req.method(), path) { (&Method::GET, "feeder_gateway/get_block") => { @@ -53,7 +56,7 @@ async fn feeder_gateway_router( Ok(handle_get_state_update(req, backend).await.unwrap_or_else(Into::into)) } (&Method::GET, "feeder_gateway/get_block_traces") => { - Ok(handle_get_block_traces(req, backend, add_transaction_provider).await.unwrap_or_else(Into::into)) + Ok(handle_get_block_traces(req, backend, add_transaction_provider, ctx).await.unwrap_or_else(Into::into)) } (&Method::GET, "feeder_gateway/get_class_by_hash") => { Ok(handle_get_class_by_hash(req, backend).await.unwrap_or_else(Into::into)) diff --git a/crates/client/gateway/server/src/service.rs b/crates/client/gateway/server/src/service.rs index d199fd549..0401f3cfc 100644 --- a/crates/client/gateway/server/src/service.rs +++ b/crates/client/gateway/server/src/service.rs @@ -8,8 +8,8 @@ use hyper::{server::conn::http1, service::service_fn}; use hyper_util::rt::TokioIo; use mc_db::MadaraBackend; use mc_rpc::providers::AddTransactionProvider; -use mp_utils::{graceful_shutdown, service::ServiceContext}; -use tokio::{net::TcpListener, sync::Notify}; +use mp_utils::service::ServiceContext; +use tokio::net::TcpListener; use super::router::main_router; @@ -20,7 +20,7 @@ pub async fn start_server( gateway_enable: bool, gateway_external: bool, gateway_port: u16, - ctx: ServiceContext, + mut ctx: ServiceContext, ) -> anyhow::Result<()> { if !feeder_gateway_enable && !gateway_enable { return Ok(()); @@ -36,46 +36,33 @@ pub async fn start_server( tracing::info!("🌐 Gateway endpoint started at {}", addr); - let shutdown_notify = Arc::new(Notify::new()); + while let Some(res) = ctx.run_until_cancelled(listener.accept()).await { + // Handle new incoming connections + if let Ok((stream, _)) = res { + let io = TokioIo::new(stream); - { - let shutdown_notify = Arc::clone(&shutdown_notify); - tokio::spawn(async move { - graceful_shutdown(&ctx).await; - shutdown_notify.notify_waiters(); - }); - } - - loop { - tokio::select! { - // Handle new incoming connections - Ok((stream, _)) = listener.accept() => { - let io = TokioIo::new(stream); - - let db_backend = Arc::clone(&db_backend); - let add_transaction_provider = Arc::clone(&add_transaction_provider); + let db_backend = Arc::clone(&db_backend); + let add_transaction_provider = add_transaction_provider.clone(); + let ctx = ctx.clone(); - tokio::task::spawn(async move { - let service = service_fn(move |req| { - main_router( - req, - Arc::clone(&db_backend), - Arc::clone(&add_transaction_provider), - feeder_gateway_enable, - gateway_enable, - ) - }); - - if let Err(err) = http1::Builder::new().serve_connection(io, service).await { - tracing::error!("Error serving connection: {:?}", err); - } + tokio::task::spawn(async move { + let service = service_fn(move |req| { + main_router( + req, + Arc::clone(&db_backend), + add_transaction_provider.clone(), + ctx.clone(), + feeder_gateway_enable, + gateway_enable, + ) }); - }, - // Await the shutdown signal - _ = shutdown_notify.notified() => { - break Ok(()); - } + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { + tracing::error!("Error serving connection: {:?}", err); + } + }); } } + + anyhow::Ok(()) } diff --git a/crates/client/rpc/src/providers/mod.rs b/crates/client/rpc/src/providers/mod.rs index a5f616871..47a35c62a 100644 --- a/crates/client/rpc/src/providers/mod.rs +++ b/crates/client/rpc/src/providers/mod.rs @@ -1,17 +1,22 @@ pub mod forward_to_provider; pub mod mempool; +use std::sync::Arc; + pub use forward_to_provider::*; pub use mempool::*; use jsonrpsee::core::{async_trait, RpcResult}; use mp_transactions::BroadcastedDeclareTransactionV0; +use mp_utils::service::{MadaraServiceId, ServiceContext}; use starknet_types_core::felt::Felt; use starknet_types_rpc::{ AddInvokeTransactionResult, BroadcastedDeclareTxn, BroadcastedDeployAccountTxn, BroadcastedInvokeTxn, ClassAndTxnHash, ContractAndTxnHash, }; +use crate::utils::OptionExt; + #[async_trait] pub trait AddTransactionProvider: Send + Sync { async fn add_declare_v0_transaction( @@ -33,3 +38,83 @@ pub trait AddTransactionProvider: Send + Sync { invoke_transaction: BroadcastedInvokeTxn<Felt>, ) -> RpcResult<AddInvokeTransactionResult<Felt>>; } + +/// A simple struct whose sole purpose is to toggle between a L2 sync and local +/// (sequencer) transaction provider depending on the state of the node as +/// specified by [ServiceContext]. +/// +/// - If we are relying on L2 sync, then all transactions are forwarded to the +/// sequencer. +/// +/// - If we are relying on local block production, then transactions are also +/// executed locally. +/// +/// This exists to accommodate warp updates, which require we toggle from L2 +/// sync transaction forwarding to local transaction execution if we are +/// launching the sync on a local sequencer. +#[derive(Clone)] +pub struct AddTransactionProviderGroup { + l2_sync: Arc<dyn AddTransactionProvider>, + mempool: Arc<dyn AddTransactionProvider>, + ctx: ServiceContext, +} + +impl AddTransactionProviderGroup { + pub const ERROR: &str = + "Failed to retrieve add transaction provider, meaning neither l2 sync nor block production are running"; + + pub fn new( + l2_sync: Arc<dyn AddTransactionProvider>, + mempool: Arc<dyn AddTransactionProvider>, + ctx: ServiceContext, + ) -> Self { + Self { l2_sync, mempool, ctx } + } + + fn provider(&self) -> Option<&Arc<dyn AddTransactionProvider>> { + if self.ctx.service_status(MadaraServiceId::L2Sync).is_on() { + Some(&self.l2_sync) + } else if self.ctx.service_status(MadaraServiceId::BlockProduction).is_on() { + Some(&self.mempool) + } else { + None + } + } +} + +#[async_trait] +impl AddTransactionProvider for AddTransactionProviderGroup { + async fn add_declare_v0_transaction( + &self, + declare_v0_transaction: BroadcastedDeclareTransactionV0, + ) -> RpcResult<ClassAndTxnHash<Felt>> { + self.provider() + .ok_or_internal_server_error(Self::ERROR)? + .add_declare_v0_transaction(declare_v0_transaction) + .await + } + + async fn add_declare_transaction( + &self, + declare_transaction: BroadcastedDeclareTxn<Felt>, + ) -> RpcResult<ClassAndTxnHash<Felt>> { + self.provider().ok_or_internal_server_error(Self::ERROR)?.add_declare_transaction(declare_transaction).await + } + + async fn add_deploy_account_transaction( + &self, + deploy_account_transaction: BroadcastedDeployAccountTxn<Felt>, + ) -> RpcResult<ContractAndTxnHash<Felt>> { + self.provider() + .ok_or_internal_server_error(Self::ERROR)? + .add_deploy_account_transaction(deploy_account_transaction) + .await + } + + async fn add_invoke_transaction( + &self, + invoke_transaction: BroadcastedInvokeTxn<Felt>, + ) -> RpcResult<AddInvokeTransactionResult<Felt>> { + self.provider().ok_or_internal_server_error(Self::ERROR)?.add_invoke_transaction(invoke_transaction).await + } +} diff --git a/crates/client/rpc/src/versions/admin/v0_1_0/api.rs b/crates/client/rpc/src/versions/admin/v0_1_0/api.rs index aa970311a..e2174d575 100644 --- a/crates/client/rpc/src/versions/admin/v0_1_0/api.rs +++ b/crates/client/rpc/src/versions/admin/v0_1_0/api.rs @@ -1,9 +1,19 @@ use jsonrpsee::core::RpcResult; use m_proc_macros::versioned_rpc; use mp_transactions::BroadcastedDeclareTransactionV0; +use mp_utils::service::{MadaraServiceId, MadaraServiceStatus}; +use serde::{Deserialize, Serialize}; use starknet_types_core::felt::Felt; use starknet_types_rpc::ClassAndTxnHash; +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "snake_case")] +pub enum ServiceRequest { + Start, + Stop, + Restart, +} + /// This is an admin method, so semver is different! #[versioned_rpc("V0_1_0", "madara")] pub trait MadaraWriteRpcApi { @@ -44,69 +54,11 @@ pub trait MadaraStatusRpcApi { #[versioned_rpc("V0_1_0", "madara")] pub trait MadaraServicesRpcApi { - /// Disables user-facing rpc services. - /// - /// This only works if user rpc has been enabled on startup, otherwise this - /// does nothing. - /// - /// # Returns - /// - /// True if user rpc was previously enabled. - #[method(name = "rpcDisable")] - async fn service_rpc_disable(&self) -> RpcResult<bool>; - - /// Enables user-facing rpc services. - /// - /// This only works if user rpc has been enabled on startup, otherwise this - /// does nothing. - /// - /// # Returns - /// - /// True if user rpc was previously enabled. - #[method(name = "rpcEnable")] - async fn service_rpc_enable(&self) -> RpcResult<bool>; - - /// Restarts user-facing rpc services, with a 5s grace period in between. - /// - /// This only works if user rpc has been enabled on startup, otherwise this - /// does nothing. - /// - /// # Returns - /// - /// True if user rpc was previously enabled. - #[method(name = "rpcRestart")] - async fn service_rpc_restart(&self) -> RpcResult<bool>; - - /// Disables l1 and l2 sync services. - /// - /// This only works if sync services have been enabled on startup, otherwise - /// this does nothing. - /// - /// # Returns - /// - /// True if any of l1 or l2 sync was previously enabled. - #[method(name = "syncDisable")] - async fn service_sync_disable(&self) -> RpcResult<bool>; - - /// Enables l1 and l2 sync services. - /// - /// This only works if sync services have been enabled on startup, otherwise - /// this does nothing. - /// - /// # Returns - /// - /// True if any of l1 or l2 sync was previously enabled. - #[method(name = "syncEnable")] - async fn service_sync_enable(&self) -> RpcResult<bool>; - - /// Disables l1 and l2 sync services, with a 5s grace period in between. - /// - /// This only works if sync services have been enabled on startup, otherwise - /// this does nothing. + /// Sets the status of one or more services /// /// # Returns /// - /// True if l1 or l2 sync was previously enabled. - #[method(name = "syncRestart")] - async fn service_sync_restart(&self) -> RpcResult<bool>; + /// * 'on' if any service was active before being toggled, 'off' otherwise. + #[method(name = "service")] + async fn service(&self, service: Vec<MadaraServiceId>, status: ServiceRequest) -> RpcResult<MadaraServiceStatus>; } diff --git a/crates/client/rpc/src/versions/admin/v0_1_0/methods/services.rs b/crates/client/rpc/src/versions/admin/v0_1_0/methods/services.rs index 2257aec7c..c0a43fc79 100644 --- a/crates/client/rpc/src/versions/admin/v0_1_0/methods/services.rs +++ b/crates/client/rpc/src/versions/admin/v0_1_0/methods/services.rs @@ -1,70 +1,67 @@ use std::time::Duration; use jsonrpsee::core::{async_trait, RpcResult}; -use mp_utils::service::MadaraService; +use mp_utils::service::{MadaraServiceId, MadaraServiceStatus, ServiceContext}; -use crate::{versions::admin::v0_1_0::MadaraServicesRpcApiV0_1_0Server, Starknet}; +use crate::{ + versions::admin::v0_1_0::{MadaraServicesRpcApiV0_1_0Server, ServiceRequest}, + Starknet, +}; const RESTART_INTERVAL: Duration = Duration::from_secs(5); #[async_trait] impl MadaraServicesRpcApiV0_1_0Server for Starknet { - #[tracing::instrument(skip(self), fields(module = "Admin"))] - async fn service_rpc_disable(&self) -> RpcResult<bool> { - tracing::info!("🔌 Stopping RPC service..."); - Ok(self.ctx.service_remove(MadaraService::Rpc)) + async fn service(&self, service: Vec<MadaraServiceId>, status: ServiceRequest) -> RpcResult<MadaraServiceStatus> { + if service.is_empty() { + Err(jsonrpsee::types::ErrorObject::owned( + jsonrpsee::types::ErrorCode::InvalidParams.code(), + "You must provide at least one service to toggle", + Some(()), + )) + } else { + match status { + ServiceRequest::Start => service_start(&self.ctx, &service), + ServiceRequest::Stop => service_stop(&self.ctx, &service), + ServiceRequest::Restart => service_restart(&self.ctx, &service).await, + } + } } +} - #[tracing::instrument(skip(self), fields(module = "Admin"))] - async fn service_rpc_enable(&self) -> RpcResult<bool> { - tracing::info!("🔌 Starting RPC service..."); - Ok(self.ctx.service_add(MadaraService::Rpc)) +fn service_start(ctx: &ServiceContext, svcs: &[MadaraServiceId]) -> RpcResult<MadaraServiceStatus> { + let mut status = MadaraServiceStatus::Off; + for svc in svcs { + tracing::info!("🔌 Starting {} service...", svc); + status |= ctx.service_add(*svc); } - #[tracing::instrument(skip(self), fields(module = "Admin"))] - async fn service_rpc_restart(&self) -> RpcResult<bool> { - tracing::info!("🔌 Restarting RPC service..."); - - let res = self.ctx.service_remove(MadaraService::Rpc); - tokio::time::sleep(RESTART_INTERVAL).await; - self.ctx.service_add(MadaraService::Rpc); - - tracing::info!("🔌 Restart complete (Rpc)"); + Ok(status) +} - return Ok(res); +fn service_stop(ctx: &ServiceContext, svcs: &[MadaraServiceId]) -> RpcResult<MadaraServiceStatus> { + let mut status = MadaraServiceStatus::Off; + for svc in svcs { + tracing::info!("🔌 Stopping {} service...", svc); + status |= ctx.service_remove(*svc); } - #[tracing::instrument(skip(self), fields(module = "Admin"))] - async fn service_sync_disable(&self) -> RpcResult<bool> { - tracing::info!("🔌 Stopping Sync service..."); - - let res = self.ctx.service_remove(MadaraService::L1Sync) | self.ctx.service_remove(MadaraService::L2Sync); + Ok(status) +} - Ok(res) +async fn service_restart(ctx: &ServiceContext, svcs: &[MadaraServiceId]) -> RpcResult<MadaraServiceStatus> { + let mut status = MadaraServiceStatus::Off; + for svc in svcs { + tracing::info!("🔌 Restarting {} service...", svc); + status |= ctx.service_remove(*svc); } - #[tracing::instrument(skip(self), fields(module = "Admin"))] - async fn service_sync_enable(&self) -> RpcResult<bool> { - tracing::info!("🔌 Starting Sync service..."); - - let res = self.ctx.service_add(MadaraService::L1Sync) | self.ctx.service_add(MadaraService::L2Sync); + tokio::time::sleep(RESTART_INTERVAL).await; - Ok(res) + for svc in svcs { + ctx.service_add(*svc); + tracing::info!("🔌 Restart {} complete", svc); } - #[tracing::instrument(skip(self), fields(module = "Admin"))] - async fn service_sync_restart(&self) -> RpcResult<bool> { - tracing::info!("🔌 Stopping Sync service..."); - - let res = self.ctx.service_remove(MadaraService::L1Sync) | self.ctx.service_remove(MadaraService::L2Sync); - - tokio::time::sleep(RESTART_INTERVAL).await; - - self.ctx.service_add(MadaraService::L1Sync); - self.ctx.service_add(MadaraService::L2Sync); - - tracing::info!("🔌 Restart complete (Sync)"); - - Ok(res) - } + Ok(status) } diff --git a/crates/client/rpc/src/versions/admin/v0_1_0/methods/write.rs b/crates/client/rpc/src/versions/admin/v0_1_0/methods/write.rs index aa907b3bb..ddbd65b5b 100644 --- a/crates/client/rpc/src/versions/admin/v0_1_0/methods/write.rs +++ b/crates/client/rpc/src/versions/admin/v0_1_0/methods/write.rs @@ -20,6 +20,6 @@ impl MadaraWriteRpcApiV0_1_0Server for Starknet { &self, declare_transaction: BroadcastedDeclareTransactionV0, ) -> RpcResult<ClassAndTxnHash<Felt>> { - Ok(self.add_transaction_provider.add_declare_v0_transaction(declare_transaction).await?) + self.add_transaction_provider.add_declare_v0_transaction(declare_transaction).await } } diff --git a/crates/client/rpc/src/versions/user/v0_7_1/methods/read/get_block_with_tx_hashes.rs b/crates/client/rpc/src/versions/user/v0_7_1/methods/read/get_block_with_tx_hashes.rs index 9222c6afc..0d19896e4 100644 --- a/crates/client/rpc/src/versions/user/v0_7_1/methods/read/get_block_with_tx_hashes.rs +++ b/crates/client/rpc/src/versions/user/v0_7_1/methods/read/get_block_with_tx_hashes.rs @@ -20,7 +20,6 @@ use crate::Starknet; /// Returns block information with transaction hashes. This includes either a confirmed block or /// a pending block with transaction hashes, depending on the state of the requested block. /// In case the block is not found, returns a `StarknetRpcApiError` with `BlockNotFound`. - pub fn get_block_with_tx_hashes( starknet: &Starknet, block_id: BlockId, diff --git a/crates/client/rpc/src/versions/user/v0_7_1/methods/read/get_nonce.rs b/crates/client/rpc/src/versions/user/v0_7_1/methods/read/get_nonce.rs index 6d1bf7dd4..f2764f5a5 100644 --- a/crates/client/rpc/src/versions/user/v0_7_1/methods/read/get_nonce.rs +++ b/crates/client/rpc/src/versions/user/v0_7_1/methods/read/get_nonce.rs @@ -21,7 +21,6 @@ use crate::Starknet; /// count or other contract-specific operations. In case of errors, such as /// `BLOCK_NOT_FOUND` or `CONTRACT_NOT_FOUND`, returns a `StarknetRpcApiError` indicating the /// specific issue. - pub fn get_nonce(starknet: &Starknet, block_id: BlockId, contract_address: Felt) -> StarknetRpcResult<Felt> { // Check if block exists. We have to return a different error in that case. let block_exists = diff --git a/crates/client/rpc/src/versions/user/v0_7_1/methods/write/mod.rs b/crates/client/rpc/src/versions/user/v0_7_1/methods/write/mod.rs index 08b18af20..345883884 100644 --- a/crates/client/rpc/src/versions/user/v0_7_1/methods/write/mod.rs +++ b/crates/client/rpc/src/versions/user/v0_7_1/methods/write/mod.rs @@ -21,7 +21,7 @@ impl StarknetWriteRpcApiV0_7_1Server for Starknet { &self, declare_transaction: BroadcastedDeclareTxn<Felt>, ) -> RpcResult<ClassAndTxnHash<Felt>> { - Ok(self.add_transaction_provider.add_declare_transaction(declare_transaction).await?) + self.add_transaction_provider.add_declare_transaction(declare_transaction).await } /// Add an Deploy Account Transaction @@ -38,7 +38,7 @@ impl StarknetWriteRpcApiV0_7_1Server for Starknet { &self, deploy_account_transaction: BroadcastedDeployAccountTxn<Felt>, ) -> RpcResult<ContractAndTxnHash<Felt>> { - Ok(self.add_transaction_provider.add_deploy_account_transaction(deploy_account_transaction).await?) + self.add_transaction_provider.add_deploy_account_transaction(deploy_account_transaction).await } /// Add an Invoke Transaction to invoke a contract function @@ -54,6 +54,6 @@ impl StarknetWriteRpcApiV0_7_1Server for Starknet { &self, invoke_transaction: BroadcastedInvokeTxn<Felt>, ) -> RpcResult<AddInvokeTransactionResult<Felt>> { - Ok(self.add_transaction_provider.add_invoke_transaction(invoke_transaction).await?) + self.add_transaction_provider.add_invoke_transaction(invoke_transaction).await } } diff --git a/crates/client/sync/src/fetch/fetchers.rs b/crates/client/sync/src/fetch/fetchers.rs index fc0f559e8..b26f3a801 100644 --- a/crates/client/sync/src/fetch/fetchers.rs +++ b/crates/client/sync/src/fetch/fetchers.rs @@ -13,8 +13,8 @@ use mp_gateway::block::{ProviderBlock, ProviderBlockPending}; use mp_gateway::error::{SequencerError, StarknetError, StarknetErrorCode}; use mp_gateway::state_update::ProviderStateUpdateWithBlockPendingMaybe::{self}; use mp_gateway::state_update::{ProviderStateUpdate, ProviderStateUpdatePending, StateDiff}; -use mp_utils::service::ServiceContext; -use mp_utils::{stopwatch_end, wait_or_graceful_shutdown, PerfStopwatch}; +use mp_utils::service::MadaraServiceId; +use mp_utils::{stopwatch_end, PerfStopwatch}; use starknet_api::core::ChainId; use starknet_types_core::felt::Felt; use std::sync::Arc; @@ -49,19 +49,30 @@ pub struct FetchConfig { pub stop_on_sync: bool, /// Number of blocks to fetch in parallel during the sync process pub sync_parallelism: u8, - /// True if the node is called with `--warp-update-receiver` - pub warp_update: bool, + /// Warp update configuration + pub warp_update: Option<WarpUpdateConfig>, +} + +#[derive(Clone, Debug)] +pub struct WarpUpdateConfig { /// The port used for nodes to make rpc calls during a warp update. pub warp_update_port_rpc: u16, /// The port used for nodes to send blocks during a warp update. pub warp_update_port_fgw: u16, + /// Whether to shutdown the warp update sender once the migration has completed. + pub warp_update_shutdown_sender: bool, + /// Whether to shut down the warp update receiver once the migration has completed + pub warp_update_shutdown_receiver: bool, + /// A list of services to start once warp update has completed. + pub deferred_service_start: Vec<MadaraServiceId>, + /// A list of services to stop one warp update has completed. + pub deferred_service_stop: Vec<MadaraServiceId>, } pub async fn fetch_pending_block_and_updates( parent_block_hash: Felt, chain_id: &ChainId, provider: &GatewayProvider, - ctx: &ServiceContext, ) -> Result<Option<UnverifiedPendingFullBlock>, FetchError> { let block_id = BlockId::Tag(BlockTag::Pending); let sw = PerfStopwatch::new(); @@ -80,7 +91,6 @@ pub async fn fetch_pending_block_and_updates( }, MAX_RETRY, BASE_DELAY, - ctx, ) .await?; @@ -99,7 +109,7 @@ pub async fn fetch_pending_block_and_updates( ); return Ok(None); } - let class_update = fetch_class_updates(chain_id, &state_update.state_diff, block_id.clone(), provider, ctx).await?; + let class_update = fetch_class_updates(chain_id, &state_update.state_diff, block_id.clone(), provider).await?; stopwatch_end!(sw, "fetching {:?}: {:?}", block_id); @@ -113,7 +123,6 @@ pub async fn fetch_block_and_updates( chain_id: &ChainId, block_n: u64, provider: &GatewayProvider, - ctx: &ServiceContext, ) -> Result<UnverifiedFullBlock, FetchError> { let block_id = BlockId::Number(block_n); @@ -127,10 +136,9 @@ pub async fn fetch_block_and_updates( }, MAX_RETRY, BASE_DELAY, - ctx, ) .await?; - let class_update = fetch_class_updates(chain_id, state_update.state_diff(), block_id, provider, ctx).await?; + let class_update = fetch_class_updates(chain_id, state_update.state_diff(), block_id, provider).await?; stopwatch_end!(sw, "fetching {:?}: {:?}", block_n); @@ -143,12 +151,8 @@ pub async fn fetch_block_and_updates( Ok(converted) } -async fn retry<F, Fut, T>( - mut f: F, - max_retries: u32, - base_delay: Duration, - ctx: &ServiceContext, -) -> Result<T, SequencerError> +// TODO: should we be checking for cancellation here? This might take a while +async fn retry<F, Fut, T>(mut f: F, max_retries: u32, base_delay: Duration) -> Result<T, SequencerError> where F: FnMut() -> Fut, Fut: std::future::Future<Output = Result<T, SequencerError>>, @@ -173,9 +177,7 @@ where tracing::warn!("The provider has returned an error: {}, retrying in {:?}", err, delay) } - if wait_or_graceful_shutdown(tokio::time::sleep(delay), ctx).await.is_none() { - return Err(SequencerError::StarknetError(StarknetError::block_not_found())); - } + tokio::time::sleep(delay).await; } } } @@ -187,7 +189,6 @@ async fn fetch_class_updates( state_diff: &StateDiff, block_id: BlockId, provider: &GatewayProvider, - ctx: &ServiceContext, ) -> anyhow::Result<Vec<ClassUpdate>> { // for blocks before 2597 on mainnet new classes are not declared in the state update // https://github.com/madara-alliance/madara/issues/233 @@ -208,7 +209,7 @@ async fn fetch_class_updates( let block_id = block_id.clone(); async move { let (class_hash, contract_class) = - retry(|| fetch_class(class_hash, block_id.clone(), provider), MAX_RETRY, BASE_DELAY, ctx).await?; + retry(|| fetch_class(class_hash, block_id.clone(), provider), MAX_RETRY, BASE_DELAY).await?; let ContractClass::Legacy(contract_class) = contract_class else { return Err(L2SyncError::UnexpectedClassType { class_hash }); @@ -225,7 +226,7 @@ async fn fetch_class_updates( let block_id = block_id.clone(); async move { let (class_hash, contract_class) = - retry(|| fetch_class(class_hash, block_id.clone(), provider), MAX_RETRY, BASE_DELAY, ctx).await?; + retry(|| fetch_class(class_hash, block_id.clone(), provider), MAX_RETRY, BASE_DELAY).await?; let ContractClass::Sierra(contract_class) = contract_class else { return Err(L2SyncError::UnexpectedClassType { class_hash }); @@ -350,7 +351,6 @@ mod test_l2_fetchers { Felt::from_hex_unchecked("0x1db054847816dbc0098c88915430c44da2c1e3f910fbcb454e14282baba0e75"), &ctx.backend.chain_config().chain_id, &ctx.provider, - &ServiceContext::new_for_testing(), ) .await; @@ -436,7 +436,6 @@ mod test_l2_fetchers { Felt::from_hex_unchecked("0x1db054847816dbc0098c88915430c44da2c1e3f910fbcb454e14282baba0e75"), &ctx.backend.chain_config().chain_id, &ctx.provider, - &ServiceContext::new_for_testing(), ) .await; @@ -644,15 +643,10 @@ mod test_l2_fetchers { .state_update(); let state_diff = state_update.state_diff(); - let class_updates = fetch_class_updates( - &ctx.backend.chain_config().chain_id, - state_diff, - BlockId::Number(5), - &ctx.provider, - &ServiceContext::new_for_testing(), - ) - .await - .expect("Failed to fetch class updates"); + let class_updates = + fetch_class_updates(&ctx.backend.chain_config().chain_id, state_diff, BlockId::Number(5), &ctx.provider) + .await + .expect("Failed to fetch class updates"); assert!(!class_updates.is_empty(), "Should have fetched at least one class update"); @@ -681,14 +675,9 @@ mod test_l2_fetchers { let state_diff = state_update.state_diff(); ctx.mock_class_hash_not_found("0x40fe2533528521fc49a8ad8440f8a1780c50337a94d0fce43756015fa816a8a".to_string()); - let result = fetch_class_updates( - &ctx.backend.chain_config().chain_id, - state_diff, - BlockId::Number(5), - &ctx.provider, - &ServiceContext::new_for_testing(), - ) - .await; + let result = + fetch_class_updates(&ctx.backend.chain_config().chain_id, state_diff, BlockId::Number(5), &ctx.provider) + .await; assert!(matches!( result, diff --git a/crates/client/sync/src/fetch/fetchers_real_fgw_test.rs b/crates/client/sync/src/fetch/fetchers_real_fgw_test.rs index 5c194aac3..ddfd5fb14 100644 --- a/crates/client/sync/src/fetch/fetchers_real_fgw_test.rs +++ b/crates/client/sync/src/fetch/fetchers_real_fgw_test.rs @@ -11,14 +11,7 @@ fn client_mainnet_fixture() -> GatewayProvider { #[rstest] #[tokio::test] async fn test_can_fetch_pending_block(client_mainnet_fixture: GatewayProvider) { - let block = fetch_pending_block_and_updates( - Felt::ZERO, - &ChainId::Mainnet, - &client_mainnet_fixture, - &ServiceContext::new_for_testing(), - ) - .await - .unwrap(); + let block = fetch_pending_block_and_updates(Felt::ZERO, &ChainId::Mainnet, &client_mainnet_fixture).await.unwrap(); // ignore as we can't check much here :/ drop(block); } @@ -32,14 +25,7 @@ async fn test_can_fetch_and_convert_block(client_mainnet_fixture: GatewayProvide // Sorting is necessary since we store storage diffs and nonces in a // hashmap in the fgw types before converting them to a Vec in the mp // types, resulting in unpredictable ordering - let mut block = fetch_block_and_updates( - &ChainId::Mainnet, - block_n, - &client_mainnet_fixture, - &ServiceContext::new_for_testing(), - ) - .await - .unwrap(); + let mut block = fetch_block_and_updates(&ChainId::Mainnet, block_n, &client_mainnet_fixture).await.unwrap(); block.state_diff.storage_diffs.sort_by(|a, b| a.address.cmp(&b.address)); block.state_diff.nonces.sort_by(|a, b| a.contract_address.cmp(&b.contract_address)); diff --git a/crates/client/sync/src/fetch/mod.rs b/crates/client/sync/src/fetch/mod.rs index 8e9c6de71..8ae306194 100644 --- a/crates/client/sync/src/fetch/mod.rs +++ b/crates/client/sync/src/fetch/mod.rs @@ -7,12 +7,14 @@ use mc_db::MadaraBackend; use mc_gateway_client::GatewayProvider; use mc_rpc::versions::admin::v0_1_0::MadaraStatusRpcApiV0_1_0Client; use mp_gateway::error::{SequencerError, StarknetError, StarknetErrorCode}; -use mp_utils::{channel_wait_or_graceful_shutdown, service::ServiceContext, wait_or_graceful_shutdown}; +use mp_utils::service::ServiceContext; use tokio::sync::{mpsc, oneshot}; use url::Url; use crate::fetch::fetchers::fetch_block_and_updates; +use self::fetchers::WarpUpdateConfig; + pub mod fetchers; pub struct L2FetchConfig { @@ -23,23 +25,27 @@ pub struct L2FetchConfig { pub n_blocks_to_sync: Option<u64>, pub stop_on_sync: bool, pub sync_parallelism: usize, - pub warp_update: bool, - pub warp_update_port_rpc: u16, - pub warp_update_port_fgw: u16, + pub warp_update: Option<WarpUpdateConfig>, } pub async fn l2_fetch_task( backend: Arc<MadaraBackend>, provider: Arc<GatewayProvider>, - ctx: ServiceContext, + mut ctx: ServiceContext, mut config: L2FetchConfig, ) -> anyhow::Result<()> { // First, catch up with the chain - // let backend = &backend; - - let L2FetchConfig { first_block, warp_update, warp_update_port_rpc, warp_update_port_fgw, .. } = config; - - if warp_update { + let L2FetchConfig { first_block, ref warp_update, .. } = config; + + if let Some(WarpUpdateConfig { + warp_update_port_rpc, + warp_update_port_fgw, + warp_update_shutdown_sender, + warp_update_shutdown_receiver, + deferred_service_start, + deferred_service_stop, + }) = warp_update + { let client = jsonrpsee::http_client::HttpClientBuilder::default() .build(format!("http://localhost:{warp_update_port_rpc}")) .expect("Building client"); @@ -62,15 +68,29 @@ pub async fn l2_fetch_task( .unwrap_or(NonZeroUsize::new(1usize).expect("1 should always be in usize bound")); config.sync_parallelism = Into::<usize>::into(available_parallelism) * 2; - let next_block = match sync_blocks(backend.as_ref(), &provider, &ctx, &config).await? { + let next_block = match sync_blocks(backend.as_ref(), &provider, &mut ctx, &config).await? { SyncStatus::Full(next_block) => next_block, SyncStatus::UpTo(next_block) => next_block, }; - if client.shutdown().await.is_err() { - tracing::error!("❗ Failed to shutdown warp update sender"); - ctx.cancel_global(); - return Ok(()); + if *warp_update_shutdown_sender { + if client.shutdown().await.is_err() { + tracing::error!("❗ Failed to shutdown warp update sender"); + ctx.cancel_global(); + return Ok(()); + } + + for svc_id in deferred_service_stop { + ctx.service_remove(*svc_id); + } + + for svc_id in deferred_service_start { + ctx.service_add(*svc_id); + } + } + + if *warp_update_shutdown_receiver { + return anyhow::Ok(()); } config.n_blocks_to_sync = config.n_blocks_to_sync.map(|n| n - (next_block - first_block)); @@ -78,7 +98,7 @@ pub async fn l2_fetch_task( config.sync_parallelism = save; } - let mut next_block = match sync_blocks(backend.as_ref(), &provider, &ctx, &config).await? { + let mut next_block = match sync_blocks(backend.as_ref(), &provider, &mut ctx, &config).await? { SyncStatus::Full(next_block) => { tracing::info!("🥳 The sync process has caught up with the tip of the chain"); next_block @@ -105,17 +125,27 @@ pub async fn l2_fetch_task( let mut interval = tokio::time::interval(sync_polling_interval); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - while wait_or_graceful_shutdown(interval.tick(), &ctx).await.is_some() { - loop { - match fetch_block_and_updates(&backend.chain_config().chain_id, next_block, &provider, &ctx).await { + while ctx.run_until_cancelled(interval.tick()).await.is_some() { + // It is possible the chain produces multiple blocks in the span of + // a single loop iteration, so we keep fetching until we reach the + // tip again. + let chain_id = &backend.chain_config().chain_id; + let fetch = |next_block: u64| fetch_block_and_updates(chain_id, next_block, &provider); + + while let Some(block) = ctx.run_until_cancelled(fetch(next_block)).await { + match block { Err(FetchError::Sequencer(SequencerError::StarknetError(StarknetError { code: StarknetErrorCode::BlockNotFound, .. }))) => { break; } - val => { - if fetch_stream_sender.send(val?).await.is_err() { + Err(e) => { + tracing::debug!("Failed to poll latest block: {e}"); + return Err(e.into()); + } + Ok(unverified_block) => { + if fetch_stream_sender.send(unverified_block).await.is_err() { // stream closed break; } @@ -126,7 +156,8 @@ pub async fn l2_fetch_task( } } } - Ok(()) + + anyhow::Ok(()) } /// Whether a chain has been caught up to the tip or only a certain block number @@ -154,27 +185,24 @@ enum SyncStatus { async fn sync_blocks( backend: &MadaraBackend, provider: &Arc<GatewayProvider>, - ctx: &ServiceContext, + ctx: &mut ServiceContext, config: &L2FetchConfig, ) -> anyhow::Result<SyncStatus> { let L2FetchConfig { first_block, fetch_stream_sender, n_blocks_to_sync, sync_parallelism, .. } = config; // Fetch blocks and updates in parallel one time before looping - let fetch_stream = - (*first_block..).take(n_blocks_to_sync.unwrap_or(u64::MAX) as _).map(|block_n| { - let provider = Arc::clone(provider); - let ctx = ctx.clone(); - async move { - (block_n, fetch_block_and_updates(&backend.chain_config().chain_id, block_n, &provider, &ctx).await) - } - }); + let fetch_stream = (*first_block..).take(n_blocks_to_sync.unwrap_or(u64::MAX) as _).map(|block_n| { + let provider = Arc::clone(provider); + let chain_id = &backend.chain_config().chain_id; + async move { (block_n, fetch_block_and_updates(chain_id, block_n, &provider).await) } + }); // Have `sync_parallelism` fetches in parallel at once, using futures Buffered let mut next_block = *first_block; let mut fetch_stream = stream::iter(fetch_stream).buffered(*sync_parallelism); - loop { - let Some((block_n, val)) = channel_wait_or_graceful_shutdown(fetch_stream.next(), ctx).await else { + while let Some(next) = ctx.run_until_cancelled(fetch_stream.next()).await { + let Some((block_n, val)) = next else { return anyhow::Ok(SyncStatus::UpTo(next_block)); }; @@ -195,6 +223,8 @@ async fn sync_blocks( next_block = block_n + 1; } + + anyhow::Ok(SyncStatus::UpTo(next_block)) } #[derive(thiserror::Error, Debug)] @@ -253,9 +283,7 @@ mod test_l2_fetch_task { n_blocks_to_sync: Some(5), stop_on_sync: false, sync_parallelism: 10, - warp_update: false, - warp_update_port_rpc: 9943, - warp_update_port_fgw: 8080, + warp_update: None, }, ), ) diff --git a/crates/client/sync/src/l2.rs b/crates/client/sync/src/l2.rs index 356edc08b..dad900f43 100644 --- a/crates/client/sync/src/l2.rs +++ b/crates/client/sync/src/l2.rs @@ -1,5 +1,6 @@ //! Contains the code required to sync data from the feeder efficiently. use crate::fetch::fetchers::fetch_pending_block_and_updates; +use crate::fetch::fetchers::WarpUpdateConfig; use crate::fetch::l2_fetch_task; use crate::fetch::L2FetchConfig; use crate::utils::trim_hash; @@ -16,7 +17,7 @@ use mp_block::BlockId; use mp_block::BlockTag; use mp_gateway::error::SequencerError; use mp_utils::service::ServiceContext; -use mp_utils::{channel_wait_or_graceful_shutdown, wait_or_graceful_shutdown, PerfStopwatch}; +use mp_utils::PerfStopwatch; use starknet_api::core::ChainId; use starknet_types_core::felt::Felt; use std::pin::pin; @@ -52,7 +53,7 @@ pub struct L2VerifyApplyConfig { flush_every_n_blocks: u64, flush_every_n_seconds: u64, stop_on_sync: bool, - telemetry: TelemetryHandle, + telemetry: Arc<TelemetryHandle>, validation: BlockValidationContext, block_conv_receiver: mpsc::Receiver<PreValidatedBlock>, } @@ -60,7 +61,7 @@ pub struct L2VerifyApplyConfig { #[tracing::instrument(skip(backend, ctx, config), fields(module = "Sync"))] async fn l2_verify_and_apply_task( backend: Arc<MadaraBackend>, - ctx: ServiceContext, + mut ctx: ServiceContext, config: L2VerifyApplyConfig, ) -> anyhow::Result<()> { let L2VerifyApplyConfig { @@ -78,7 +79,7 @@ async fn l2_verify_and_apply_task( let mut instant = std::time::Instant::now(); let target_duration = std::time::Duration::from_secs(flush_every_n_seconds); - while let Some(block) = channel_wait_or_graceful_shutdown(pin!(block_conv_receiver.recv()), &ctx).await { + while let Some(Some(block)) = ctx.run_until_cancelled(pin!(block_conv_receiver.recv())).await { let BlockImportResult { header, block_hash } = block_import.verify_apply(block, validation.clone()).await?; if header.block_number - last_block_n >= flush_every_n_blocks || instant.elapsed() >= target_duration { @@ -122,7 +123,7 @@ async fn l2_verify_and_apply_task( ctx.cancel_global() } - Ok(()) + anyhow::Ok(()) } async fn l2_block_conversion_task( @@ -130,14 +131,14 @@ async fn l2_block_conversion_task( output: mpsc::Sender<PreValidatedBlock>, block_import: Arc<BlockImporter>, validation: BlockValidationContext, - ctx: ServiceContext, + mut ctx: ServiceContext, ) -> anyhow::Result<()> { // Items of this stream are futures that resolve to blocks, which becomes a regular stream of blocks // using futures buffered. let conversion_stream = stream::unfold( (updates_receiver, block_import, validation.clone(), ctx.clone()), |(mut updates_recv, block_import, validation, ctx)| async move { - channel_wait_or_graceful_shutdown(updates_recv.recv(), &ctx).await.map(|block| { + updates_recv.recv().await.map(|block| { let block_import_ = Arc::clone(&block_import); let validation_ = validation.clone(); ( @@ -149,13 +150,14 @@ async fn l2_block_conversion_task( ); let mut stream = pin!(conversion_stream.buffered(10)); - while let Some(block) = channel_wait_or_graceful_shutdown(stream.next(), &ctx).await { + while let Some(Some(block)) = ctx.run_until_cancelled(stream.next()).await { if output.send(block?).await.is_err() { // channel closed break; } } - Ok(()) + + anyhow::Ok(()) } struct L2PendingBlockConfig { @@ -168,7 +170,7 @@ struct L2PendingBlockConfig { async fn l2_pending_block_task( backend: Arc<MadaraBackend>, provider: Arc<GatewayProvider>, - ctx: ServiceContext, + mut ctx: ServiceContext, config: L2PendingBlockConfig, ) -> anyhow::Result<()> { let L2PendingBlockConfig { block_import, once_caught_up_receiver, pending_block_poll_interval, validation } = @@ -190,17 +192,18 @@ async fn l2_pending_block_task( let mut interval = tokio::time::interval(pending_block_poll_interval); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - while wait_or_graceful_shutdown(interval.tick(), &ctx).await.is_some() { + while ctx.run_until_cancelled(interval.tick()).await.is_some() { tracing::debug!("Getting pending block..."); let current_block_hash = backend .get_block_hash(&BlockId::Tag(BlockTag::Latest)) .context("Getting latest block hash")? .unwrap_or(/* genesis parent block hash */ Felt::ZERO); - let Some(block) = - fetch_pending_block_and_updates(current_block_hash, &backend.chain_config().chain_id, &provider, &ctx) - .await - .context("Getting pending block from FGW")? + + let chain_id = &backend.chain_config().chain_id; + let Some(block) = fetch_pending_block_and_updates(current_block_hash, chain_id, &provider) + .await + .context("Getting pending block from FGW")? else { continue; }; @@ -214,7 +217,7 @@ async fn l2_pending_block_task( }; if let Err(err) = import_block().await { - tracing::debug!("Error while importing pending block: {err:#}"); + tracing::debug!("Failed to import pending block: {err:#}"); } } @@ -233,18 +236,16 @@ pub struct L2SyncConfig { pub flush_every_n_seconds: u64, pub pending_block_poll_interval: Duration, pub ignore_block_order: bool, - pub warp_update: bool, - pub warp_update_port_rpc: u16, - pub warp_update_port_fgw: u16, pub chain_id: ChainId, - pub telemetry: TelemetryHandle, + pub telemetry: Arc<TelemetryHandle>, pub block_importer: Arc<BlockImporter>, + pub warp_update: Option<WarpUpdateConfig>, } /// Spawns workers to fetch blocks and state updates from the feeder. #[tracing::instrument(skip(backend, provider, ctx, config), fields(module = "Sync"))] pub async fn sync( - backend: &Arc<MadaraBackend>, + backend: Arc<MadaraBackend>, provider: GatewayProvider, ctx: ServiceContext, config: L2SyncConfig, @@ -272,8 +273,11 @@ pub async fn sync( }; let mut join_set = JoinSet::new(); + let warp_update_shutdown_sender = + config.warp_update.as_ref().map(|w| w.warp_update_shutdown_receiver).unwrap_or(false); + join_set.spawn(l2_fetch_task( - Arc::clone(backend), + Arc::clone(&backend), Arc::clone(&provider), ctx.clone(), L2FetchConfig { @@ -285,8 +289,6 @@ pub async fn sync( stop_on_sync: config.stop_on_sync, sync_parallelism: config.sync_parallelism as usize, warp_update: config.warp_update, - warp_update_port_rpc: config.warp_update_port_rpc, - warp_update_port_fgw: config.warp_update_port_fgw, }, )); join_set.spawn(l2_block_conversion_task( @@ -297,21 +299,21 @@ pub async fn sync( ctx.clone(), )); join_set.spawn(l2_verify_and_apply_task( - Arc::clone(backend), + Arc::clone(&backend), ctx.clone(), L2VerifyApplyConfig { block_import: Arc::clone(&config.block_importer), backup_every_n_blocks: config.backup_every_n_blocks, flush_every_n_blocks: config.flush_every_n_blocks, flush_every_n_seconds: config.flush_every_n_seconds, - stop_on_sync: config.stop_on_sync, + stop_on_sync: config.stop_on_sync || warp_update_shutdown_sender, telemetry: config.telemetry, validation: validation.clone(), block_conv_receiver, }, )); join_set.spawn(l2_pending_block_task( - Arc::clone(backend), + Arc::clone(&backend), provider, ctx.clone(), L2PendingBlockConfig { @@ -370,7 +372,7 @@ mod tests { let (block_conv_sender, block_conv_receiver) = mpsc::channel(100); let block_import = Arc::new(BlockImporter::new(backend.clone(), None).unwrap()); let validation = BlockValidationContext::new(backend.chain_config().chain_id.clone()); - let telemetry = TelemetryService::new(true, vec![]).unwrap().new_handle(); + let telemetry = Arc::new(TelemetryService::new(vec![]).unwrap().new_handle()); let mock_block = create_dummy_unverified_full_block(); diff --git a/crates/client/sync/src/lib.rs b/crates/client/sync/src/lib.rs index 6bb874f9e..588ed353b 100644 --- a/crates/client/sync/src/lib.rs +++ b/crates/client/sync/src/lib.rs @@ -21,13 +21,13 @@ pub struct SyncConfig { pub block_importer: Arc<BlockImporter>, pub starting_block: Option<u64>, pub backup_every_n_blocks: Option<u64>, - pub telemetry: TelemetryHandle, + pub telemetry: Arc<TelemetryHandle>, pub pending_block_poll_interval: Duration, } #[tracing::instrument(skip(backend, ctx, fetch_config, sync_config))] pub async fn l2_sync_worker( - backend: &Arc<MadaraBackend>, + backend: Arc<MadaraBackend>, ctx: ServiceContext, fetch_config: FetchConfig, sync_config: SyncConfig, @@ -56,31 +56,25 @@ pub async fn l2_sync_worker( ) } - l2::sync( - backend, - provider, - ctx, - L2SyncConfig { - first_block: starting_block, - n_blocks_to_sync: fetch_config.n_blocks_to_sync, - stop_on_sync: fetch_config.stop_on_sync, - verify: fetch_config.verify, - sync_polling_interval: fetch_config.sync_polling_interval, - backup_every_n_blocks: sync_config.backup_every_n_blocks, - flush_every_n_blocks: fetch_config.flush_every_n_blocks, - flush_every_n_seconds: fetch_config.flush_every_n_seconds, - pending_block_poll_interval: sync_config.pending_block_poll_interval, - ignore_block_order, - sync_parallelism: fetch_config.sync_parallelism, - warp_update: fetch_config.warp_update, - warp_update_port_rpc: fetch_config.warp_update_port_rpc, - warp_update_port_fgw: fetch_config.warp_update_port_fgw, - chain_id: backend.chain_config().chain_id.clone(), - telemetry: sync_config.telemetry, - block_importer: sync_config.block_importer, - }, - ) - .await?; + let l2_config = L2SyncConfig { + first_block: starting_block, + n_blocks_to_sync: fetch_config.n_blocks_to_sync, + stop_on_sync: fetch_config.stop_on_sync, + verify: fetch_config.verify, + sync_polling_interval: fetch_config.sync_polling_interval, + backup_every_n_blocks: sync_config.backup_every_n_blocks, + flush_every_n_blocks: fetch_config.flush_every_n_blocks, + flush_every_n_seconds: fetch_config.flush_every_n_seconds, + pending_block_poll_interval: sync_config.pending_block_poll_interval, + ignore_block_order, + sync_parallelism: fetch_config.sync_parallelism, + chain_id: backend.chain_config().chain_id.clone(), + telemetry: sync_config.telemetry, + block_importer: sync_config.block_importer, + warp_update: fetch_config.warp_update, + }; + + l2::sync(backend, provider, ctx, l2_config).await?; Ok(()) } diff --git a/crates/client/sync/src/tests/utils/read_resource.rs b/crates/client/sync/src/tests/utils/read_resource.rs index 1e57d11b9..ae0fde06d 100644 --- a/crates/client/sync/src/tests/utils/read_resource.rs +++ b/crates/client/sync/src/tests/utils/read_resource.rs @@ -5,5 +5,5 @@ use std::string::String; pub fn read_resource_file(path_in_resource_dir: &str) -> String { let path = Path::new(&env::var("CARGO_MANIFEST_DIR").unwrap()).join("resources").join(path_in_resource_dir); - return read_to_string(path.to_str().unwrap()).unwrap(); + read_to_string(path.to_str().unwrap()).unwrap() } diff --git a/crates/client/telemetry/src/lib.rs b/crates/client/telemetry/src/lib.rs index c5e3bebdd..3ce417f65 100644 --- a/crates/client/telemetry/src/lib.rs +++ b/crates/client/telemetry/src/lib.rs @@ -1,13 +1,9 @@ -use std::sync::Arc; use std::time::SystemTime; use anyhow::Context; use futures::SinkExt; -use mp_utils::channel_wait_or_graceful_shutdown; -use mp_utils::service::{MadaraService, Service, ServiceContext}; -use reqwest_websocket::{Message, RequestBuilderExt}; -use tokio::sync::mpsc; -use tokio::task::JoinSet; +use mp_utils::service::{MadaraServiceId, PowerOfTwo, Service, ServiceContext, ServiceId, ServiceRunner}; +use reqwest_websocket::{Message, RequestBuilderExt, WebSocket}; mod sysinfo; pub use sysinfo::*; @@ -18,42 +14,33 @@ pub enum VerbosityLevel { Debug = 1, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct TelemetryEvent { verbosity: VerbosityLevel, message: serde_json::Value, } #[derive(Debug, Clone)] -pub struct TelemetryHandle(Option<Arc<mpsc::Sender<TelemetryEvent>>>); +#[repr(transparent)] +pub struct TelemetryHandle(tokio::sync::broadcast::Sender<TelemetryEvent>); impl TelemetryHandle { pub fn send(&self, verbosity: VerbosityLevel, message: serde_json::Value) { if message.get("msg").is_none() { tracing::warn!("Telemetry messages should have a message type"); } - if let Some(tx) = &self.0 { - // drop the message if the channel if full. - let _ = tx.try_send(TelemetryEvent { verbosity, message }); - } + let _ = self.0.send(TelemetryEvent { verbosity, message }); } } pub struct TelemetryService { - telemetry: bool, telemetry_endpoints: Vec<(String, u8)>, telemetry_handle: TelemetryHandle, - start_state: Option<mpsc::Receiver<TelemetryEvent>>, } impl TelemetryService { - pub fn new(telemetry: bool, telemetry_endpoints: Vec<(String, u8)>) -> anyhow::Result<Self> { - let (telemetry_handle, start_state) = if !telemetry { - (TelemetryHandle(None), None) - } else { - let (tx, rx) = mpsc::channel(1024); - (TelemetryHandle(Some(Arc::new(tx))), Some(rx)) - }; - Ok(Self { telemetry, telemetry_endpoints, telemetry_handle, start_state }) + pub fn new(telemetry_endpoints: Vec<(String, u8)>) -> anyhow::Result<Self> { + let telemetry_handle = TelemetryHandle(tokio::sync::broadcast::channel(1024).0); + Ok(Self { telemetry_endpoints, telemetry_handle }) } pub fn new_handle(&self) -> TelemetryHandle { @@ -93,71 +80,72 @@ impl TelemetryService { #[async_trait::async_trait] impl Service for TelemetryService { - async fn start(&mut self, join_set: &mut JoinSet<anyhow::Result<()>>, ctx: ServiceContext) -> anyhow::Result<()> { - if !self.telemetry { - return Ok(()); - } + async fn start<'a>(&mut self, runner: ServiceRunner<'a>) -> anyhow::Result<()> { + let rx = self.telemetry_handle.0.subscribe(); + let clients = start_clients(&self.telemetry_endpoints).await; - let telemetry_endpoints = self.telemetry_endpoints.clone(); - let mut rx = self.start_state.take().context("the service has already been started")?; - - join_set.spawn(async move { - let client = &reqwest::Client::default(); - let mut clients = futures::future::join_all(telemetry_endpoints.iter().map(|(endpoint, pr)| async move { - let websocket = match client.get(endpoint).upgrade().send().await { - Ok(ws) => ws, - Err(err) => { - tracing::warn!("Could not connect to telemetry endpoint '{endpoint}': {err:?}"); - return None; - } - }; - let websocket = match websocket.into_websocket().await { - Ok(ws) => ws, - Err(err) => { - tracing::warn!("Could not connect websocket to telemetry endpoint '{endpoint}': {err:?}"); - return None; - } - }; - Some((websocket, *pr, endpoint.clone())) - })) - .await; - - let rx = &mut rx; - - while let Some(event) = channel_wait_or_graceful_shutdown(rx.recv(), &ctx).await { - tracing::debug!( - "Sending telemetry event '{}'.", - event.message.get("msg").and_then(|e| e.as_str()).unwrap_or("<unknown>") - ); - let ts = chrono::Local::now().to_rfc3339(); - let msg = serde_json::json!({ "id": 1, "ts": ts, "payload": event.message }); - let msg = &serde_json::to_string(&msg).context("serializing telemetry message to string")?; - - futures::future::join_all(clients.iter_mut().map(|client| async move { - if let Some((websocket, verbosity, endpoint)) = client { - if *verbosity >= event.verbosity as u8 { - tracing::trace!("send telemetry to '{endpoint}'"); - match websocket.send(Message::Text(msg.clone())).await { - Ok(_) => {} - Err(err) => { - tracing::warn!( - "Could not connect send telemetry to endpoint '{endpoint}': {err:#}" - ); - } - } - } - } - })) - .await; - } + runner.service_loop(move |ctx| start_telemetry(rx, ctx, clients)); - Ok(()) - }); + anyhow::Ok(()) + } +} - Ok(()) +impl ServiceId for TelemetryService { + #[inline(always)] + fn svc_id(&self) -> PowerOfTwo { + MadaraServiceId::Telemetry.svc_id() } +} + +async fn start_clients(telemetry_endpoints: &[(String, u8)]) -> Vec<Option<(WebSocket, u8, String)>> { + let client = &reqwest::Client::default(); + futures::future::join_all(telemetry_endpoints.iter().map(|(endpoint, pr)| async move { + let websocket = match client.get(endpoint).upgrade().send().await { + Ok(ws) => ws, + Err(err) => { + tracing::warn!("Failed to connect to telemetry endpoint '{endpoint}': {err:?}"); + return None; + } + }; + let websocket = match websocket.into_websocket().await { + Ok(ws) => ws, + Err(err) => { + tracing::warn!("Failed to connect websocket to telemetry endpoint '{endpoint}': {err:?}"); + return None; + } + }; + Some((websocket, *pr, endpoint.clone())) + })) + .await +} - fn id(&self) -> MadaraService { - MadaraService::Telemetry +async fn start_telemetry( + mut rx: tokio::sync::broadcast::Receiver<TelemetryEvent>, + mut ctx: ServiceContext, + mut clients: Vec<Option<(WebSocket, u8, String)>>, +) -> anyhow::Result<()> { + while let Some(Ok(event)) = ctx.run_until_cancelled(rx.recv()).await { + tracing::debug!( + "Sending telemetry event '{}'.", + event.message.get("msg").and_then(|e| e.as_str()).unwrap_or("<unknown>") + ); + + let ts = chrono::Local::now().to_rfc3339(); + let msg = serde_json::json!({ "id": 1, "ts": ts, "payload": event.message }); + let msg = &serde_json::to_string(&msg).context("serializing telemetry message to string")?; + + futures::future::join_all(clients.iter_mut().map(|client| async move { + if let Some((websocket, verbosity, endpoint)) = client { + if *verbosity >= event.verbosity as u8 { + tracing::trace!("Sending telemetry to '{endpoint}'"); + if let Err(err) = websocket.send(Message::Text(msg.clone())).await { + tracing::warn!("Failed to send telemetry to endpoint '{endpoint}': {err:#}"); + } + } + } + })) + .await; } + + anyhow::Ok(()) } diff --git a/crates/node/src/cli/l1.rs b/crates/node/src/cli/l1.rs index cd5edf7c0..2d28a8b25 100644 --- a/crates/node/src/cli/l1.rs +++ b/crates/node/src/cli/l1.rs @@ -8,7 +8,7 @@ use mp_utils::parsers::{parse_duration, parse_url}; pub struct L1SyncParams { /// Disable L1 sync. #[clap(env = "MADARA_SYNC_L1_DISABLED", long, alias = "no-l1-sync", conflicts_with = "l1_endpoint")] - pub sync_l1_disabled: bool, + pub l1_sync_disabled: bool, /// The L1 rpc endpoint url for state verification. #[clap(env = "MADARA_L1_ENDPOINT", long, value_parser = parse_url, value_name = "ETHEREUM RPC URL")] diff --git a/crates/node/src/cli/sync.rs b/crates/node/src/cli/l2.rs similarity index 91% rename from crates/node/src/cli/sync.rs rename to crates/node/src/cli/l2.rs index 987b91730..3765f76db 100644 --- a/crates/node/src/cli/sync.rs +++ b/crates/node/src/cli/l2.rs @@ -1,5 +1,6 @@ use std::{sync::Arc, time::Duration}; +use mc_sync::fetch::fetchers::WarpUpdateConfig; use mp_chain_config::ChainConfig; use starknet_api::core::ChainId; @@ -11,10 +12,10 @@ use super::FGW_DEFAULT_PORT; use super::RPC_DEFAULT_PORT_ADMIN; #[derive(Clone, Debug, clap::Args)] -pub struct SyncParams { +pub struct L2SyncParams { /// Disable the sync service. The sync service is responsible for listening for new blocks on starknet and ethereum. #[clap(env = "MADARA_SYNC_DISABLED", long, alias = "no-sync")] - pub sync_disabled: bool, + pub l2_sync_disabled: bool, /// The block you want to start syncing from. This will most probably break your database. #[clap(env = "MADARA_UNSAFE_STARTING_BLOCK", long, value_name = "BLOCK NUMBER")] @@ -39,9 +40,17 @@ pub struct SyncParams { pub warp_update_port_rpc: u16, /// The port used for nodes to send blocks during a warp update. - #[arg(env = "MADARA_WARP_UPDATE_PORT_FGW", long, value_name = "WARP UPDATE FGW", default_value_t = FGW_DEFAULT_PORT)] + #[arg(env = "MADARA_WARP_UPDATE_PORT_FGW", long, value_name = "WARP UPDATE PORT FGW", default_value_t = FGW_DEFAULT_PORT)] pub warp_update_port_fgw: u16, + /// Whether to shut down the warp update sender once the migration has completed + #[arg(env = "MADARA_WARP_UPDATE_SHUTDOWN_SENDER", long, default_value_t = false)] + pub warp_update_shutdown_sender: bool, + + /// Whether to shut down the warp update receiver once the migration has completed + #[arg(env = "MADARA_WARP_UPDATE_SHUTDOWN_RECEIVER", long, default_value_t = false)] + pub warp_update_shutdown_receiver: bool, + /// Polling interval, in seconds. This only affects the sync service once it has caught up with the blockchain tip. #[clap( env = "MADARA_SYNC_POLLING_INTERVAL", @@ -137,12 +146,12 @@ pub struct SyncParams { pub sync_parallelism: u8, } -impl SyncParams { +impl L2SyncParams { pub fn block_fetch_config( &self, chain_id: ChainId, chain_config: Arc<ChainConfig>, - warp_update: bool, + warp_update: Option<WarpUpdateConfig>, ) -> FetchConfig { let (gateway, feeder_gateway) = match &self.gateway_url { Some(url) => ( @@ -167,8 +176,6 @@ impl SyncParams { stop_on_sync: self.stop_on_sync, sync_parallelism: self.sync_parallelism, warp_update, - warp_update_port_rpc: self.warp_update_port_rpc, - warp_update_port_fgw: self.warp_update_port_fgw, } } } diff --git a/crates/node/src/cli/mod.rs b/crates/node/src/cli/mod.rs index 8cbf58bdf..557f6d8de 100644 --- a/crates/node/src/cli/mod.rs +++ b/crates/node/src/cli/mod.rs @@ -4,8 +4,8 @@ pub mod chain_config_overrides; pub mod db; pub mod gateway; pub mod l1; +pub mod l2; pub mod rpc; -pub mod sync; pub mod telemetry; use crate::cli::l1::L1SyncParams; use analytics::AnalyticsParams; @@ -14,10 +14,10 @@ pub use block_production::*; pub use chain_config_overrides::*; pub use db::*; pub use gateway::*; +pub use l2::*; pub use rpc::*; use starknet_api::core::ChainId; use std::str::FromStr; -pub use sync::*; pub use telemetry::*; use clap::ArgGroup; @@ -150,7 +150,7 @@ pub struct RunCmd { #[allow(missing_docs)] #[clap(flatten)] - pub sync_params: SyncParams, + pub l2_sync_params: L2SyncParams, #[allow(missing_docs)] #[clap(flatten)] @@ -213,11 +213,9 @@ impl RunCmd { pub fn apply_arg_preset(mut self) -> Self { if self.args_preset.warp_update_sender { self.gateway_params.feeder_gateway_enable = true; - self.gateway_params.gateway_port = self.sync_params.warp_update_port_fgw; + self.gateway_params.gateway_port = self.l2_sync_params.warp_update_port_fgw; self.rpc_params.rpc_admin = true; - self.rpc_params.rpc_admin_port = self.sync_params.warp_update_port_rpc; - } else if self.args_preset.warp_update_receiver { - self.rpc_params.rpc_disable = true; + self.rpc_params.rpc_admin_port = self.l2_sync_params.warp_update_port_rpc; } else if self.args_preset.gateway { self.gateway_params.feeder_gateway_enable = true; self.gateway_params.gateway_enable = true; diff --git a/crates/node/src/main.rs b/crates/node/src/main.rs index f30f3e80c..bb14def92 100644 --- a/crates/node/src/main.rs +++ b/crates/node/src/main.rs @@ -15,9 +15,10 @@ use mc_db::{DatabaseService, TrieLogConfig}; use mc_gateway_client::GatewayProvider; use mc_mempool::{GasPriceProvider, L1DataProvider, Mempool, MempoolLimits}; use mc_rpc::providers::{AddTransactionProvider, ForwardToProvider, MempoolAddTxProvider}; +use mc_sync::fetch::fetchers::WarpUpdateConfig; use mc_telemetry::{SysInfo, TelemetryService}; use mp_oracle::pragma::PragmaOracleBuilder; -use mp_utils::service::{Service, ServiceGroup}; +use mp_utils::service::{MadaraServiceId, ServiceMonitor}; use service::{BlockProductionService, GatewayService, L1SyncService, L2SyncService, RpcService}; use std::sync::Arc; @@ -29,7 +30,7 @@ async fn main() -> anyhow::Result<()> { crate::util::setup_rayon_threadpool()?; crate::util::raise_fdlimit(); - let mut run_cmd: RunCmd = RunCmd::parse().apply_arg_preset(); + let mut run_cmd = RunCmd::parse().apply_arg_preset(); // Setting up analytics @@ -49,13 +50,25 @@ async fn main() -> anyhow::Result<()> { run_cmd.chain_config()? }; + // Check if the devnet is running with the correct chain id. + if run_cmd.devnet && chain_config.chain_id != NetworkType::Devnet.chain_id() { + if !run_cmd.block_production_params.override_devnet_chain_id { + tracing::error!("You're running a devnet with the network config of {:?}. This means that devnet transactions can be replayed on the actual network. Use `--network=devnet` instead. Or if this is the expected behavior please pass `--override-devnet-chain-id`", chain_config.chain_name); + panic!(); + } else { + // This log is immediately flooded with devnet accounts and so this can be missed. + // Should we add a delay here to make this clearly visisble? + tracing::warn!("You're running a devnet with the network config of {:?}. This means that devnet transactions can be replayed on the actual network.", run_cmd.network); + } + } + let node_name = run_cmd.node_name_or_provide().await.to_string(); let node_version = env!("MADARA_BUILD_VERSION"); - tracing::info!("🥷 {} Node", GREET_IMPL_NAME); + tracing::info!("🥷 {} Node", GREET_IMPL_NAME); tracing::info!("✌️ Version {}", node_version); tracing::info!("💁 Support URL: {}", GREET_SUPPORT_URL); - tracing::info!("🏷 Node Name: {}", node_name); + tracing::info!("🏷 Node Name: {}", node_name); let role = if run_cmd.is_sequencer() { "Sequencer" } else { "Full Node" }; tracing::info!("👤 Role: {}", role); tracing::info!("🌐 Network: {} (chain id `{}`)", chain_config.chain_name, chain_config.chain_id); @@ -64,13 +77,19 @@ async fn main() -> anyhow::Result<()> { let sys_info = SysInfo::probe(); sys_info.show(); - // Services. + // ===================================================================== // + // SERVICES (SETUP) // + // ===================================================================== // - let telemetry_service: TelemetryService = - TelemetryService::new(run_cmd.telemetry_params.telemetry, run_cmd.telemetry_params.telemetry_endpoints.clone()) + // Telemetry + + let service_telemetry: TelemetryService = + TelemetryService::new(run_cmd.telemetry_params.telemetry_endpoints.clone()) .context("Initializing telemetry service")?; - let db_service = DatabaseService::new( + // Database + + let service_db = DatabaseService::new( &run_cmd.db_params.base_path, run_cmd.db_params.backup_dir.clone(), run_cmd.db_params.restore_from_latest_backup, @@ -84,10 +103,7 @@ async fn main() -> anyhow::Result<()> { .await .context("Initializing db service")?; - let importer = Arc::new( - BlockImporter::new(Arc::clone(db_service.backend()), run_cmd.sync_params.unsafe_starting_block) - .context("Initializing importer service")?, - ); + // L1 Sync let mut l1_gas_setter = GasPriceProvider::new(); @@ -117,7 +133,7 @@ async fn main() -> anyhow::Result<()> { } } - if !run_cmd.l1_sync_params.sync_l1_disabled + if !run_cmd.l1_sync_params.l1_sync_disabled && l1_gas_setter.is_oracle_needed() && l1_gas_setter.oracle_provider.is_none() { @@ -128,16 +144,16 @@ async fn main() -> anyhow::Result<()> { // declare mempool here so that it can be used to process l1->l2 messages in the l1 service let mut mempool = Mempool::new( - Arc::clone(db_service.backend()), + Arc::clone(service_db.backend()), Arc::clone(&l1_data_provider), MempoolLimits::new(&chain_config), ); mempool.load_txs_from_db().context("Loading mempool transactions")?; let mempool = Arc::new(mempool); - let l1_service = L1SyncService::new( + let service_l1_sync = L1SyncService::new( &run_cmd.l1_sync_params, - &db_service, + &service_db, l1_gas_setter, chain_config.chain_id.clone(), chain_config.eth_core_contract_address, @@ -148,84 +164,178 @@ async fn main() -> anyhow::Result<()> { .await .context("Initializing the l1 sync service")?; - // Block provider startup. - // `rpc_add_txs_method_provider` is a trait object that tells the RPC task where to put the transactions when using the Write endpoints. - let (block_provider_service, rpc_add_txs_method_provider): (_, Arc<dyn AddTransactionProvider>) = - match run_cmd.is_sequencer() { - // Block production service. (authority) - true => { - let block_production_service = BlockProductionService::new( - &run_cmd.block_production_params, - &db_service, - Arc::clone(&mempool), - importer, - Arc::clone(&l1_data_provider), - run_cmd.devnet, - telemetry_service.new_handle(), - )?; - - (ServiceGroup::default().with(block_production_service), Arc::new(MempoolAddTxProvider::new(mempool))) - } - // Block sync service. (full node) - false => { - // Feeder gateway sync service. - let sync_service = L2SyncService::new( - &run_cmd.sync_params, - Arc::clone(&chain_config), - &db_service, - importer, - telemetry_service.new_handle(), - run_cmd.args_preset.warp_update_receiver, - ) - .await - .context("Initializing sync service")?; - - let mut provider = - GatewayProvider::new(chain_config.gateway_url.clone(), chain_config.feeder_gateway_url.clone()); - // gateway api key is needed for declare transactions on mainnet - if let Some(api_key) = run_cmd.sync_params.gateway_key { - provider.add_header( - HeaderName::from_static("x-throttling-bypass"), - HeaderValue::from_str(&api_key).with_context(|| "Invalid API key format")?, - ) - } - - (ServiceGroup::default().with(sync_service), Arc::new(ForwardToProvider::new(provider))) - } - }; - - let rpc_service = - RpcService::new(run_cmd.rpc_params, Arc::clone(db_service.backend()), Arc::clone(&rpc_add_txs_method_provider)); - - let gateway_service = GatewayService::new(run_cmd.gateway_params, &db_service, rpc_add_txs_method_provider) - .await - .context("Initializing gateway service")?; - - telemetry_service.send_connected(&node_name, node_version, &chain_config.chain_name, &sys_info); - - let app = ServiceGroup::default() - .with(db_service) - .with(l1_service) - .with(block_provider_service) - .with(rpc_service) - .with(gateway_service) - .with(telemetry_service); + // L2 Sync - // Check if the devnet is running with the correct chain id. - if run_cmd.devnet && chain_config.chain_id != NetworkType::Devnet.chain_id() { - if !run_cmd.block_production_params.override_devnet_chain_id { - tracing::error!("You are running a devnet with the network config of {:?}. This means that devnet transactions can be replayed on the actual network. Use `--network=devnet` instead. Or if this is the expected behavior please pass `--override-devnet-chain-id`", chain_config.chain_name); - panic!(); - } else { - // This log is immediately flooded with devnet accounts and so this can be missed. - // Should we add a delay here to make this clearly visisble? - tracing::warn!("You are running a devnet with the network config of {:?}. This means that devnet transactions can be replayed on the actual network.", run_cmd.network); + let importer = Arc::new( + BlockImporter::new(Arc::clone(service_db.backend()), run_cmd.l2_sync_params.unsafe_starting_block) + .context("Initializing importer service")?, + ); + + let warp_update = if run_cmd.args_preset.warp_update_receiver { + let mut deferred_service_start = vec![]; + let mut deferred_service_stop = vec![]; + + if !run_cmd.rpc_params.rpc_disable { + deferred_service_start.push(MadaraServiceId::RpcUser); } + + if run_cmd.rpc_params.rpc_admin { + deferred_service_start.push(MadaraServiceId::RpcAdmin); + } + + if run_cmd.gateway_params.feeder_gateway_enable { + deferred_service_start.push(MadaraServiceId::Gateway); + } + + if run_cmd.telemetry_params.telemetry { + deferred_service_start.push(MadaraServiceId::Telemetry); + } + + if run_cmd.is_sequencer() { + deferred_service_start.push(MadaraServiceId::BlockProduction); + deferred_service_stop.push(MadaraServiceId::L2Sync); + } + + Some(WarpUpdateConfig { + warp_update_port_rpc: run_cmd.l2_sync_params.warp_update_port_rpc, + warp_update_port_fgw: run_cmd.l2_sync_params.warp_update_port_fgw, + warp_update_shutdown_sender: run_cmd.l2_sync_params.warp_update_shutdown_sender, + warp_update_shutdown_receiver: run_cmd.l2_sync_params.warp_update_shutdown_receiver, + deferred_service_start, + deferred_service_stop, + }) + } else { + None + }; + + let service_l2_sync = L2SyncService::new( + &run_cmd.l2_sync_params, + Arc::clone(&chain_config), + &service_db, + importer, + service_telemetry.new_handle(), + warp_update, + ) + .await + .context("Initializing sync service")?; + + let mut provider = GatewayProvider::new(chain_config.gateway_url.clone(), chain_config.feeder_gateway_url.clone()); + + // gateway api key is needed for declare transactions on mainnet + if let Some(api_key) = run_cmd.l2_sync_params.gateway_key.clone() { + provider.add_header( + HeaderName::from_static("x-throttling-bypass"), + HeaderValue::from_str(&api_key).with_context(|| "Invalid API key format")?, + ) + } + + // Block production + + let importer = Arc::new( + BlockImporter::new(Arc::clone(service_db.backend()), run_cmd.l2_sync_params.unsafe_starting_block) + .context("Initializing importer service")?, + ); + let service_block_production = BlockProductionService::new( + &run_cmd.block_production_params, + &service_db, + Arc::clone(&mempool), + importer, + Arc::clone(&l1_data_provider), + )?; + + // Add transaction provider + let add_tx_provider_l2_sync: Arc<dyn AddTransactionProvider> = Arc::new(ForwardToProvider::new(provider)); + let add_tx_provider_mempool: Arc<dyn AddTransactionProvider> = Arc::new(MempoolAddTxProvider::new(mempool)); + + // User-facing RPC + + let service_rpc_user = RpcService::user( + run_cmd.rpc_params.clone(), + Arc::clone(service_db.backend()), + Arc::clone(&add_tx_provider_l2_sync), + Arc::clone(&add_tx_provider_mempool), + ); + + // Admin-facing RPC (for node operators) + + let service_rpc_admin = RpcService::admin( + run_cmd.rpc_params.clone(), + Arc::clone(service_db.backend()), + Arc::clone(&add_tx_provider_l2_sync), + Arc::clone(&add_tx_provider_mempool), + ); + + // Feeder gateway + + let service_gateway = GatewayService::new( + run_cmd.gateway_params.clone(), + Arc::clone(service_db.backend()), + Arc::clone(&add_tx_provider_l2_sync), + Arc::clone(&add_tx_provider_mempool), + ) + .await + .context("Initializing gateway service")?; + + service_telemetry.send_connected(&node_name, node_version, &chain_config.chain_name, &sys_info); + + // ===================================================================== // + // SERVICES (START) // + // ===================================================================== // + + if run_cmd.is_sequencer() { + service_block_production.setup_devnet().await?; + } + + let app = ServiceMonitor::default() + .with(service_db)? + .with(service_l1_sync)? + .with(service_l2_sync)? + .with(service_block_production)? + .with(service_rpc_user)? + .with(service_rpc_admin)? + .with(service_gateway)? + .with(service_telemetry)?; + + // Since the database is not implemented as a proper service, we do not + // active it, as it would never be marked as stopped by the existing logic + // + // app.activate(MadaraService::Database); + + let l1_sync_enabled = !run_cmd.l1_sync_params.l1_sync_disabled; + let l1_endpoint_some = run_cmd.l1_sync_params.l1_endpoint.is_some(); + let warp_update_receiver = run_cmd.args_preset.warp_update_receiver; + + if l1_sync_enabled && (l1_endpoint_some || !run_cmd.devnet) { + app.activate(MadaraServiceId::L1Sync); + } + + if warp_update_receiver { + app.activate(MadaraServiceId::L2Sync); + } else if run_cmd.is_sequencer() { + app.activate(MadaraServiceId::BlockProduction); + } else if !run_cmd.l2_sync_params.l2_sync_disabled { + app.activate(MadaraServiceId::L2Sync); + } + + if !run_cmd.rpc_params.rpc_disable && !warp_update_receiver { + app.activate(MadaraServiceId::RpcUser); + } + + if run_cmd.rpc_params.rpc_admin && !warp_update_receiver { + app.activate(MadaraServiceId::RpcAdmin); + } + + if run_cmd.gateway_params.feeder_gateway_enable && !warp_update_receiver { + app.activate(MadaraServiceId::Gateway); + } + + if run_cmd.telemetry_params.telemetry && !warp_update_receiver { + app.activate(MadaraServiceId::Telemetry); } - app.start_and_drive_to_end().await?; + app.start().await?; let _ = analytics.shutdown(); - Ok(()) + anyhow::Ok(()) } diff --git a/crates/node/src/service/block_production.rs b/crates/node/src/service/block_production.rs index e8c759b40..0d6b4dbbe 100644 --- a/crates/node/src/service/block_production.rs +++ b/crates/node/src/service/block_production.rs @@ -5,25 +5,18 @@ use mc_block_production::{metrics::BlockProductionMetrics, BlockProductionTask}; use mc_db::{DatabaseService, MadaraBackend}; use mc_devnet::{ChainGenesisDescription, DevnetKeys}; use mc_mempool::{L1DataProvider, Mempool}; -use mc_telemetry::TelemetryHandle; -use mp_utils::service::{MadaraService, Service, ServiceContext}; +use mp_utils::service::{MadaraServiceId, PowerOfTwo, Service, ServiceId, ServiceRunner}; use std::{io::Write, sync::Arc}; -use tokio::task::JoinSet; -struct StartParams { +pub struct BlockProductionService { backend: Arc<MadaraBackend>, block_import: Arc<BlockImporter>, mempool: Arc<Mempool>, - metrics: BlockProductionMetrics, + metrics: Arc<BlockProductionMetrics>, l1_data_provider: Arc<dyn L1DataProvider>, - is_devnet: bool, n_devnet_contracts: u64, } -pub struct BlockProductionService { - start: Option<StartParams>, - enabled: bool, -} impl BlockProductionService { #[allow(clippy::too_many_arguments)] pub fn new( @@ -32,26 +25,16 @@ impl BlockProductionService { mempool: Arc<mc_mempool::Mempool>, block_import: Arc<BlockImporter>, l1_data_provider: Arc<dyn L1DataProvider>, - is_devnet: bool, - _telemetry: TelemetryHandle, ) -> anyhow::Result<Self> { - if config.block_production_disabled { - return Ok(Self { start: None, enabled: false }); - } - - let metrics = BlockProductionMetrics::register(); + let metrics = Arc::new(BlockProductionMetrics::register()); Ok(Self { - start: Some(StartParams { - backend: Arc::clone(db_service.backend()), - l1_data_provider, - mempool, - metrics, - block_import, - n_devnet_contracts: config.devnet_contracts, - is_devnet, - }), - enabled: true, + backend: Arc::clone(db_service.backend()), + l1_data_provider, + mempool, + metrics, + block_import, + n_devnet_contracts: config.devnet_contracts, }) } } @@ -59,66 +42,72 @@ impl BlockProductionService { #[async_trait::async_trait] impl Service for BlockProductionService { // TODO(cchudant,2024-07-30): special threading requirements for the block production task - #[tracing::instrument(skip(self, join_set, ctx), fields(module = "BlockProductionService"))] - async fn start(&mut self, join_set: &mut JoinSet<anyhow::Result<()>>, ctx: ServiceContext) -> anyhow::Result<()> { - if !self.enabled { - return Ok(()); - } - let StartParams { backend, l1_data_provider, mempool, metrics, is_devnet, n_devnet_contracts, block_import } = - self.start.take().expect("Service already started"); - - if is_devnet { - // DEVNET: we the genesis block for the devnet if not deployed, otherwise we only print the devnet keys. - - let keys = if backend.get_latest_block_n().context("Getting the latest block number in db")?.is_none() { - // deploy devnet genesis - - tracing::info!("⛏️ Deploying devnet genesis block"); - - let mut genesis_config = - ChainGenesisDescription::base_config().context("Failed to create base genesis config")?; - let contracts = genesis_config - .add_devnet_contracts(n_devnet_contracts) - .context("Failed to add devnet contracts")?; - - let genesis_block = genesis_config - .build(backend.chain_config()) - .context("Building genesis block from devnet config")?; - - block_import - .add_block( - genesis_block, - BlockValidationContext::new(backend.chain_config().chain_id.clone()).trust_class_hashes(true), - ) - .await - .context("Importing devnet genesis block")?; - - contracts.save_to_db(&backend).context("Saving predeployed devnet contract keys to database")?; - - contracts - } else { - DevnetKeys::from_db(&backend).context("Getting the devnet predeployed contract keys and balances")? - }; - - // display devnet welcome message :) - // we display it to stdout instead of stderr - - let msg = format!("{}", keys); - - std::io::stdout().write(msg.as_bytes()).context("Writing devnet welcome message to stdout")?; - } - - join_set.spawn(async move { - BlockProductionTask::new(backend, block_import, mempool, metrics, l1_data_provider)? - .block_production_task(ctx) - .await?; - Ok(()) - }); + #[tracing::instrument(skip(self, runner), fields(module = "BlockProductionService"))] + async fn start<'a>(&mut self, runner: ServiceRunner<'a>) -> anyhow::Result<()> { + let Self { backend, l1_data_provider, mempool, metrics, block_import, .. } = self; + + let block_production_task = BlockProductionTask::new( + Arc::clone(backend), + Arc::clone(block_import), + Arc::clone(mempool), + Arc::clone(metrics), + Arc::clone(l1_data_provider), + )?; + + runner.service_loop(move |ctx| block_production_task.block_production_task(ctx)); Ok(()) } +} - fn id(&self) -> MadaraService { - MadaraService::BlockProduction +impl ServiceId for BlockProductionService { + #[inline(always)] + fn svc_id(&self) -> PowerOfTwo { + MadaraServiceId::BlockProduction.svc_id() + } +} + +impl BlockProductionService { + /// Initializes the genesis state of a devnet. This is needed for local sequencers. + /// + /// This methods was made external to [Service::start] as it needs to be + /// called on node startup even if sequencer block production is not yet + /// enabled. This happens during warp updates on a local sequencer. + pub async fn setup_devnet(&self) -> anyhow::Result<()> { + let Self { backend, n_devnet_contracts, block_import, .. } = self; + + let keys = if backend.get_latest_block_n().context("Getting the latest block number in db")?.is_none() { + // deploy devnet genesis + tracing::info!("⛏️ Deploying devnet genesis block"); + + let mut genesis_config = + ChainGenesisDescription::base_config().context("Failed to create base genesis config")?; + let contracts = + genesis_config.add_devnet_contracts(*n_devnet_contracts).context("Failed to add devnet contracts")?; + + let genesis_block = + genesis_config.build(backend.chain_config()).context("Building genesis block from devnet config")?; + + block_import + .add_block( + genesis_block, + BlockValidationContext::new(backend.chain_config().chain_id.clone()).trust_class_hashes(true), + ) + .await + .context("Importing devnet genesis block")?; + + contracts.save_to_db(backend).context("Saving predeployed devnet contract keys to database")?; + + contracts + } else { + DevnetKeys::from_db(backend).context("Getting the devnet predeployed contract keys and balances")? + }; + + // display devnet welcome message :) + // we display it to stdout instead of stderr + let msg = format!("{}", keys); + std::io::stdout().write(msg.as_bytes()).context("Writing devnet welcome message to stdout")?; + + anyhow::Ok(()) } } diff --git a/crates/node/src/service/gateway.rs b/crates/node/src/service/gateway.rs index 6fbf6fc0e..6ba1af54c 100644 --- a/crates/node/src/service/gateway.rs +++ b/crates/node/src/service/gateway.rs @@ -1,50 +1,57 @@ use crate::cli::GatewayParams; -use mc_db::{DatabaseService, MadaraBackend}; -use mc_rpc::providers::AddTransactionProvider; -use mp_utils::service::{MadaraService, Service, ServiceContext}; +use mc_db::MadaraBackend; +use mc_rpc::providers::{AddTransactionProvider, AddTransactionProviderGroup}; +use mp_utils::service::{MadaraServiceId, PowerOfTwo, Service, ServiceId, ServiceRunner}; use std::sync::Arc; -use tokio::task::JoinSet; #[derive(Clone)] pub struct GatewayService { config: GatewayParams, db_backend: Arc<MadaraBackend>, - add_transaction_provider: Arc<dyn AddTransactionProvider>, + add_txs_provider_l2_sync: Arc<dyn AddTransactionProvider>, + add_txs_provider_mempool: Arc<dyn AddTransactionProvider>, } impl GatewayService { pub async fn new( config: GatewayParams, - db: &DatabaseService, - add_transaction_provider: Arc<dyn AddTransactionProvider>, + db_backend: Arc<MadaraBackend>, + add_txs_provider_l2_sync: Arc<dyn AddTransactionProvider>, + add_txs_provider_mempool: Arc<dyn AddTransactionProvider>, ) -> anyhow::Result<Self> { - Ok(Self { config, db_backend: Arc::clone(db.backend()), add_transaction_provider }) + Ok(Self { config, db_backend, add_txs_provider_l2_sync, add_txs_provider_mempool }) } } #[async_trait::async_trait] impl Service for GatewayService { - async fn start(&mut self, join_set: &mut JoinSet<anyhow::Result<()>>, ctx: ServiceContext) -> anyhow::Result<()> { - if self.config.feeder_gateway_enable || self.config.gateway_enable { - let GatewayService { db_backend, add_transaction_provider, config } = self.clone(); + async fn start<'a>(&mut self, runner: ServiceRunner<'a>) -> anyhow::Result<()> { + let GatewayService { config, db_backend, add_txs_provider_l2_sync, add_txs_provider_mempool } = self.clone(); - join_set.spawn(async move { - mc_gateway_server::service::start_server( - db_backend, - add_transaction_provider, - config.feeder_gateway_enable, - config.gateway_enable, - config.gateway_external, - config.gateway_port, - ctx, - ) - .await - }); - } + runner.service_loop(move |ctx| { + let add_tx_provider = Arc::new(AddTransactionProviderGroup::new( + add_txs_provider_l2_sync, + add_txs_provider_mempool, + ctx.clone(), + )); + + mc_gateway_server::service::start_server( + db_backend, + add_tx_provider, + config.feeder_gateway_enable, + config.gateway_enable, + config.gateway_external, + config.gateway_port, + ctx, + ) + }); Ok(()) } +} - fn id(&self) -> MadaraService { - MadaraService::Gateway +impl ServiceId for GatewayService { + #[inline(always)] + fn svc_id(&self) -> PowerOfTwo { + MadaraServiceId::Gateway.svc_id() } } diff --git a/crates/node/src/service/l1.rs b/crates/node/src/service/l1.rs index 18c32f562..e305d741c 100644 --- a/crates/node/src/service/l1.rs +++ b/crates/node/src/service/l1.rs @@ -5,16 +5,15 @@ use mc_db::{DatabaseService, MadaraBackend}; use mc_eth::client::{EthereumClient, L1BlockMetrics}; use mc_mempool::{GasPriceProvider, Mempool}; use mp_block::H160; -use mp_utils::service::{MadaraService, Service, ServiceContext}; +use mp_utils::service::{MadaraServiceId, PowerOfTwo, Service, ServiceId, ServiceRunner}; use starknet_api::core::ChainId; use std::sync::Arc; use std::time::Duration; -use tokio::task::JoinSet; #[derive(Clone)] pub struct L1SyncService { db_backend: Arc<MadaraBackend>, - eth_client: Option<EthereumClient>, + eth_client: Option<Arc<EthereumClient>>, l1_gas_provider: GasPriceProvider, chain_id: ChainId, gas_price_sync_disabled: bool, @@ -34,15 +33,15 @@ impl L1SyncService { devnet: bool, mempool: Arc<Mempool>, ) -> anyhow::Result<Self> { - let eth_client = if !config.sync_l1_disabled && (config.l1_endpoint.is_some() || !devnet) { + let eth_client = if !config.l1_sync_disabled && (config.l1_endpoint.is_some() || !devnet) { if let Some(l1_rpc_url) = &config.l1_endpoint { let core_address = Address::from_slice(l1_core_address.as_bytes()); let l1_block_metrics = L1BlockMetrics::register().expect("Registering metrics"); - Some( - EthereumClient::new(l1_rpc_url.clone(), core_address, l1_block_metrics) - .await - .context("Creating ethereum client")?, - ) + let client = EthereumClient::new(l1_rpc_url.clone(), core_address, l1_block_metrics) + .await + .context("Creating ethereum client")?; + + Some(Arc::new(client)) } else { anyhow::bail!( "No Ethereum endpoint provided. You need to provide one using --l1-endpoint <RPC URL> in order to verify the synced state or disable the l1 watcher using --no-l1-sync." @@ -64,7 +63,7 @@ impl L1SyncService { .context("L1 gas prices require the ethereum service to be enabled. Either disable gas prices syncing using `--gas-price 0`, or disable L1 sync using the `--no-l1-sync` argument.")?; // running at-least once before the block production service tracing::info!("⏳ Getting initial L1 gas prices"); - mc_eth::l1_gas_price::gas_price_worker_once(ð_client, l1_gas_provider.clone(), gas_price_poll) + mc_eth::l1_gas_price::gas_price_worker_once(ð_client, &l1_gas_provider, gas_price_poll) .await .context("Getting initial ethereum gas prices")?; } @@ -83,18 +82,25 @@ impl L1SyncService { #[async_trait::async_trait] impl Service for L1SyncService { - async fn start(&mut self, join_set: &mut JoinSet<anyhow::Result<()>>, ctx: ServiceContext) -> anyhow::Result<()> { - let L1SyncService { l1_gas_provider, chain_id, gas_price_sync_disabled, gas_price_poll, mempool, .. } = - self.clone(); + async fn start<'a>(&mut self, runner: ServiceRunner<'a>) -> anyhow::Result<()> { + let L1SyncService { + db_backend, + l1_gas_provider, + chain_id, + gas_price_sync_disabled, + gas_price_poll, + mempool, + .. + } = self.clone(); - if let Some(eth_client) = self.eth_client.take() { + if let Some(eth_client) = &self.eth_client { // enabled - let db_backend = Arc::clone(&self.db_backend); - join_set.spawn(async move { + let eth_client = Arc::clone(eth_client); + runner.service_loop(move |ctx| { mc_eth::sync::l1_sync_worker( - &db_backend, - ð_client, + db_backend, + eth_client, chain_id, l1_gas_provider, gas_price_sync_disabled, @@ -102,14 +108,18 @@ impl Service for L1SyncService { mempool, ctx, ) - .await }); + } else { + tracing::error!("❗ Tried to start L1 Sync but no l1 endpoint was provided to the node on startup"); } Ok(()) } +} - fn id(&self) -> MadaraService { - MadaraService::L1Sync +impl ServiceId for L1SyncService { + #[inline(always)] + fn svc_id(&self) -> PowerOfTwo { + MadaraServiceId::L1Sync.svc_id() } } diff --git a/crates/node/src/service/sync.rs b/crates/node/src/service/l2.rs similarity index 65% rename from crates/node/src/service/sync.rs rename to crates/node/src/service/l2.rs index 42a747c4a..2c843729c 100644 --- a/crates/node/src/service/sync.rs +++ b/crates/node/src/service/l2.rs @@ -1,15 +1,13 @@ -use crate::cli::SyncParams; -use anyhow::Context; +use crate::cli::L2SyncParams; use mc_block_import::BlockImporter; use mc_db::{DatabaseService, MadaraBackend}; -use mc_sync::fetch::fetchers::FetchConfig; +use mc_sync::fetch::fetchers::{FetchConfig, WarpUpdateConfig}; use mc_sync::SyncConfig; use mc_telemetry::TelemetryHandle; use mp_chain_config::ChainConfig; -use mp_utils::service::{MadaraService, Service, ServiceContext}; +use mp_utils::service::{MadaraServiceId, PowerOfTwo, Service, ServiceId, ServiceRunner}; use std::sync::Arc; use std::time::Duration; -use tokio::task::JoinSet; #[derive(Clone)] pub struct L2SyncService { @@ -18,23 +16,22 @@ pub struct L2SyncService { fetch_config: FetchConfig, backup_every_n_blocks: Option<u64>, starting_block: Option<u64>, - start_params: Option<TelemetryHandle>, - disabled: bool, + telemetry: Arc<TelemetryHandle>, pending_block_poll_interval: Duration, } impl L2SyncService { pub async fn new( - config: &SyncParams, + config: &L2SyncParams, chain_config: Arc<ChainConfig>, db: &DatabaseService, block_importer: Arc<BlockImporter>, telemetry: TelemetryHandle, - warp_update: bool, + warp_update: Option<WarpUpdateConfig>, ) -> anyhow::Result<Self> { let fetch_config = config.block_fetch_config(chain_config.chain_id.clone(), chain_config.clone(), warp_update); - tracing::info!("🛰️ Using feeder gateway URL: {}", fetch_config.feeder_gateway.as_str()); + tracing::info!("🛰️ Using feeder gateway URL: {}", fetch_config.feeder_gateway.as_str()); Ok(Self { db_backend: Arc::clone(db.backend()), @@ -42,8 +39,7 @@ impl L2SyncService { starting_block: config.unsafe_starting_block, backup_every_n_blocks: config.backup_every_n_blocks, block_importer, - start_params: Some(telemetry), - disabled: config.sync_disabled, + telemetry: Arc::new(telemetry), pending_block_poll_interval: config.pending_block_poll_interval, }) } @@ -51,25 +47,21 @@ impl L2SyncService { #[async_trait::async_trait] impl Service for L2SyncService { - async fn start(&mut self, join_set: &mut JoinSet<anyhow::Result<()>>, ctx: ServiceContext) -> anyhow::Result<()> { - if self.disabled { - return Ok(()); - } + async fn start<'a>(&mut self, runner: ServiceRunner<'a>) -> anyhow::Result<()> { let L2SyncService { + db_backend, fetch_config, backup_every_n_blocks, starting_block, pending_block_poll_interval, block_importer, - .. + telemetry, } = self.clone(); - let telemetry = self.start_params.take().context("Service already started")?; + let telemetry = Arc::clone(&telemetry); - let db_backend = Arc::clone(&self.db_backend); - - join_set.spawn(async move { + runner.service_loop(move |ctx| { mc_sync::l2_sync_worker( - &db_backend, + db_backend, ctx, fetch_config, SyncConfig { @@ -80,13 +72,15 @@ impl Service for L2SyncService { pending_block_poll_interval, }, ) - .await }); Ok(()) } +} - fn id(&self) -> MadaraService { - MadaraService::L2Sync +impl ServiceId for L2SyncService { + #[inline(always)] + fn svc_id(&self) -> PowerOfTwo { + MadaraServiceId::L2Sync.svc_id() } } diff --git a/crates/node/src/service/mod.rs b/crates/node/src/service/mod.rs index 6a95afd73..84a36057b 100644 --- a/crates/node/src/service/mod.rs +++ b/crates/node/src/service/mod.rs @@ -1,11 +1,11 @@ mod block_production; mod gateway; mod l1; +mod l2; mod rpc; -mod sync; pub use block_production::BlockProductionService; pub use gateway::GatewayService; pub use l1::L1SyncService; +pub use l2::L2SyncService; pub use rpc::RpcService; -pub use sync::L2SyncService; diff --git a/crates/node/src/service/rpc/mod.rs b/crates/node/src/service/rpc/mod.rs index 555befb9e..54ac1aafc 100644 --- a/crates/node/src/service/rpc/mod.rs +++ b/crates/node/src/service/rpc/mod.rs @@ -1,11 +1,13 @@ use std::sync::Arc; use jsonrpsee::server::ServerHandle; -use tokio::task::JoinSet; use mc_db::MadaraBackend; -use mc_rpc::{providers::AddTransactionProvider, rpc_api_admin, rpc_api_user, Starknet}; -use mp_utils::service::{MadaraService, Service, ServiceContext}; +use mc_rpc::{ + providers::{AddTransactionProvider, AddTransactionProviderGroup}, + rpc_api_admin, rpc_api_user, Starknet, +}; +use mp_utils::service::{MadaraServiceId, PowerOfTwo, Service, ServiceId, ServiceRunner}; use metrics::RpcMetrics; use server::{start_server, ServerConfig}; @@ -18,93 +20,126 @@ mod metrics; mod middleware; mod server; +#[derive(Clone)] +pub enum RpcType { + User, + Admin, +} + pub struct RpcService { config: RpcParams, backend: Arc<MadaraBackend>, - add_txs_method_provider: Arc<dyn AddTransactionProvider>, - server_handle_user: Option<ServerHandle>, - server_handle_admin: Option<ServerHandle>, + add_txs_provider_l2_sync: Arc<dyn AddTransactionProvider>, + add_txs_provider_mempool: Arc<dyn AddTransactionProvider>, + server_handle: Option<ServerHandle>, + rpc_type: RpcType, } impl RpcService { - pub fn new( + pub fn user( config: RpcParams, backend: Arc<MadaraBackend>, - add_txs_method_provider: Arc<dyn AddTransactionProvider>, + add_txs_provider_l2_sync: Arc<dyn AddTransactionProvider>, + add_txs_provider_mempool: Arc<dyn AddTransactionProvider>, ) -> Self { - Self { config, backend, add_txs_method_provider, server_handle_user: None, server_handle_admin: None } + Self { + config, + backend, + add_txs_provider_l2_sync, + add_txs_provider_mempool, + server_handle: None, + rpc_type: RpcType::User, + } + } + + pub fn admin( + config: RpcParams, + backend: Arc<MadaraBackend>, + add_txs_provider_l2_sync: Arc<dyn AddTransactionProvider>, + add_txs_provider_mempool: Arc<dyn AddTransactionProvider>, + ) -> Self { + Self { + config, + backend, + add_txs_provider_l2_sync, + add_txs_provider_mempool, + server_handle: None, + rpc_type: RpcType::Admin, + } } } #[async_trait::async_trait] impl Service for RpcService { - async fn start(&mut self, join_set: &mut JoinSet<anyhow::Result<()>>, ctx: ServiceContext) -> anyhow::Result<()> { - let RpcService { config, backend, add_txs_method_provider, .. } = self; - - let starknet = - Starknet::new(backend.clone(), add_txs_method_provider.clone(), config.storage_proof_config(), ctx.clone()); - let metrics = RpcMetrics::register()?; - - let server_config_user = if !config.rpc_disable { - let api_rpc_user = rpc_api_user(&starknet)?; - let methods_user = rpc_api_build("rpc", api_rpc_user).into(); - - Some(ServerConfig { - name: "JSON-RPC".to_string(), - addr: config.addr_user(), - batch_config: config.batch_config(), - max_connections: config.rpc_max_connections, - max_payload_in_mb: config.rpc_max_request_size, - max_payload_out_mb: config.rpc_max_response_size, - max_subs_per_conn: config.rpc_max_subscriptions_per_connection, - message_buffer_capacity: config.rpc_message_buffer_capacity_per_connection, - methods: methods_user, - metrics: metrics.clone(), - cors: config.cors(), - rpc_version_default: mp_chain_config::RpcVersion::RPC_VERSION_LATEST, - }) - } else { - None - }; - - let server_config_admin = if config.rpc_admin { - let api_rpc_admin = rpc_api_admin(&starknet)?; - let methods_admin = rpc_api_build("admin", api_rpc_admin).into(); - - Some(ServerConfig { - name: "JSON-RPC (Admin)".to_string(), - addr: config.addr_admin(), - batch_config: config.batch_config(), - max_connections: config.rpc_max_connections, - max_payload_in_mb: config.rpc_max_request_size, - max_payload_out_mb: config.rpc_max_response_size, - max_subs_per_conn: config.rpc_max_subscriptions_per_connection, - message_buffer_capacity: config.rpc_message_buffer_capacity_per_connection, - methods: methods_admin, - metrics, - cors: config.cors(), - rpc_version_default: mp_chain_config::RpcVersion::RPC_VERSION_LATEST_ADMIN, - }) - } else { - None - }; - - if let Some(server_config) = &server_config_user { - // rpc enabled - self.server_handle_user = Some(start_server(server_config.clone(), join_set, ctx.clone()).await?); - } - - if let Some(server_config) = &server_config_admin { - // rpc enabled (admin) - let ctx = ctx.child().with_id(MadaraService::RpcAdmin); - ctx.service_add(MadaraService::RpcAdmin); - self.server_handle_admin = Some(start_server(server_config.clone(), join_set, ctx).await?); - } - - Ok(()) + async fn start<'a>(&mut self, runner: ServiceRunner<'a>) -> anyhow::Result<()> { + let config = self.config.clone(); + let backend = Arc::clone(&self.backend); + let add_tx_provider_l2_sync = Arc::clone(&self.add_txs_provider_l2_sync); + let add_tx_provider_mempool = Arc::clone(&self.add_txs_provider_mempool); + let rpc_type = self.rpc_type.clone(); + + let (stop_handle, server_handle) = jsonrpsee::server::stop_channel(); + + self.server_handle = Some(server_handle); + + runner.service_loop(move |ctx| async move { + let add_tx_provider = Arc::new(AddTransactionProviderGroup::new( + add_tx_provider_l2_sync, + add_tx_provider_mempool, + ctx.clone(), + )); + + let starknet = Starknet::new(backend.clone(), add_tx_provider, config.storage_proof_config(), ctx.clone()); + let metrics = RpcMetrics::register()?; + + let server_config = { + let (name, addr, api_rpc, rpc_version_default) = match rpc_type { + RpcType::User => ( + "JSON-RPC".to_string(), + config.addr_user(), + rpc_api_user(&starknet)?, + mp_chain_config::RpcVersion::RPC_VERSION_LATEST, + ), + RpcType::Admin => ( + "JSON-RPC (Admin)".to_string(), + config.addr_admin(), + rpc_api_admin(&starknet)?, + mp_chain_config::RpcVersion::RPC_VERSION_LATEST_ADMIN, + ), + }; + let methods = rpc_api_build("rpc", api_rpc).into(); + + ServerConfig { + name, + addr, + batch_config: config.batch_config(), + max_connections: config.rpc_max_connections, + max_payload_in_mb: config.rpc_max_request_size, + max_payload_out_mb: config.rpc_max_response_size, + max_subs_per_conn: config.rpc_max_subscriptions_per_connection, + message_buffer_capacity: config.rpc_message_buffer_capacity_per_connection, + methods, + metrics, + cors: config.cors(), + rpc_version_default, + } + }; + + start_server(server_config, ctx.clone(), stop_handle).await?; + + anyhow::Ok(()) + }); + + anyhow::Ok(()) } +} - fn id(&self) -> MadaraService { - MadaraService::Rpc +impl ServiceId for RpcService { + #[inline(always)] + fn svc_id(&self) -> PowerOfTwo { + match self.rpc_type { + RpcType::User => MadaraServiceId::RpcUser.svc_id(), + RpcType::Admin => MadaraServiceId::RpcAdmin.svc_id(), + } } } diff --git a/crates/node/src/service/rpc/server.rs b/crates/node/src/service/rpc/server.rs index 07e3a66f0..6dd4acd25 100644 --- a/crates/node/src/service/rpc/server.rs +++ b/crates/node/src/service/rpc/server.rs @@ -7,11 +7,8 @@ use std::time::Duration; use anyhow::Context; use mp_utils::service::ServiceContext; -use tokio::task::JoinSet; use tower::Service; -use mp_utils::wait_or_graceful_shutdown; - use crate::service::rpc::middleware::RpcMiddlewareServiceVersion; use super::metrics::RpcMetrics; @@ -46,11 +43,13 @@ struct PerConnection<RpcMiddleware, HttpMiddleware> { } /// Start RPC server listening on given address. -pub async fn start_server( +/// +/// This future will complete once the server has been shutdown. +pub async fn start_server<'a>( config: ServerConfig, - join_set: &mut JoinSet<anyhow::Result<()>>, - ctx: ServiceContext, -) -> anyhow::Result<jsonrpsee::server::ServerHandle> { + mut ctx: ServiceContext, + stop_handle: jsonrpsee::server::StopHandle, +) -> anyhow::Result<()> { let ServerConfig { name, addr, @@ -91,7 +90,6 @@ pub async fn start_server( .set_http_middleware(http_middleware) .set_id_provider(jsonrpsee::server::RandomStringIdProvider::new(16)); - let (stop_handle, server_handle) = jsonrpsee::server::stop_channel(); let cfg = PerConnection { methods, stop_handle: stop_handle.clone(), @@ -109,6 +107,7 @@ pub async fn start_server( Ok::<_, Infallible>(hyper::service::service_fn(move |req| { let PerConnection { service_builder, metrics, stop_handle, methods } = cfg.clone(); + let ctx1 = ctx1.clone(); let is_websocket = jsonrpsee::server::ws::is_upgrade_request(&req); let transport_label = if is_websocket { "ws" } else { "http" }; @@ -122,10 +121,9 @@ pub async fn start_server( .layer(metrics_layer.clone()); let mut svc = service_builder.set_rpc_middleware(rpc_middleware).build(methods, stop_handle); - let ctx1 = ctx1.clone(); async move { - if !ctx1.is_active() { + if ctx1.is_cancelled() { Ok(hyper::Response::builder() .status(hyper::StatusCode::GONE) .body(hyper::Body::from("GONE"))?) @@ -157,21 +155,18 @@ pub async fn start_server( .with_context(|| format!("Creating hyper server at: {addr}"))? .serve(make_service); - join_set.spawn(async move { - tracing::info!( - "📱 Running {name} server at {} (allowed origins={})", - local_addr.to_string(), - format_cors(cors.as_ref()) - ); - server - .with_graceful_shutdown(async { - wait_or_graceful_shutdown(stop_handle.shutdown(), &ctx).await; - }) - .await - .context("Running rpc server") - }); + tracing::info!( + "📱 Running {name} server at {} (allowed origins={})", + local_addr.to_string(), + format_cors(cors.as_ref()) + ); - Ok(server_handle) + server + .with_graceful_shutdown(async { + ctx.run_until_cancelled(stop_handle.shutdown()).await; + }) + .await + .context("Running rpc server") } // Copied from https://github.com/paritytech/polkadot-sdk/blob/a0aefc6b233ace0a82a8631d67b6854e6aeb014b/substrate/client/rpc-servers/src/utils.rs#L192 diff --git a/crates/primitives/transactions/src/lib.rs b/crates/primitives/transactions/src/lib.rs index d045d3ac9..ab5c27d99 100644 --- a/crates/primitives/transactions/src/lib.rs +++ b/crates/primitives/transactions/src/lib.rs @@ -23,7 +23,6 @@ const SIMULATE_TX_VERSION_OFFSET: Felt = Felt::from_hex_unchecked("0x10000000000 /// Legacy check for deprecated txs /// See `https://docs.starknet.io/documentation/architecture_and_concepts/Blocks/transactions/` for more details. - pub const LEGACY_BLOCK_NUMBER: u64 = 1470; pub const V0_7_BLOCK_NUMBER: u64 = 833; diff --git a/crates/primitives/utils/Cargo.toml b/crates/primitives/utils/Cargo.toml index 6688bb84c..aab8d40c2 100644 --- a/crates/primitives/utils/Cargo.toml +++ b/crates/primitives/utils/Cargo.toml @@ -26,6 +26,8 @@ anyhow.workspace = true async-trait.workspace = true crypto-bigint.workspace = true futures.workspace = true +num-traits.workspace = true +paste.workspace = true rand.workspace = true rayon.workspace = true serde.workspace = true @@ -53,6 +55,7 @@ tracing-subscriber = { workspace = true, features = ["env-filter"] } [dev-dependencies] rstest.workspace = true +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } [features] testing = [] diff --git a/crates/primitives/utils/src/lib.rs b/crates/primitives/utils/src/lib.rs index 5aac81131..d38d5098c 100644 --- a/crates/primitives/utils/src/lib.rs +++ b/crates/primitives/utils/src/lib.rs @@ -7,8 +7,6 @@ pub mod service; use std::time::{Duration, Instant}; -use futures::Future; -use service::ServiceContext; use tokio::sync::oneshot; /// Prefer this compared to [`tokio::spawn_blocking`], as spawn_blocking creates new OS threads and @@ -27,43 +25,6 @@ where rx.await.expect("tokio channel closed") } -async fn graceful_shutdown_inner(ctx: &ServiceContext) { - let sigterm = async { - match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) { - Ok(mut signal) => signal.recv().await, - // SIGTERM not supported - Err(_) => core::future::pending().await, - } - }; - - tokio::select! { - _ = tokio::signal::ctrl_c() => {}, - _ = sigterm => {}, - _ = ctx.cancelled() => {}, - }; - - ctx.cancel_local() -} -pub async fn graceful_shutdown(ctx: &ServiceContext) { - graceful_shutdown_inner(ctx).await -} - -/// Should be used with streams/channels `next`/`recv` function. -pub async fn wait_or_graceful_shutdown<T>(future: impl Future<Output = T>, ctx: &ServiceContext) -> Option<T> { - tokio::select! { - _ = graceful_shutdown_inner(ctx) => { None }, - res = future => { Some(res) }, - } -} - -/// Should be used with streams/channels `next`/`recv` function. -pub async fn channel_wait_or_graceful_shutdown<T>( - future: impl Future<Output = Option<T>>, - ctx: &ServiceContext, -) -> Option<T> { - wait_or_graceful_shutdown(future, ctx).await? -} - #[derive(Debug, Default)] pub struct StopHandle(Option<oneshot::Sender<()>>); diff --git a/crates/primitives/utils/src/service.rs b/crates/primitives/utils/src/service.rs index e8d5b1da0..ddba4eb93 100644 --- a/crates/primitives/utils/src/service.rs +++ b/crates/primitives/utils/src/service.rs @@ -1,127 +1,646 @@ -//! Service trait and combinators. +//! Madara Services Architecture +//! +//! Madara follows a [microservice](microservices) architecture to simplify the +//! composability and parallelism of its services. That is to say services can +//! be started in different orders, at different points in the program's +//! execution, stopped and even restarted. The advantage in parallelism arises +//! from the fact that each services runs as its own non-blocking asynchronous +//! task which allows for high throughput. Inter-service communication is done +//! via [tokio::sync] or more often through direct database reads and writes. +//! +//! --- +//! +//! # The [Service] trait +//! +//! This is the backbone of Madara services and serves as a common interface to +//! all. The [Service] trait specifies how a service must start as well as how +//! to _identify_ it. For reasons of atomicity, services are currently +//! identified by a single [std::sync::atomic::AtomicU64]. More about this later. +//! +//! Services are started with [Service::start] using [ServiceRunner::service_loop]. +//! [ServiceRunner::service_loop] is a function which takes in a future: this +//! future represents the main loop of your service, and should run until your +//! service completes or is cancelled. +//! +//! It is part of the contract of the [Service] trait that calls to +//! [ServiceRunner::service_loop] should not complete until the service has +//! _finished_ execution (this should be evident by the name) as this is used +//! to mark a service as complete and therefore ready to restart. Services where +//! [ServiceRunner::service_loop] completes _before_ the service has finished +//! execution will be automatically marked for shutdown as a safety mechanism. +//! This is done as a safeguard to avoid an invalid state where it would be +//! impossible for the node to shutdown. +//! +//! > **Note** +//! > It is assumed that services can and might be restarted. You have the +//! > responsibility to ensure this is possible. This means you should make sure +//! > not to use the like of [std::mem::take] or similar on your service inside +//! > [Service::start]. In general, make sure your service still contains all +//! > the necessary information it needs to restart. This might mean certain +//! > attributes need to be stored as a [std::sync::Arc] and cloned so that the +//! > future in [ServiceRunner::service_loop] can safely take ownership of them. +//! +//! ## An incorrect implementation of the [Service] trait +//! +//! ```rust +//! # use mp_utils::service::Service; +//! # use mp_utils::service::ServiceId; +//! # use mp_utils::service::PowerOfTwo; +//! # use mp_utils::service::ServiceRunner; +//! # use mp_utils::service::MadaraServiceId; +//! +//! pub struct MyService; +//! +//! #[async_trait::async_trait] +//! impl Service for MyService { +//! async fn start<'a>(&mut self, runner: ServiceRunner<'a>) -> anyhow::Result<()> { +//! runner.service_loop(move |ctx| async { +//! tokio::task::spawn(async { +//! tokio::time::sleep(std::time::Duration::MAX).await; +//! }); +//! +//! // This is incorrect, as the future passed to service_loop will +//! // resolve before the task spawned above completes, meaning +//! // Madara will incorrectly mark this service as ready to restart. +//! // In a more complex scenario, this means we might enter an +//! // invalid state! +//! anyhow::Ok(()) +//! }); +//! +//! anyhow::Ok(()) +//! } +//! } +//! +//! impl ServiceId for MyService { +//! fn svc_id(&self) -> PowerOfTwo { +//! MadaraServiceId::Monitor.svc_id() +//! } +//! } +//! ``` +//! +//! ## A correct implementation of the [Service] trait +//! +//! ```rust +//! # use mp_utils::service::Service; +//! # use mp_utils::service::ServiceId; +//! # use mp_utils::service::PowerOfTwo; +//! # use mp_utils::service::ServiceRunner; +//! # use mp_utils::service::MadaraServiceId; +//! +//! pub struct MyService; +//! +//! #[async_trait::async_trait] +//! impl Service for MyService { +//! async fn start<'a>(&mut self, runner: ServiceRunner<'a>) -> anyhow::Result<()> { +//! runner.service_loop(move |mut ctx| async move { +//! ctx.run_until_cancelled(tokio::time::sleep(std::time::Duration::MAX)).await; +//! +//! // This is correct, as the future passed to service_loop will +//! // only resolve once the task above completes, so Madara can +//! // correctly mark this service as ready to restart. +//! anyhow::Ok(()) +//! }); +//! +//! anyhow::Ok(()) +//! } +//! } +//! +//! impl ServiceId for MyService { +//! fn svc_id(&self) -> PowerOfTwo { +//! MadaraServiceId::Monitor.svc_id() +//! } +//! } +//! ``` +//! +//! Or if you really need to spawn a background task: +//! +//! ```rust +//! # use mp_utils::service::Service; +//! # use mp_utils::service::ServiceId; +//! # use mp_utils::service::PowerOfTwo; +//! # use mp_utils::service::ServiceRunner; +//! # use mp_utils::service::MadaraServiceId; +//! +//! pub struct MyService; +//! +//! #[async_trait::async_trait] +//! impl Service for MyService { +//! async fn start<'a>(&mut self, runner: ServiceRunner<'a>) -> anyhow::Result<()> { +//! runner.service_loop(move |mut ctx| async move { +//! let mut ctx1 = ctx.clone(); +//! tokio::task::spawn(async move { +//! ctx1.run_until_cancelled(tokio::time::sleep(std::time::Duration::MAX)).await; +//! }); +//! +//! ctx.cancelled().await; +//! +//! // This is correct, as even though we are spawning a background +//! // task we have implemented a cancellation mechanism with ctx +//! // and are waiting for that cancellation in service_loop. +//! anyhow::Ok(()) +//! }); +//! +//! anyhow::Ok(()) +//! } +//! } +//! +//! impl ServiceId for MyService { +//! fn svc_id(&self) -> PowerOfTwo { +//! MadaraServiceId::Monitor.svc_id() +//! } +//! } +//! ``` +//! +//! This sort of problem generally arises in cases similar to the above, where +//! the service's role is to spawn another background task. This is can happen +//! when the service needs to start a server for example. Either avoid spawning +//! a detached task or use mechanisms such as [ServiceContext::cancelled] to +//! await for the service's completion. +//! +//! Note that service shutdown is designed to be manual. We still implement a +//! [SERVICE_GRACE_PERIOD] which is the maximum duration a service is allowed +//! to take to shutdown, after which it is forcefully cancelled. This should not +//! happen in practice and only serves to avoid cases where someone would forget +//! to implement a cancellation check. More on this in the next section. +//! +//! --- +//! +//! # Cancellation status and inter-process requests +//! +//! Services are passed a [ServiceContext] as part of [ServiceRunner::service_loop] +//! to be used during their execution to check for and request cancellation. +//! Services can also start child services with [ServiceContext::child] to +//! create a hierarchy of services. +//! +//! ## Cancellation checks +//! +//! The main advantage of [ServiceContext] is that it allows you to gracefully +//! handle the shutdown of your services by checking for cancellation at logical +//! points in the execution, such as every iteration of a service's main loop. +//! You can use the following methods to check for cancellation, each with their +//! own caveats. +//! +//! - [ServiceContext::is_cancelled]: synchronous, useful in non-blocking +//! scenarios. +//! - [ServiceContext::cancelled]: a future which resolves upon service +//! cancellation. Useful to wait on a service or alongside [tokio::select]. +//! +//! > **Warning** +//! > It is your responsibility to check for cancellation inside of your +//! > service. If you do not, or your service takes longer than +//! > [SERVICE_GRACE_PERIOD] to shutdown, then your service will be forcefully +//! > cancelled. +//! +//! ## Cancellation requests +//! +//! Any service with access to a [ServiceContext] can request the cancellation +//! of _any other service, at any point during execution_. This can be used for +//! error handling for example, by having a single service shut itself down +//! without affecting other services, or for administrative and testing purposes +//! by having a node operator toggle services on and off from a remote endpoint. +//! +//! You can use the following methods to request for the cancellation of a +//! service: +//! +//! - [ServiceContext::cancel_global]: cancels all services. +//! - [ServiceContext::cancel_local]: cancels this service and all its children. +//! - [ServiceContext::service_remove]: cancel a specific service. +//! +//! ## Start requests +//! +//! You can _request_ for a service to be restarted by calling +//! [ServiceContext::service_add]. This is not guaranteed to work, and will fail +//! if the service is already running or if it has not been registered to +//! [the set of global services](#service-orchestration) at the start of the +//! program. +//! +//! ## Atomic status checks +//! +//! All service updates and checks are performed atomically with the use of +//! [tokio_util::sync::CancellationToken] and [MadaraServiceMask], which is a +//! [std::sync::atomic::AtomicU64] bitmask with strong [std::sync::atomic::Ordering::SeqCst] +//! cross-thread ordering of operations. Services are represented as a unique +//! [PowerOfTwo] which is provided through the [ServiceId] trait. +//! +//! > **Note** +//! > The use of [std::sync::atomic::AtomicU64] limits the number of possible +//! > services to 64. This might be increased in the future if there is a +//! > demonstrable need for more services, but right now this limit seems +//! > high enough. +//! +//! --- +//! +//! # Service orchestration +//! +//! Services are orchestrated by a [ServiceMonitor], which is responsible for +//! registering services, marking them as active or inactive as well as starting +//! and restarting them upon request. [ServiceMonitor] also handles the +//! cancellation of all services upon receiving a `SIGINT` or `SIGTERM`. +//! +//! > **Important** +//! > Services cannot be started or restarted if they have not been registered +//! > with [ServiceMonitor::with]. +//! +//! Services are run to completion until no service remains, at which point the +//! node will automatically shutdown. +//! +//! [microservices]: https://en.wikipedia.org/wiki/Microservices use anyhow::Context; -use std::{fmt::Display, panic, sync::Arc}; +use futures::Future; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::{Debug, Display}, + panic, + sync::Arc, + time::Duration, +}; use tokio::task::JoinSet; -#[repr(u8)] -#[derive(Clone, Copy, PartialEq, Eq, Default, Debug)] -pub enum MadaraService { +/// Maximum potential number of services that a [ServiceRunner] can run at once +pub const SERVICE_COUNT_MAX: usize = 64; + +/// Maximum duration a service is allowed to take to shutdown, after which it +/// will be forcefully cancelled +pub const SERVICE_GRACE_PERIOD: Duration = Duration::from_secs(10); + +macro_rules! power_of_two { + ( $($pow:literal),* ) => { + paste::paste! { + #[repr(u64)] + #[derive(Clone, Copy, PartialEq, Eq, Default, Debug)] + pub enum PowerOfTwo { + #[default] + ZERO = 0, + $( + [<P $pow>] = 1u64 << $pow, + )* + } + + impl PowerOfTwo { + /// Converts a [PowerOfTwo] into a unique index which can be + /// used in an arrray + pub fn index(&self) -> usize { + match self { + Self::ZERO => 0, + $( + Self::[<P $pow>] => $pow, + )* + } + } + } + + impl ServiceId for PowerOfTwo { + fn svc_id(&self) -> PowerOfTwo { + *self + } + } + + impl Display for PowerOfTwo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", *self as u64) + } + } + + impl TryFrom<u8> for PowerOfTwo { + type Error = anyhow::Error; + + fn try_from(pow: u8) -> anyhow::Result<Self> { + TryFrom::<u64>::try_from(pow as u64) + } + } + + impl TryFrom<u16> for PowerOfTwo { + type Error = anyhow::Error; + + fn try_from(pow: u16) -> anyhow::Result<Self> { + TryFrom::<u64>::try_from(pow as u64) + } + } + + impl TryFrom<u32> for PowerOfTwo { + type Error = anyhow::Error; + + fn try_from(pow: u32) -> anyhow::Result<Self> { + TryFrom::<u64>::try_from(pow as u64) + } + } + + impl TryFrom<u64> for PowerOfTwo + { + type Error = anyhow::Error; + + fn try_from(pow: u64) -> anyhow::Result<Self> { + $( + const [<P $pow>]: u64 = 1 << $pow; + )* + + let pow: u64 = pow.into(); + match pow { + 0 => anyhow::Ok(Self::ZERO), + $( + [<P $pow>] => anyhow::Ok(Self::[<P $pow>]), + )* + _ => anyhow::bail!("Not a power of two: {pow}"), + } + } + } + } + }; +} + +power_of_two!( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63 +); + +/// The core [Service]s available in Madara. +/// +/// Note that [PowerOfTwo::ZERO] represents [MadaraServiceId::Monitor] as +/// [ServiceMonitor] is always running and therefore is the genesis state of all +/// other services. +#[derive(Clone, Copy, PartialEq, Eq, Default, Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum MadaraServiceId { #[default] - None = 0, - Database = 1, - L1Sync = 2, - L2Sync = 4, - BlockProduction = 8, - Rpc = 16, - RpcAdmin = 32, - Gateway = 64, - Telemetry = 128, -} - -impl Display for MadaraService { + #[serde(skip)] + Monitor, + #[serde(skip)] + Database, + L1Sync, + L2Sync, + BlockProduction, + #[serde(rename = "rpc")] + RpcUser, + #[serde(skip)] + RpcAdmin, + Gateway, + Telemetry, +} + +impl ServiceId for MadaraServiceId { + #[inline(always)] + fn svc_id(&self) -> PowerOfTwo { + match self { + MadaraServiceId::Monitor => PowerOfTwo::ZERO, + MadaraServiceId::Database => PowerOfTwo::P0, + MadaraServiceId::L1Sync => PowerOfTwo::P1, + MadaraServiceId::L2Sync => PowerOfTwo::P2, + MadaraServiceId::BlockProduction => PowerOfTwo::P3, + MadaraServiceId::RpcUser => PowerOfTwo::P4, + MadaraServiceId::RpcAdmin => PowerOfTwo::P5, + MadaraServiceId::Gateway => PowerOfTwo::P6, + MadaraServiceId::Telemetry => PowerOfTwo::P7, + } + } +} + +impl Display for MadaraServiceId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "{}", match self { - MadaraService::None => "none", - MadaraService::Database => "database", - MadaraService::L1Sync => "l1 sync", - MadaraService::L2Sync => "l2 sync", - MadaraService::BlockProduction => "block production", - MadaraService::Rpc => "rpc", - MadaraService::RpcAdmin => "rpc admin", - MadaraService::Gateway => "gateway", - MadaraService::Telemetry => "telemetry", + Self::Monitor => "monitor", + Self::Database => "database", + Self::L1Sync => "l1 sync", + Self::L2Sync => "l2 sync", + Self::BlockProduction => "block production", + Self::RpcUser => "rpc user", + Self::RpcAdmin => "rpc admin", + Self::Gateway => "gateway", + Self::Telemetry => "telemetry", } ) } } +impl std::ops::BitOr for MadaraServiceId { + type Output = u64; + + fn bitor(self, rhs: Self) -> Self::Output { + self.svc_id() as u64 | rhs.svc_id() as u64 + } +} + +impl std::ops::BitAnd for MadaraServiceId { + type Output = u64; + + fn bitand(self, rhs: Self) -> Self::Output { + self.svc_id() as u64 & rhs.svc_id() as u64 + } +} + +impl From<PowerOfTwo> for MadaraServiceId { + fn from(value: PowerOfTwo) -> Self { + match value { + PowerOfTwo::ZERO => Self::Monitor, + PowerOfTwo::P0 => Self::Database, + PowerOfTwo::P1 => Self::L1Sync, + PowerOfTwo::P2 => Self::L2Sync, + PowerOfTwo::P3 => Self::BlockProduction, + PowerOfTwo::P4 => Self::RpcUser, + PowerOfTwo::P5 => Self::RpcAdmin, + PowerOfTwo::P6 => Self::Gateway, + _ => Self::Telemetry, + } + } +} + +// A boolean status enum, for clarity's sake +#[derive(PartialEq, Eq, Clone, Copy, Default, Serialize, Deserialize)] +pub enum MadaraServiceStatus { + On, + #[default] + Off, +} + +impl Display for MadaraServiceStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::On => "on", + Self::Off => "off", + } + ) + } +} + +impl std::ops::BitOr for MadaraServiceStatus { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + if self.is_on() || rhs.is_on() { + MadaraServiceStatus::On + } else { + MadaraServiceStatus::Off + } + } +} + +impl std::ops::BitOr for &MadaraServiceStatus { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + if self.is_on() || rhs.is_on() { + &MadaraServiceStatus::On + } else { + &MadaraServiceStatus::Off + } + } +} + +impl std::ops::BitAnd for MadaraServiceStatus { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + if self.is_on() && rhs.is_on() { + MadaraServiceStatus::On + } else { + MadaraServiceStatus::Off + } + } +} + +impl std::ops::BitAnd for &MadaraServiceStatus { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + if self.is_on() && rhs.is_on() { + &MadaraServiceStatus::On + } else { + &MadaraServiceStatus::Off + } + } +} + +impl std::ops::BitOrAssign for MadaraServiceStatus { + fn bitor_assign(&mut self, rhs: Self) { + *self = if self.is_on() || rhs.is_on() { MadaraServiceStatus::On } else { MadaraServiceStatus::Off } + } +} + +impl std::ops::BitAndAssign for MadaraServiceStatus { + fn bitand_assign(&mut self, rhs: Self) { + *self = if self.is_on() && rhs.is_on() { MadaraServiceStatus::On } else { MadaraServiceStatus::Off } + } +} + +impl From<bool> for MadaraServiceStatus { + fn from(value: bool) -> Self { + match value { + true => Self::On, + false => Self::Off, + } + } +} + +impl MadaraServiceStatus { + #[inline(always)] + pub fn is_on(&self) -> bool { + self == &MadaraServiceStatus::On + } + + #[inline(always)] + pub fn is_off(&self) -> bool { + self == &MadaraServiceStatus::Off + } +} + +/// An atomic bitmask of each [MadaraServiceId]'s status with strong +/// [std::sync::atomic::Ordering::SeqCst] cross-thread ordering of operations. #[repr(transparent)] #[derive(Default)] -pub struct MadaraServiceMask(std::sync::atomic::AtomicU8); +pub struct MadaraServiceMask(std::sync::atomic::AtomicU64); impl MadaraServiceMask { #[cfg(feature = "testing")] pub fn new_for_testing() -> Self { - Self(std::sync::atomic::AtomicU8::new(u8::MAX)) + Self(std::sync::atomic::AtomicU64::new(u64::MAX)) } #[inline(always)] - pub fn is_active(&self, cap: u8) -> bool { - self.0.load(std::sync::atomic::Ordering::SeqCst) & cap > 0 + pub fn status(&self, svc: impl ServiceId) -> MadaraServiceStatus { + (self.value() & svc.svc_id() as u64 > 0).into() } #[inline(always)] - pub fn activate(&self, cap: MadaraService) -> bool { - let prev = self.0.fetch_or(cap as u8, std::sync::atomic::Ordering::SeqCst); - prev & cap as u8 > 0 + pub fn is_active_some(&self) -> bool { + self.value() > 0 } #[inline(always)] - pub fn deactivate(&self, cap: MadaraService) -> bool { - let cap = cap as u8; - let prev = self.0.fetch_and(!cap, std::sync::atomic::Ordering::SeqCst); - prev & cap > 0 + pub fn activate(&self, svc: impl ServiceId) -> MadaraServiceStatus { + let prev = self.0.fetch_or(svc.svc_id() as u64, std::sync::atomic::Ordering::SeqCst); + (prev & svc.svc_id() as u64 > 0).into() } -} -#[repr(u8)] -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)] -pub enum MadaraState { - #[default] - Starting, - Warp, - Running, - Shutdown, -} + #[inline(always)] + pub fn deactivate(&self, svc: impl ServiceId) -> MadaraServiceStatus { + let svc = svc.svc_id() as u64; + let prev = self.0.fetch_and(!svc, std::sync::atomic::Ordering::SeqCst); + (prev & svc > 0).into() + } -impl From<u8> for MadaraState { - fn from(value: u8) -> Self { - match value { - 0 => Self::Starting, - 1 => Self::Warp, - 2 => Self::Running, - _ => Self::Shutdown, + fn active_set(&self) -> Vec<MadaraServiceId> { + let mut i = MadaraServiceId::Telemetry.svc_id() as u64; + let state = self.value(); + let mut set = Vec::with_capacity(SERVICE_COUNT_MAX); + + while i > 0 { + let mask = state & i; + + if mask > 0 { + let pow = PowerOfTwo::try_from(mask).expect("mask is a power of 2"); + set.push(MadaraServiceId::from(pow)); + } + + i >>= 1; } + + set + } + + fn value(&self) -> u64 { + self.0.load(std::sync::atomic::Ordering::SeqCst) } } -/// Atomic state and cancellation context associated to a Service. +/// Atomic state and cancellation context associated to a [Service]. /// /// # Scope /// -/// You can create a hierarchy of services by calling `ServiceContext::branch_local`. +/// You can create a hierarchy of services by calling [ServiceContext::child]. /// Services are said to be in the same _local scope_ if they inherit the same -/// `token_local` cancellation token. You can think of services being local -/// if they can cancel each other without affecting the rest of the app (this -/// is not exact but it serves as a good mental model). +/// `token_local` [tokio_util::sync::CancellationToken]. You can think of +/// services being local if they can cancel each other without affecting the +/// rest of the app. /// -/// All services which descend from the same context are also said to be in the -/// same _global scope_, that is to say any service in this scope can cancel -/// _all_ other services in the same scope (including child services) at any -/// time. This is true of services in the same [ServiceGroup] for example. +/// All services which are derived from the same [ServiceContext] are said to +/// be in the same _global scope_, that is to say any service in this scope can +/// cancel _all_ other services in the same scope (including child services) at +/// any time. This is true of services in the same [ServiceMonitor] for example. /// -/// # Services +/// # Service hierarchy /// /// - A services is said to be a _child service_ if it uses a context created -/// with `ServiceContext::branch_local` +/// with [ServiceContext::child] /// /// - A service is said to be a _parent service_ if it uses a context which was /// used to create child services. /// /// > A parent services can always cancel all of its child services, but a child /// > service cannot cancel its parent service. -#[cfg_attr(not(feature = "testing"), derive(Default))] pub struct ServiceContext { token_global: tokio_util::sync::CancellationToken, token_local: Option<tokio_util::sync::CancellationToken>, services: Arc<MadaraServiceMask>, - services_notify: Arc<tokio::sync::Notify>, - state: Arc<std::sync::atomic::AtomicU8>, - id: MadaraService, + service_update_sender: Arc<tokio::sync::broadcast::Sender<ServiceTransport>>, + service_update_receiver: Option<tokio::sync::broadcast::Receiver<ServiceTransport>>, + id: PowerOfTwo, } impl Clone for ServiceContext { @@ -130,142 +649,271 @@ impl Clone for ServiceContext { token_global: self.token_global.clone(), token_local: self.token_local.clone(), services: Arc::clone(&self.services), - services_notify: Arc::clone(&self.services_notify), - state: Arc::clone(&self.state), + service_update_sender: Arc::clone(&self.service_update_sender), + service_update_receiver: None, id: self.id, } } } -impl ServiceContext { - pub fn new() -> Self { +impl Default for ServiceContext { + fn default() -> Self { Self { token_global: tokio_util::sync::CancellationToken::new(), token_local: None, services: Arc::new(MadaraServiceMask::default()), - services_notify: Arc::new(tokio::sync::Notify::new()), - state: Arc::new(std::sync::atomic::AtomicU8::new(MadaraState::default() as u8)), - id: MadaraService::default(), + service_update_sender: Arc::new(tokio::sync::broadcast::channel(SERVICE_COUNT_MAX).0), + service_update_receiver: None, + id: MadaraServiceId::Monitor.svc_id(), } } +} + +impl ServiceContext { + /// Creates a new [Default] [ServiceContext] + pub fn new() -> Self { + Self::default() + } #[cfg(feature = "testing")] pub fn new_for_testing() -> Self { - Self { - token_global: tokio_util::sync::CancellationToken::new(), - token_local: None, - services: Arc::new(MadaraServiceMask::new_for_testing()), - services_notify: Arc::new(tokio::sync::Notify::new()), - state: Arc::new(std::sync::atomic::AtomicU8::new(MadaraState::default() as u8)), - id: MadaraService::default(), - } + Self { services: Arc::new(MadaraServiceMask::new_for_testing()), ..Default::default() } + } + + /// Creates a new [Default] [ServiceContext] with the state of its services + /// set to the specified value. + pub fn new_with_services(services: Arc<MadaraServiceMask>) -> Self { + Self { services, ..Default::default() } } /// Stops all services under the same global context scope. pub fn cancel_global(&self) { + tracing::info!("🔌 Gracefully shutting down node"); + self.token_global.cancel(); } /// Stops all services under the same local context scope. /// - /// A local context is created by calling `branch_local` and allows you to - /// reduce the scope of cancellation only to those services which will use - /// the new context. + /// A local context is created by calling [ServiceContext::child] and allows + /// you to reduce the scope of cancellation only to those services which + /// will use the new context. pub fn cancel_local(&self) { self.token_local.as_ref().unwrap_or(&self.token_global).cancel(); } /// A future which completes when the service associated to this - /// [ServiceContext] is canceled. + /// [ServiceContext] is cancelled. /// - /// This happens after calling [ServiceContext::cancel_local] or - /// [ServiceContext::cancel_global]. + /// This allows for more manual implementation of cancellation logic than + /// [ServiceContext::run_until_cancelled], and should only be used in cases + /// where using `run_until_cancelled` is not possible or would be less + /// clear. /// - /// Use this to race against other futures in a [tokio::select] for example. + /// A service is cancelled after calling [ServiceContext::cancel_local], + /// [ServiceContext::cancel_global] or if it is marked for removal with + /// [ServiceContext::service_remove]. + /// + /// Use this to race against other futures in a [tokio::select] or keep a + /// coroutine alive for as long as the service itself. #[inline(always)] - pub async fn cancelled(&self) { - if self.state() != MadaraState::Shutdown { - match &self.token_local { - Some(token_local) => tokio::select! { - _ = self.token_global.cancelled() => {}, - _ = token_local.cancelled() => {} - }, - None => tokio::select! { - _ = self.token_global.cancelled() => {}, - }, + pub async fn cancelled(&mut self) { + if self.service_update_receiver.is_none() { + self.service_update_receiver = Some(self.service_update_sender.subscribe()); + } + + let mut rx = self.service_update_receiver.take().expect("Receiver was set above"); + let token_global = &self.token_global; + let token_local = self.token_local.as_ref().unwrap_or(&self.token_global); + + loop { + // We keep checking for service status updates until a token has + // been cancelled or this service was deactivated + let res = tokio::select! { + svc = rx.recv() => svc.ok(), + _ = token_global.cancelled() => break, + _ = token_local.cancelled() => break + }; + + if let Some(ServiceTransport { svc_id, status }) = res { + if svc_id == self.id && status == MadaraServiceStatus::Off { + return; + } } } } - /// Check if the service associated to this [ServiceContext] was canceled. + /// Checks if the service associated to this [ServiceContext] was cancelled. + /// + /// This happens after calling [ServiceContext::cancel_local], + /// [ServiceContext::cancel_global] or [ServiceContext::service_remove]. /// - /// This happens after calling [ServiceContext::cancel_local] or - /// [ServiceContext::cancel_global]. + /// # Limitations + /// + /// This function should _not_ be used when waiting on potentially + /// blocking futures which can be cancelled without entering an invalid + /// state. The latter is important, so let's break this down. + /// + /// - _blocking future_: this is blocking at a [Service] level, not at the + /// node level. A blocking task in this sense in a task which prevents a + /// service from making progress in its execution, but not necessarily the + /// rest of the node. A prime example of this is when you are waiting on + /// a channel, and updates to that channel are sparse, or even unique. + /// + /// - _entering an invalid state_: the entire point of [ServiceContext] is + /// to allow services to gracefully shutdown. We do not want to be, for + /// example, racing each service against a global cancellation future, as + /// not every service might be cancellation safe (we still do this + /// somewhat with [SERVICE_GRACE_PERIOD] but this is a last resort and + /// should not execute in normal circumstances). Put differently, we do + /// not want to stop in the middle of a critical computation before it has + /// been saved to disk. + /// + /// Putting this together, [ServiceContext::is_cancelled] is only suitable + /// for checking cancellation alongside tasks which will not block the + /// running service, or in very specific circumstances where waiting on a + /// blocking future has higher precedence than shutting down the node. + /// + /// Examples of when to use [ServiceContext::is_cancelled]: + /// + /// - All your computation does is sleep or tick away a short period of + /// time. + /// - You are checking for cancellation inside of synchronous code. + /// + /// If this does not describe your usage, and you are waiting on a blocking + /// future, which is cancel-safe and which does not risk putting the node + /// in an invalid state if cancelled, then you should be using + /// [ServiceContext::cancelled] instead. #[inline(always)] pub fn is_cancelled(&self) -> bool { self.token_global.is_cancelled() || self.token_local.as_ref().map(|t| t.is_cancelled()).unwrap_or(false) - || !self.services.is_active(self.id as u8) - || self.state() == MadaraState::Shutdown + || self.services.status(self.id) == MadaraServiceStatus::Off + } + + /// Runs a [Future] until the [Service] associated to this [ServiceContext] + /// is cancelled. + /// + /// This happens after calling [ServiceContext::cancel_local], + /// [ServiceContext::cancel_global] or [ServiceContext::service_remove]. + /// + /// # Cancellation safety + /// + /// It is important that the future you pass to this function is _cancel- + /// safe_ as it will be forcefully shutdown if ever the service is cancelled. + /// This means your future might be interrupted at _any_ point in its + /// execution. + /// + /// Futures can be considered as cancel-safe in the context of Madara if + /// their computation can be interrupted at any point without causing any + /// side-effects to the running node. + /// + /// # Returns + /// + /// The return value of the future wrapped in [Some], or [None] if the + /// service was cancelled. + pub async fn run_until_cancelled<T, F>(&mut self, f: F) -> Option<T> + where + T: Sized + Send + Sync, + F: Future<Output = T>, + { + tokio::select! { + res = f => Some(res), + _ = self.cancelled() => None + } } - /// The id of service associated to this [ServiceContext] - pub fn id(&self) -> MadaraService { + /// The id of the [Service] associated to this [ServiceContext] + pub fn id(&self) -> PowerOfTwo { self.id } - /// Copies the context, maintaining its scope but with a new id. - pub fn with_id(mut self, id: MadaraService) -> Self { - self.id = id; + /// Sets the id of this [ServiceContext] + pub fn with_id(mut self, id: impl ServiceId) -> Self { + self.id = id.svc_id(); self } - /// Copies the context into a new local scope. + /// Creates a new [ServiceContext] as a child of the current context. /// - /// Any service which uses this new context will be able to cancel the + /// Any [Service] which uses this new context will be able to cancel the /// services in the same local scope as itself, and any further child /// services, without affecting the rest of the global scope. pub fn child(&self) -> Self { let token_local = self.token_local.as_ref().unwrap_or(&self.token_global).child_token(); - Self { - token_global: self.token_global.clone(), - token_local: Some(token_local), - services: Arc::clone(&self.services), - services_notify: Arc::clone(&self.services_notify), - state: Arc::clone(&self.state), - id: self.id, - } + Self { token_local: Some(token_local), ..Clone::clone(self) } } - /// Atomically checks if a set of services are running. - /// - /// You can combine multiple [MadaraService] into a single bitmask to - /// check the state of multiple services at once. + /// Atomically checks if a [Service] is running. #[inline(always)] - pub fn service_check(&self, cap: u8) -> bool { - self.services.is_active(cap) + pub fn service_status(&self, svc: impl ServiceId) -> MadaraServiceStatus { + self.services.status(svc) } - /// Atomically marks a service as active + /// Atomically marks a [Service] as active. /// /// This will immediately be visible to all services in the same global /// scope. This is true across threads. + /// + /// You can use [ServiceContext::service_subscribe] to subscribe to changes + /// in the status of any service. #[inline(always)] - pub fn service_add(&self, cap: MadaraService) -> bool { - let res = self.services.activate(cap); - self.services_notify.notify_waiters(); + pub fn service_add(&self, id: impl ServiceId) -> MadaraServiceStatus { + let svc_id = id.svc_id(); + let res = self.services.activate(id); + + // TODO: make an internal server error out of this + let _ = self.service_update_sender.send(ServiceTransport { svc_id, status: MadaraServiceStatus::On }); res } - /// Atomically marks a service as inactive + /// Atomically marks a [Service] as inactive. /// /// This will immediately be visible to all services in the same global /// scope. This is true across threads. + /// + /// You can use [ServiceContext::service_subscribe] to subscribe to changes + /// in the status of any service. #[inline(always)] - pub fn service_remove(&self, cap: MadaraService) -> bool { - self.services.deactivate(cap) + pub fn service_remove(&self, id: impl ServiceId) -> MadaraServiceStatus { + let svc_id = id.svc_id(); + let res = self.services.deactivate(id); + let _ = self.service_update_sender.send(ServiceTransport { svc_id, status: MadaraServiceStatus::Off }); + + res + } + + /// Opens up a new subscription which will complete once the status of a + /// [Service] has been updated. + /// + /// This subscription is stored on first call to this method and can be + /// accessed through the same instance of [ServiceContext]. + /// + /// # Returns + /// + /// Identifying information about the service which was updated as well + /// as its new [MadaraServiceStatus]. + pub async fn service_subscribe(&mut self) -> Option<ServiceTransport> { + if self.service_update_receiver.is_none() { + self.service_update_receiver = Some(self.service_update_sender.subscribe()); + } + + let mut rx = self.service_update_receiver.take().expect("Receiver was set above"); + let token_global = &self.token_global; + let token_local = self.token_local.as_ref().unwrap_or(&self.token_global); + + let res = tokio::select! { + svc = rx.recv() => svc.ok(), + _ = token_global.cancelled() => None, + _ = token_local.cancelled() => None + }; + + // ownership hack: `rx` cannot depend on a mutable borrow to `self` as we + // also depend on immutable borrows for `token_local` and `token_global` + self.service_update_receiver = Some(rx); + res } /// Atomically checks if the service associated to this [ServiceContext] is @@ -274,110 +922,414 @@ impl ServiceContext { /// This can be updated across threads by calling [ServiceContext::service_remove] /// or [ServiceContext::service_add] #[inline(always)] - pub fn is_active(&self) -> bool { - self.services.is_active(self.id as u8) - } - - /// Atomically checks the state of the node - #[inline(always)] - pub fn state(&self) -> MadaraState { - self.state.load(std::sync::atomic::Ordering::SeqCst).into() + pub fn status(&self) -> MadaraServiceStatus { + self.services.status(self.id) } +} - /// Atomically sets the state of the node - /// - /// This will immediately be visible to all services in the same global - /// scope. This is true across threads. - pub fn state_advance(&mut self) -> MadaraState { - let state = self.state.load(std::sync::atomic::Ordering::SeqCst).saturating_add(1); - self.state.store(state, std::sync::atomic::Ordering::SeqCst); - state.into() - } +/// Provides info about a [Service]'s status. +/// +/// Used as part of [ServiceContext::service_subscribe]. +#[derive(Clone, Copy)] +pub struct ServiceTransport { + pub svc_id: PowerOfTwo, + pub status: MadaraServiceStatus, } -/// The app is divided into services, with each service having a different responsability within the app. -/// Depending on the startup configuration, some services are enabled and some are disabled. +/// A microservice in the Madara node. +/// +/// The app is divided into services, with each service handling different +/// responsibilities within the app. Depending on the startup configuration, +/// some services are enabled and some are disabled. +/// +/// Services should be started with [ServiceRunner::service_loop]. +/// +/// # Writing your own service +/// +/// Writing a service involves two main steps: +/// +/// 1. Implementing the [Service] trait +/// 2. Implementing the [ServiceId] trait +/// +/// It is also recommended you create your own enum for storing service ids +/// which itself implements [ServiceId]. This helps keep your code organized as +/// [PowerOfTwo::P17] does not have much meaning in of itself. +/// +/// ```rust +/// # use mp_utils::service::Service; +/// # use mp_utils::service::ServiceId; +/// # use mp_utils::service::PowerOfTwo; +/// # use mp_utils::service::ServiceRunner; +/// # use mp_utils::service::ServiceMonitor; +/// +/// // This enum only exist to make it easier for us to remember which +/// // PowerOfTwo represents our services. +/// pub enum MyServiceId { +/// MyServiceA, +/// MyServiceB +/// } +/// +/// impl ServiceId for MyServiceId { +/// #[inline(always)] +/// fn svc_id(&self) -> PowerOfTwo { +/// match self { +/// // PowerOfTwo::P0 up until PowerOfTwo::P7 are already in use by +/// // MadaraServiceId, you should not use them! +/// Self::MyServiceA => PowerOfTwo::P8, +/// Self::MyServiceB => PowerOfTwo::P9, +/// } +/// } +/// } +/// +/// // Similarly, this enum is more explicit for our usecase than Option<T> +/// #[derive(Clone, Debug)] +/// pub enum Channel<T: Sized + Send + Sync> { +/// Open(T), +/// Closed +/// } +/// +/// // An example service, sends over 4 integers to `ServiceB` and the exits +/// struct MyServiceA(tokio::sync::broadcast::Sender<Channel<usize>>); +/// +/// #[async_trait::async_trait] +/// impl Service for MyServiceA { +/// async fn start<'a>(&mut self, runner: ServiceRunner<'a>) -> anyhow::Result<()> { +/// let mut sx = self.0.clone(); +/// +/// runner.service_loop(move |mut ctx| async move { +/// for i in 0..4 { +/// sx.send(Channel::Open(i))?; +/// +/// const SLEEP: std::time::Duration = std::time::Duration::from_secs(1); +/// ctx.run_until_cancelled(tokio::time::sleep(SLEEP)).await; +/// } +/// +/// // An important subtlety: we are using a broadcast channel to +/// // keep the connection alive between A and B even between +/// // restarts. To do this, we always keep a broadcast sender and +/// // receiver alive in A and B respectively, which we clone +/// // whenever either service starts. This means the channel won't +/// // close when the sender in A's service_loop is dropped! We need +/// // to explicitly notify B that it has received all the +/// // information A has to send to it, which is why we use the +/// // `Channel` enum. +/// sx.send(Channel::Closed); +/// +/// anyhow::Ok(()) +/// }); +/// +/// anyhow::Ok(()) +/// } +/// } +/// +/// impl ServiceId for MyServiceA { +/// fn svc_id(&self) -> PowerOfTwo { +/// MyServiceId::MyServiceA.svc_id() +/// } +/// } +/// +/// // An example service, listens for messages from `ServiceA` and the exits +/// struct MyServiceB(tokio::sync::broadcast::Receiver<Channel<usize>>); /// -/// This trait enables launching nested services and groups. +/// #[async_trait::async_trait] +/// impl Service for MyServiceB { +/// async fn start<'a>(&mut self, runner: ServiceRunner<'a>) -> anyhow::Result<()> { +/// let mut rx = self.0.resubscribe(); +/// +/// runner.service_loop(move |mut ctx| async move { +/// loop { +/// let i = tokio::select! { +/// res = rx.recv() => { +/// // As mentioned above, `res` will never receive an +/// // `Err(RecvError::Closed)` since we always keep a +/// // sender alive in A for restarts, so we manually +/// // check if the channel was closed. +/// match res? { +/// Channel::Open(i) => i, +/// Channel::Closed => break, +/// } +/// }, +/// // This is a case where using `ctx.run_until_cancelled` +/// // would probably be harder to read. +/// _ = ctx.cancelled() => break, +/// }; +/// +/// println!("MyServiceB received {i}"); +/// } +/// +/// anyhow::Ok(()) +/// }); +/// +/// anyhow::Ok(()) +/// } +/// } +/// +/// impl ServiceId for MyServiceB { +/// fn svc_id(&self) -> PowerOfTwo { +/// MyServiceId::MyServiceB.svc_id() +/// } +/// } +/// +/// #[tokio::main] +/// async fn main() -> anyhow::Result<()> { +/// let (sx, rx) = tokio::sync::broadcast::channel(16); +/// +/// let service_a = MyServiceA(sx); +/// let service_b = MyServiceB(rx); +/// +/// let monitor = ServiceMonitor::default() +/// .with(service_a)? +/// .with(service_b)?; +/// +/// // We can use `MyServiceId` directly here. Most service methods only +/// // require an `impl ServiceId`, so this kind of pattern is very much +/// // recommended. +/// monitor.activate(MyServiceId::MyServiceA); +/// monitor.activate(MyServiceId::MyServiceB); +/// +/// monitor.start().await?; +/// +/// anyhow::Ok(()) +/// } +/// ``` #[async_trait::async_trait] -pub trait Service: 'static + Send + Sync { +pub trait Service: 'static + Send + Sync + ServiceId { /// Default impl does not start any task. - async fn start(&mut self, _join_set: &mut JoinSet<anyhow::Result<()>>, _ctx: ServiceContext) -> anyhow::Result<()> { + async fn start<'a>(&mut self, _runner: ServiceRunner<'a>) -> anyhow::Result<()> { Ok(()) } +} + +/// Allows a [Service] to identify itself +/// +/// Services are identified using a unique [PowerOfTwo] +pub trait ServiceId { + fn svc_id(&self) -> PowerOfTwo; +} + +#[async_trait::async_trait] +impl Service for Box<dyn Service> { + async fn start<'a>(&mut self, _runner: ServiceRunner<'a>) -> anyhow::Result<()> { + self.as_mut().start(_runner).await + } +} + +impl ServiceId for Box<dyn Service> { + #[inline(always)] + fn svc_id(&self) -> PowerOfTwo { + self.as_ref().svc_id() + } +} + +/// Wrapper around a [tokio::task::JoinSet] and a [ServiceContext]. +/// +/// Used to enforce certain shutdown behavior onto [Service]s which are started +/// with [ServiceRunner::service_loop] +pub struct ServiceRunner<'a> { + ctx: ServiceContext, + join_set: &'a mut JoinSet<anyhow::Result<PowerOfTwo>>, +} - async fn start_and_drive_to_end(mut self) -> anyhow::Result<()> +impl<'a> ServiceRunner<'a> { + fn new(ctx: ServiceContext, join_set: &'a mut JoinSet<anyhow::Result<PowerOfTwo>>) -> Self { + Self { ctx, join_set } + } + + /// The main loop of a [Service]. + /// + /// The future passed to this function should complete _only once the + /// service completes or is cancelled_. Services that complete early will + /// automatically be cancelled. + /// + /// > **Caution** + /// > As a safety mechanism, services have up to [SERVICE_GRACE_PERIOD] + /// > to gracefully shutdown before they are forcefully cancelled. This + /// > should not execute in a normal context and only serves to prevent + /// > infinite loops on shutdown request if services have not been + /// > implemented correctly. + #[tracing::instrument(skip(self, runner), fields(module = "Service"))] + pub fn service_loop<F, E>(self, runner: impl FnOnce(ServiceContext) -> F + Send + 'static) where - Self: Sized, + F: Future<Output = Result<(), E>> + Send + 'static, + E: Into<anyhow::Error> + Send, { - let mut join_set = JoinSet::new(); - self.start(&mut join_set, ServiceContext::new()).await.context("Starting service")?; - drive_joinset(join_set).await + let Self { ctx, join_set } = self; + join_set.spawn(async move { + let id = ctx.id(); + if id != MadaraServiceId::Monitor.svc_id() { + tracing::debug!("Starting service with id: {id}"); + } + + // If a service is implemented correctly, `stopper` should never + // cancel first. This is a safety measure in case someone forgets to + // implement a cancellation check along some branch of the service's + // execution, or if they don't read the docs :D + let ctx1 = ctx.clone(); + tokio::select! { + res = runner(ctx) => res.map_err(Into::into)?, + _ = Self::stopper(ctx1, &id) => {}, + } + + if id != MadaraServiceId::Monitor.svc_id() { + tracing::debug!("Shutting down service with id: {id}"); + } + + Ok(id) + }); } - fn id(&self) -> MadaraService; + async fn stopper(mut ctx: ServiceContext, id: &PowerOfTwo) { + ctx.cancelled().await; + tokio::time::sleep(SERVICE_GRACE_PERIOD).await; + + tracing::warn!("⚠️ Forcefully shutting down service with id: {id}"); + } } -pub struct ServiceGroup { - services: Vec<Box<dyn Service>>, - join_set: Option<JoinSet<anyhow::Result<()>>>, +pub struct ServiceMonitor { + services: [Option<Box<dyn Service>>; SERVICE_COUNT_MAX], + join_set: JoinSet<anyhow::Result<PowerOfTwo>>, + status_request: Arc<MadaraServiceMask>, + status_actual: Arc<MadaraServiceMask>, } -impl Default for ServiceGroup { +impl Default for ServiceMonitor { fn default() -> Self { - Self { services: vec![], join_set: Some(Default::default()) } + Self { + services: [const { None }; SERVICE_COUNT_MAX], + join_set: JoinSet::new(), + status_request: Arc::default(), + status_actual: Arc::default(), + } } } -impl ServiceGroup { - pub fn new(services: Vec<Box<dyn Service>>) -> Self { - Self { services, join_set: Some(Default::default()) } - } +/// Orchestrates the execution of various [Service]s. +/// +/// A [ServiceMonitor] is responsible for registering services, starting and +/// stopping them as well as handling `SIGINT` and `SIGTERM`. Services are run +/// to completion until no service remains, at which point the node will +/// automatically shutdown. +/// +/// All services are inactive by default. Only the services which are marked as +/// _explicitly active_ with [ServiceMonitor::activate] will be automatically +/// started when calling [ServiceMonitor::start]. If no service was activated, +/// the node will shutdown. +/// +/// Note that services which are not added with [ServiceMonitor::with] cannot +/// be started or restarted. +impl ServiceMonitor { + /// Registers a [Service] to the [ServiceMonitor]. This service is + /// _inactive_ by default and can be started at a later time. + pub fn with(mut self, svc: impl Service) -> anyhow::Result<Self> { + let idx = svc.svc_id().index(); + self.services[idx] = match self.services[idx] { + Some(_) => anyhow::bail!("Services has already been added"), + None => Some(Box::new(svc)), + }; - /// Add a new service to the service group. - pub fn push(&mut self, value: impl Service) { - if self.join_set.is_none() { - panic!("Cannot add services to a group that has been started.") - } - self.services.push(Box::new(value)); + anyhow::Ok(self) } - pub fn with(mut self, value: impl Service) -> Self { - self.push(value); - self + /// Marks a [Service] as active, meaning it will be started automatically + /// when calling [ServiceMonitor::start]. + pub fn activate(&self, id: impl ServiceId) { + self.status_request.activate(id); } -} -#[async_trait::async_trait] -impl Service for ServiceGroup { - async fn start(&mut self, join_set: &mut JoinSet<anyhow::Result<()>>, ctx: ServiceContext) -> anyhow::Result<()> { - // drive the join set as a nested task - let mut own_join_set = self.join_set.take().expect("Service has already been started."); + /// Starts all activate [Service]s and runs them to completion. Services + /// are activated by calling [ServiceMonitor::activate]. This function + /// completes once all services have been run to completion. + /// + /// Keep in mind that services can be restarted as long as other services + /// are running (otherwise the node would shutdown). + #[tracing::instrument(skip(self), fields(module = "Service"))] + pub async fn start(mut self) -> anyhow::Result<()> { + let mut ctx = ServiceContext::new_with_services(Arc::clone(&self.status_request)); + + // start only the initially active services for svc in self.services.iter_mut() { - ctx.service_add(svc.id()); - svc.start(&mut own_join_set, ctx.child().with_id(svc.id())).await.context("Starting service")?; + match svc { + Some(svc) if self.status_request.status(svc.svc_id()) == MadaraServiceStatus::On => { + let id = svc.svc_id(); + self.status_actual.activate(id); + + let ctx = ctx.child().with_id(id); + let runner = ServiceRunner::new(ctx, &mut self.join_set); + svc.start(runner).await.context("Starting service")?; + } + _ => continue, + } } - join_set.spawn(drive_joinset(own_join_set)); - Ok(()) - } + // SIGINT & SIGTERM + let runner = ServiceRunner::new(ctx.clone(), &mut self.join_set); + runner.service_loop(|ctx| async move { + let sigint = tokio::signal::ctrl_c(); + let sigterm = async { + match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) { + Ok(mut signal) => signal.recv().await, + Err(_) => core::future::pending().await, // SIGTERM not supported + } + }; - fn id(&self) -> MadaraService { - MadaraService::None - } -} + tokio::select! { + res = sigint => res?, + _ = sigterm => {}, + }; -async fn drive_joinset(mut join_set: JoinSet<anyhow::Result<()>>) -> anyhow::Result<()> { - while let Some(result) = join_set.join_next().await { - match result { - Ok(result) => result?, - Err(panic_error) if panic_error.is_panic() => { - // bubble up panics too - panic::resume_unwind(panic_error.into_panic()); - } - Err(_task_cancelled_error) => {} + ctx.cancel_global(); + + anyhow::Ok(()) + }); + + tracing::debug!("Running services: {:?}", self.status_request.active_set()); + while self.status_request.is_active_some() { + tokio::select! { + // A service has run to completion, mark it as inactive + Some(result) = self.join_set.join_next() => { + match result { + Ok(result) => { + let id = result?; + tracing::debug!("service {id} has shut down"); + self.status_actual.deactivate(id); + self.status_request.deactivate(id); + } + Err(panic_error) if panic_error.is_panic() => { + // bubble up panics too + panic::resume_unwind(panic_error.into_panic()); + } + Err(_task_cancelled_error) => {} + } + }, + // A service has had its status updated, check if it is a + // restart request + Some(ServiceTransport { svc_id, status }) = ctx.service_subscribe() => { + if status == MadaraServiceStatus::On { + if let Some(svc) = self.services[svc_id.index()].as_mut() { + if self.status_actual.status(svc_id) == MadaraServiceStatus::Off { + self.status_actual.activate(svc_id); + + let ctx = ctx.child().with_id(svc_id); + let runner = ServiceRunner::new(ctx, &mut self.join_set); + svc.start(runner) + .await + .context("Starting service")?; + + tracing::debug!("service {svc_id} has started"); + } else { + // reset request + self.status_request.deactivate(svc_id); + } + } + } + }, + else => continue + }; + + tracing::debug!("Services still active: {:?}", self.status_request.active_set()); } - } - Ok(()) + Ok(()) + } } diff --git a/crates/proc-macros/Cargo.toml b/crates/proc-macros/Cargo.toml index ad6b33cec..85c71fdf8 100644 --- a/crates/proc-macros/Cargo.toml +++ b/crates/proc-macros/Cargo.toml @@ -15,10 +15,17 @@ workspace = true targets = ["x86_64-unknown-linux-gnu"] [dependencies] -indoc = { workspace = true } -proc-macro2 = "1.0.86" -quote = "1.0.26" -syn = { version = "2.0.39", features = ["full"] } +indoc.workspace = true +proc-macro2.workspace = true +quote.workspace = true +syn.workspace = true + +# This is currently only used inside code blocks in doc comments +[dev-dependencies] +jsonrpsee = { workspace = true, default-features = true, features = [ + "macros", + "server", +] } [lib] proc-macro = true diff --git a/crates/proc-macros/src/lib.rs b/crates/proc-macros/src/lib.rs index d0bc77920..74ab634c6 100644 --- a/crates/proc-macros/src/lib.rs +++ b/crates/proc-macros/src/lib.rs @@ -45,21 +45,31 @@ //! //! Given this code: //! -//! ```rust,ignore +//! ```rust +//! # use m_proc_macros::versioned_rpc; +//! # use std::sync::Arc; +//! # use std::error::Error; +//! # use jsonrpsee::core::RpcResult; +//! //! #[versioned_rpc("V0_7_1", "starknet")] //! pub trait JsonRpc { //! #[method(name = "blockNumber", and_versions = ["V0_8_0"])] -//! fn block_number(&self) -> anyhow::Result<u64>; +//! fn block_number(&self) -> RpcResult<u64>; //! } //! ``` //! //! The macro will generate the following code: //! -//! ```rust,ignore +//! ```rust +//! # use m_proc_macros::versioned_rpc; +//! # use std::sync::Arc; +//! # use std::error::Error; +//! # use jsonrpsee::core::RpcResult; +//! //! #[jsonrpsee::proc_macros::rpc(server, client, namespace = "starknet")] //! pub trait JsonRpcV0_7_1 { //! #[method(name = "V0_7_1_blockNumber", aliases = ["starknet_V0_8_0blockNumber"])] -//! fn block_number(&self) -> anyhow::Result<u64>; +//! fn block_number(&self) -> RpcResult<u64>; //! } //! ```