diff --git a/cmd/juno/juno.go b/cmd/juno/juno.go index c61ee40529..c54acedbaa 100644 --- a/cmd/juno/juno.go +++ b/cmd/juno/juno.go @@ -82,6 +82,7 @@ const ( callMaxStepsF = "rpc-call-max-steps" corsEnableF = "rpc-cors-enable" versionedConstantsFileF = "versioned-constants-file" + vmConcurrencyModeF = "vm-concurrency-mode" defaultConfig = "" defaulHost = "localhost" @@ -119,6 +120,7 @@ const ( defaultGwTimeout = 5 * time.Second defaultCorsEnable = false defaultVersionedConstantsFile = "" + defaultVMConcurrencyMode = false configFlagUsage = "The YAML configuration file." logLevelFlagUsage = "Options: trace, debug, info, warn, error." @@ -170,6 +172,7 @@ const ( "The upper limit is 4 million steps, and any higher value will still be capped at 4 million." corsEnableUsage = "Enable CORS on RPC endpoints" versionedConstantsFileUsage = "Use custom versioned constants from provided file" + vmConcurrencyModeUsage = "Enable VM concurrency mode" ) var Version string @@ -355,6 +358,7 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr junoCmd.Flags().Bool(corsEnableF, defaultCorsEnable, corsEnableUsage) junoCmd.Flags().String(versionedConstantsFileF, defaultVersionedConstantsFile, versionedConstantsFileUsage) junoCmd.MarkFlagsMutuallyExclusive(p2pFeederNodeF, p2pPeersF) + junoCmd.Flags().Bool(vmConcurrencyModeF, defaultVMConcurrencyMode, vmConcurrencyModeUsage) junoCmd.AddCommand(GenP2PKeyPair(), DBCmd(defaultDBPath)) diff --git a/node/node.go b/node/node.go index abd0b93990..7286f7e8d1 100644 --- a/node/node.go +++ b/node/node.go @@ -66,6 +66,7 @@ type Config struct { PendingPollInterval time.Duration `mapstructure:"pending-poll-interval"` RemoteDB string `mapstructure:"remote-db"` VersionedConstantsFile string `mapstructure:"versioned-constants-file"` + VMConcurrencyMode bool `mapstructure:"vm-concurrency-mode"` Metrics bool `mapstructure:"metrics"` MetricsHost string `mapstructure:"metrics-host"` @@ -179,7 +180,7 @@ func New(cfg *Config, version string) (*Node, error) { //nolint:gocyclo,funlen services = append(services, synchronizer) } - throttledVM := NewThrottledVM(vm.New(false, log), cfg.MaxVMs, int32(cfg.MaxVMQueue)) + throttledVM := NewThrottledVM(vm.New(cfg.VMConcurrencyMode, log), cfg.MaxVMs, int32(cfg.MaxVMQueue)) var syncReader sync.Reader = &sync.NoopSynchronizer{} if synchronizer != nil { diff --git a/node/throttled_vm.go b/node/throttled_vm.go index 6ae72af728..2a6477dac9 100644 --- a/node/throttled_vm.go +++ b/node/throttled_vm.go @@ -15,7 +15,7 @@ type ThrottledVM struct { func NewThrottledVM(res vm.VM, concurrenyBudget uint, maxQueueLen int32) *ThrottledVM { return &ThrottledVM{ - Throttler: utils.NewThrottler[vm.VM](concurrenyBudget, &res).WithMaxQueueLen(maxQueueLen), + Throttler: utils.NewThrottler(concurrenyBudget, &res).WithMaxQueueLen(maxQueueLen), } } diff --git a/vm/rust/Cargo.toml b/vm/rust/Cargo.toml index 2909483444..ae87a80d97 100644 --- a/vm/rust/Cargo.toml +++ b/vm/rust/Cargo.toml @@ -8,7 +8,7 @@ edition = "2021" [dependencies] serde = "1.0.208" serde_json = { version = "1.0.125", features = ["raw_value"] } -blockifier = "0.8.0-rc.3" +blockifier = { version = "0.8.0-rc.3", features = ["concurrency"] } starknet_api = "0.13.0-rc.1" cairo-vm = "=1.0.1" starknet-types-core = { version = "0.1.5", features = ["hash", "prime-bigint"] } @@ -18,6 +18,7 @@ once_cell = "1.19.0" lazy_static = "1.4.0" semver = "1.0.22" anyhow = "1.0.81" +num_cpus = "1.15" [lib] crate-type = ["staticlib"] diff --git a/vm/rust/src/lib.rs b/vm/rust/src/lib.rs index fba659199b..0fb3538dfc 100644 --- a/vm/rust/src/lib.rs +++ b/vm/rust/src/lib.rs @@ -21,6 +21,10 @@ use blockifier::bouncer::BouncerConfig; use blockifier::fee::{fee_utils, gas_usage}; use blockifier::transaction::objects::GasVector; use blockifier::{ + blockifier::{ + config::{ConcurrencyConfig, TransactionExecutorConfig}, + transaction_executor::{TransactionExecutor, TransactionExecutorError}, + }, context::{BlockContext, ChainInfo, FeeTokenAddresses, TransactionContext}, execution::{ contract_class::ClassInfo, @@ -33,7 +37,6 @@ use blockifier::{ }, objects::{DeprecatedTransactionInfo, HasRelatedFeeType, TransactionInfo}, transaction_execution::Transaction, - transactions::ExecutableTransaction, }, versioned_constants::VersionedConstants, }; @@ -230,11 +233,26 @@ pub extern "C" fn cairoVMExecute( None, concurrency_mode, ); - let charge_fee = skip_charge_fee == 0; - let validate = skip_validate == 0; + let _charge_fee = skip_charge_fee == 0; + let _validate = skip_validate == 0; let mut trace_buffer = Vec::with_capacity(10_000); + let n_workers = num_cpus::get() / 2; + // Initialize the TransactionExecutor + let config = TransactionExecutorConfig { + concurrency_config: ConcurrencyConfig { + enabled: concurrency_mode, + chunk_size: n_workers * 3, + n_workers, + }, + }; + + let mut executor = TransactionExecutor::new(state, block_context.clone(), config); + + let mut transactions: Vec = Vec::new(); + + // Prepare transactions for (txn_index, txn_and_query_bit) in txns_and_query_bits.iter().enumerate() { let class_info = match txn_and_query_bit.txn.clone() { StarknetApiTransaction::Declare(_) => { @@ -277,37 +295,43 @@ pub extern "C" fn cairoVMExecute( return; } - let mut txn_state = CachedState::create_transactional(&mut state); - let fee_type; - let minimal_l1_gas_amount_vector: Option; - let res = match txn.unwrap() { - Transaction::AccountTransaction(t) => { - fee_type = t.fee_type(); - minimal_l1_gas_amount_vector = - Some(gas_usage::estimate_minimal_gas_vector(&block_context, &t).unwrap()); - t.execute(&mut txn_state, &block_context, charge_fee, validate) - } - Transaction::L1HandlerTransaction(t) => { - fee_type = t.fee_type(); - minimal_l1_gas_amount_vector = None; - t.execute(&mut txn_state, &block_context, charge_fee, validate) + match txn { + Ok(txn) => transactions.push(txn), + Err(_) => { + report_error( + reader_handle, + "failed to create transaction", + txn_index as i64, + ); + return; } - }; + } + } + // Execute transactions + let results = executor.execute_txs(&transactions); + let mut block_state = executor.block_state.take().unwrap(); + + // Process results + for (txn_index, res) in results.into_iter().enumerate() { match res { Err(error) => { let err_string = match &error { - ContractConstructorExecutionFailed(e) => format!("{error} {e}"), - ExecutionError { error: e, .. } | ValidateTransactionError { error: e, .. } => { - format!("{error} {e}") - } + TransactionExecutorError::TransactionExecutionError(err) => match err { + ContractConstructorExecutionFailed(e) => format!("{error} {e}"), + ExecutionError { error: e, .. } + | ValidateTransactionError { error: e, .. } => { + format!("{error} {e}") + } + other => other.to_string(), + }, other => other.to_string(), }; report_error( reader_handle, format!( "failed txn {} reason: {}", - txn_and_query_bit.txn_hash, err_string, + txns_and_query_bits[txn_index].txn_hash, err_string, ) .as_str(), txn_index as i64, @@ -326,6 +350,20 @@ pub extern "C" fn cairoVMExecute( // we are estimating fee, override actual fee calculation if t.transaction_receipt.fee.0 == 0 { + let minimal_l1_gas_amount_vector: Option; + let fee_type; + match &transactions[txn_index] { + Transaction::AccountTransaction(at) => { + fee_type = at.fee_type(); + minimal_l1_gas_amount_vector = Some( + gas_usage::estimate_minimal_gas_vector(&block_context, at).unwrap(), + ); + } + Transaction::L1HandlerTransaction(ht) => { + fee_type = ht.fee_type(); + minimal_l1_gas_amount_vector = None; + } + } let minimal_l1_gas_amount_vector = minimal_l1_gas_amount_vector.unwrap_or_default(); let gas_consumed = t @@ -359,8 +397,13 @@ pub extern "C" fn cairoVMExecute( .try_into() .unwrap_or(u64::MAX); - let trace = - jsonrpc::new_transaction_trace(&txn_and_query_bit.txn, t, &mut txn_state); + let mut txn_state = CachedState::create_transactional(&mut block_state); + + let trace = jsonrpc::new_transaction_trace( + &txns_and_query_bits[txn_index].txn, + t, + &mut txn_state, + ); if let Err(e) = trace { report_error( reader_handle, @@ -381,7 +424,6 @@ pub extern "C" fn cairoVMExecute( append_trace(reader_handle, trace.as_ref().unwrap(), &mut trace_buffer); } } - txn_state.commit(); } }