diff --git a/crates/blockifier/src/concurrency/versioned_state_proxy.rs b/crates/blockifier/src/concurrency/versioned_state_proxy.rs index cc61f80367..799d0d59ff 100644 --- a/crates/blockifier/src/concurrency/versioned_state_proxy.rs +++ b/crates/blockifier/src/concurrency/versioned_state_proxy.rs @@ -8,7 +8,7 @@ use starknet_api::state::StorageKey; use crate::concurrency::versioned_storage::VersionedStorage; use crate::concurrency::TxIndex; use crate::execution::contract_class::ContractClass; -use crate::state::cached_state::{ContractClassMapping, StateCache}; +use crate::state::cached_state::{ContractClassMapping, StateMaps}; use crate::state::state_api::{State, StateReader, StateResult}; #[cfg(test)] @@ -42,21 +42,17 @@ impl VersionedState { } } - // Note: Invoke this function after `update_initial_values_of_write_only_access`. - // Transactions that overwrite previously written values are not charged. Hence, altering a - // write-only cell can impact the fee calculation, leading to a re-execution. // TODO(Mohammad, 01/04/2024): Store the read set (and write set) within a shared // object (probabily `VersionedState`). As RefCell operations are not thread-safe. Therefore, // accessing this function should be protected by a mutex to ensure thread safety. - pub fn validate_read_set(&mut self, tx_index: TxIndex, state_cache: &mut StateCache) -> bool { + pub fn validate_read_set(&mut self, tx_index: TxIndex, reads: &StateMaps) -> bool { // If is the first transaction in the chunk, then the read set is valid. Since it has no // predecessors, there's nothing to compare it to. if tx_index == 0 { return true; } - for (&(contract_address, storage_key), expected_value) in - &state_cache.storage_initial_values - { + + for (&(contract_address, storage_key), expected_value) in &reads.storage { let value = self.storage.read(tx_index, (contract_address, storage_key)).expect(READ_ERR); @@ -65,7 +61,7 @@ impl VersionedState { } } - for (&contract_address, expected_value) in &state_cache.nonce_initial_values { + for (&contract_address, expected_value) in &reads.nonces { let value = self.nonces.read(tx_index, contract_address).expect(READ_ERR); if &value != expected_value { @@ -73,7 +69,7 @@ impl VersionedState { } } - for (&contract_address, expected_value) in &state_cache.class_hash_initial_values { + for (&contract_address, expected_value) in &reads.class_hashes { let value = self.class_hashes.read(tx_index, contract_address).expect(READ_ERR); if &value != expected_value { @@ -82,7 +78,7 @@ impl VersionedState { } // Added for symmetry. We currently do not update this initial mapping. - for (&class_hash, expected_value) in &state_cache.compiled_class_hash_initial_values { + for (&class_hash, expected_value) in &reads.compiled_class_hashes { let value = self.compiled_class_hashes.read(tx_index, class_hash).expect(READ_ERR); if &value != expected_value { @@ -100,19 +96,19 @@ impl VersionedState { pub fn apply_writes( &mut self, tx_index: TxIndex, - state_cache: &mut StateCache, + writes: &StateMaps, class_hash_to_class: ContractClassMapping, ) { - for (&key, &value) in &state_cache.storage_writes { + for (&key, &value) in &writes.storage { self.storage.write(tx_index, key, value); } - for (&key, &value) in &state_cache.nonce_writes { + for (&key, &value) in &writes.nonces { self.nonces.write(tx_index, key, value); } - for (&key, &value) in &state_cache.class_hash_writes { + for (&key, &value) in &writes.class_hashes { self.class_hashes.write(tx_index, key, value); } - for (&key, &value) in &state_cache.compiled_class_hash_writes { + for (&key, &value) in &writes.compiled_class_hashes { self.compiled_class_hashes.write(tx_index, key, value); } for (key, value) in class_hash_to_class { diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index 0409d839e9..7c8aeaae1d 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -70,11 +70,7 @@ impl CachedState { pub fn update_cache(&mut self, cache_updates: StateCache) { let mut cache = self.cache.borrow_mut(); - - cache.nonce_writes.extend(cache_updates.nonce_writes); - cache.class_hash_writes.extend(cache_updates.class_hash_writes); - cache.storage_writes.extend(cache_updates.storage_writes); - cache.compiled_class_hash_writes.extend(cache_updates.compiled_class_hash_writes); + cache.writes.extend(&cache_updates.writes); } pub fn update_contract_class_cache( @@ -100,30 +96,32 @@ impl CachedState { // Eliminate storage writes that are identical to the initial value (no change). Assumes // that `set_storage_at` does not affect the state field. - for contract_storage_key in cache.storage_writes.keys() { - if !cache.storage_initial_values.contains_key(contract_storage_key) { + for contract_storage_key in cache.writes.storage.keys() { + if !cache.initial_reads.storage.contains_key(contract_storage_key) { // First access to this cell was write; cache initial value. - cache.storage_initial_values.insert( + cache.initial_reads.storage.insert( *contract_storage_key, self.state.get_storage_at(contract_storage_key.0, contract_storage_key.1)?, ); } } - for contract_address in cache.class_hash_writes.keys() { - if !cache.class_hash_initial_values.contains_key(contract_address) { + for contract_address in cache.writes.class_hashes.keys() { + if !cache.initial_reads.class_hashes.contains_key(contract_address) { // First access to this cell was write; cache initial value. cache - .class_hash_initial_values + .initial_reads + .class_hashes .insert(*contract_address, self.state.get_class_hash_at(*contract_address)?); } } - for contract_address in cache.nonce_writes.keys() { - if !cache.nonce_initial_values.contains_key(contract_address) { + for contract_address in cache.writes.nonces.keys() { + if !cache.initial_reads.nonces.contains_key(contract_address) { // First access to this cell was write; cache initial value. cache - .nonce_initial_values + .initial_reads + .nonces .insert(*contract_address, self.state.get_nonce_at(*contract_address)?); } } @@ -142,7 +140,7 @@ impl CachedState { let class_hash_updates = state_cache.get_class_hash_updates(); let storage_diffs = state_cache.get_storage_updates(); let nonces = state_cache.get_nonce_updates(); - let declared_classes = state_cache.compiled_class_hash_writes.clone(); + let declared_classes = state_cache.writes.compiled_class_hashes.clone(); CommitmentStateDiff { address_to_class_hash: IndexMap::from_iter(class_hash_updates), @@ -348,6 +346,24 @@ impl From for IndexMap, + pub(crate) class_hashes: HashMap, + pub(crate) storage: HashMap, + pub(crate) compiled_class_hashes: HashMap, + pub(crate) declared_contracts: HashMap, +} + +impl StateMaps { + pub fn extend(&mut self, other: &Self) { + self.nonces.extend(&other.nonces); + self.class_hashes.extend(&other.class_hashes); + self.storage.extend(&other.storage); + self.compiled_class_hashes.extend(&other.compiled_class_hashes); + self.declared_contracts.extend(&other.declared_contracts) + } +} /// Caches read and write requests. /// The tracked changes are needed for block state commitment. @@ -355,27 +371,19 @@ impl From for IndexMap, - pub(crate) class_hash_initial_values: HashMap, - pub(crate) storage_initial_values: HashMap, - pub(crate) compiled_class_hash_initial_values: HashMap, - pub(crate) declared_contract_initial_values: HashMap, + pub(crate) initial_reads: StateMaps, // Writer's cached information. - pub(crate) nonce_writes: HashMap, - pub(crate) class_hash_writes: HashMap, - pub(crate) storage_writes: HashMap, - pub(crate) compiled_class_hash_writes: HashMap, - pub(crate) declared_contract_writes: HashMap, + pub(crate) writes: StateMaps, } impl StateCache { fn declare_contract(&mut self, class_hash: ClassHash) { - self.declared_contract_writes.insert(class_hash, true); + self.writes.declared_contracts.insert(class_hash, true); } fn set_declared_contract_initial_values(&mut self, class_hash: ClassHash, is_declared: bool) { - self.declared_contract_initial_values.insert(class_hash, is_declared); + self.initial_reads.declared_contracts.insert(class_hash, is_declared); } fn get_storage_at( @@ -384,15 +392,17 @@ impl StateCache { key: StorageKey, ) -> Option<&StarkFelt> { let contract_storage_key = (contract_address, key); - self.storage_writes + self.writes + .storage .get(&contract_storage_key) - .or_else(|| self.storage_initial_values.get(&contract_storage_key)) + .or_else(|| self.initial_reads.storage.get(&contract_storage_key)) } fn get_nonce_at(&self, contract_address: ContractAddress) -> Option<&Nonce> { - self.nonce_writes + self.writes + .nonces .get(&contract_address) - .or_else(|| self.nonce_initial_values.get(&contract_address)) + .or_else(|| self.initial_reads.nonces.get(&contract_address)) } pub fn set_storage_initial_value( @@ -402,7 +412,7 @@ impl StateCache { value: StarkFelt, ) { let contract_storage_key = (contract_address, key); - self.storage_initial_values.insert(contract_storage_key, value); + self.initial_reads.storage.insert(contract_storage_key, value); } fn set_storage_value( @@ -412,21 +422,22 @@ impl StateCache { value: StarkFelt, ) { let contract_storage_key = (contract_address, key); - self.storage_writes.insert(contract_storage_key, value); + self.writes.storage.insert(contract_storage_key, value); } fn set_nonce_initial_value(&mut self, contract_address: ContractAddress, nonce: Nonce) { - self.nonce_initial_values.insert(contract_address, nonce); + self.initial_reads.nonces.insert(contract_address, nonce); } fn set_nonce_value(&mut self, contract_address: ContractAddress, nonce: Nonce) { - self.nonce_writes.insert(contract_address, nonce); + self.writes.nonces.insert(contract_address, nonce); } fn get_class_hash_at(&self, contract_address: ContractAddress) -> Option<&ClassHash> { - self.class_hash_writes + self.writes + .class_hashes .get(&contract_address) - .or_else(|| self.class_hash_initial_values.get(&contract_address)) + .or_else(|| self.initial_reads.class_hashes.get(&contract_address)) } fn set_class_hash_initial_value( @@ -434,17 +445,18 @@ impl StateCache { contract_address: ContractAddress, class_hash: ClassHash, ) { - self.class_hash_initial_values.insert(contract_address, class_hash); + self.initial_reads.class_hashes.insert(contract_address, class_hash); } fn set_class_hash_write(&mut self, contract_address: ContractAddress, class_hash: ClassHash) { - self.class_hash_writes.insert(contract_address, class_hash); + self.writes.class_hashes.insert(contract_address, class_hash); } fn get_compiled_class_hash(&self, class_hash: ClassHash) -> Option<&CompiledClassHash> { - self.compiled_class_hash_writes + self.writes + .compiled_class_hashes .get(&class_hash) - .or_else(|| self.compiled_class_hash_initial_values.get(&class_hash)) + .or_else(|| self.initial_reads.compiled_class_hashes.get(&class_hash)) } fn set_compiled_class_hash_initial_value( @@ -452,7 +464,7 @@ impl StateCache { class_hash: ClassHash, compiled_class_hash: CompiledClassHash, ) { - self.compiled_class_hash_initial_values.insert(class_hash, compiled_class_hash); + self.initial_reads.compiled_class_hashes.insert(class_hash, compiled_class_hash); } fn set_compiled_class_hash_write( @@ -460,19 +472,19 @@ impl StateCache { class_hash: ClassHash, compiled_class_hash: CompiledClassHash, ) { - self.compiled_class_hash_writes.insert(class_hash, compiled_class_hash); + self.writes.compiled_class_hashes.insert(class_hash, compiled_class_hash); } fn get_storage_updates(&self) -> HashMap { - strict_subtract_mappings(&self.storage_writes, &self.storage_initial_values) + strict_subtract_mappings(&self.writes.storage, &self.initial_reads.storage) } fn get_class_hash_updates(&self) -> HashMap { - strict_subtract_mappings(&self.class_hash_writes, &self.class_hash_initial_values) + strict_subtract_mappings(&self.writes.class_hashes, &self.initial_reads.class_hashes) } fn get_nonce_updates(&self) -> HashMap { - strict_subtract_mappings(&self.nonce_writes, &self.nonce_initial_values) + strict_subtract_mappings(&self.writes.nonces, &self.initial_reads.nonces) } fn get_compiled_class_hash_updates(&self) -> HashMap { @@ -482,8 +494,8 @@ impl StateCache { // class hash writes keys are not a subset of compiled class hash initial values keys. subtract_mappings( - &self.compiled_class_hash_writes, - &self.compiled_class_hash_initial_values, + &self.writes.compiled_class_hashes, + &self.initial_reads.compiled_class_hashes, ) } } @@ -525,49 +537,6 @@ impl<'a, S: State + ?Sized> StateReader for MutRefState<'a, S> { } } -impl<'a, S: State + ?Sized> State for MutRefState<'a, S> { - fn set_storage_at( - &mut self, - contract_address: ContractAddress, - key: StorageKey, - value: StarkFelt, - ) -> StateResult<()> { - self.0.set_storage_at(contract_address, key, value) - } - - fn increment_nonce(&mut self, contract_address: ContractAddress) -> StateResult<()> { - self.0.increment_nonce(contract_address) - } - - fn set_class_hash_at( - &mut self, - contract_address: ContractAddress, - class_hash: ClassHash, - ) -> StateResult<()> { - self.0.set_class_hash_at(contract_address, class_hash) - } - - fn set_contract_class( - &mut self, - class_hash: ClassHash, - contract_class: ContractClass, - ) -> StateResult<()> { - self.0.set_contract_class(class_hash, contract_class) - } - - fn set_compiled_class_hash( - &mut self, - class_hash: ClassHash, - compiled_class_hash: CompiledClassHash, - ) -> StateResult<()> { - self.0.set_compiled_class_hash(class_hash, compiled_class_hash) - } - - fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &HashSet) { - self.0.add_visited_pcs(class_hash, pcs) - } -} - pub type TransactionalState<'a, S> = CachedState>>; /// Adds the ability to perform a transactional execution. diff --git a/crates/blockifier/src/state/cached_state_test.rs b/crates/blockifier/src/state/cached_state_test.rs index 0c14fe002d..71dd1054bd 100644 --- a/crates/blockifier/src/state/cached_state_test.rs +++ b/crates/blockifier/src/state/cached_state_test.rs @@ -27,9 +27,9 @@ fn set_initial_state_values( assert!(*state.cache.borrow() == StateCache::default(), "Cache already initialized."); state.class_hash_to_class.replace(class_hash_to_class); - state.cache.get_mut().class_hash_initial_values.extend(class_hash_initial_values); - state.cache.get_mut().nonce_initial_values.extend(nonce_initial_values); - state.cache.get_mut().storage_initial_values.extend(storage_initial_values); + state.cache.get_mut().initial_reads.class_hashes.extend(class_hash_initial_values); + state.cache.get_mut().initial_reads.nonces.extend(nonce_initial_values); + state.cache.get_mut().initial_reads.storage.extend(storage_initial_values); } #[test] @@ -112,8 +112,8 @@ fn declare_contract() { let class_hash = test_contract.get_class_hash(); let contract_class = test_contract.get_class(); - assert_eq!(state.cache.borrow().declared_contract_writes.get(&class_hash), None); - assert_eq!(state.cache.borrow().declared_contract_initial_values.get(&class_hash), None); + assert_eq!(state.cache.borrow().writes.declared_contracts.get(&class_hash), None); + assert_eq!(state.cache.borrow().initial_reads.declared_contracts.get(&class_hash), None); // Reading an undeclared contract class. assert_matches!( @@ -122,12 +122,12 @@ fn declare_contract() { undeclared_class_hash == class_hash ); assert_eq!( - *state.cache.borrow().declared_contract_initial_values.get(&class_hash).unwrap(), + *state.cache.borrow().initial_reads.declared_contracts.get(&class_hash).unwrap(), false ); state.set_contract_class(class_hash, contract_class).unwrap(); - assert_eq!(*state.cache.borrow().declared_contract_writes.get(&class_hash).unwrap(), true); + assert_eq!(*state.cache.borrow().writes.declared_contracts.get(&class_hash).unwrap(), true); } #[test] @@ -574,3 +574,26 @@ fn test_state_changes_keys() { } ) } + +#[rstest] +fn test_state_maps() { + let contract_address1 = contract_address!("0x101"); + let storage_key1 = StorageKey(patricia_key!("0x102")); + let class_hash1 = ClassHash(stark_felt!("0x103")); + let nonce1 = Nonce(stark_felt!("0x104")); + let compiled_class_hash1 = CompiledClassHash(stark_felt!("0x105")); + let some_felt1 = stark_felt!("0x106"); + let maps = StateMaps { + nonces: HashMap::from([(contract_address1, nonce1)]), + class_hashes: HashMap::from([(contract_address1, class_hash1)]), + storage: HashMap::from([((contract_address1, storage_key1), some_felt1)]), + compiled_class_hashes: HashMap::from([(class_hash1, compiled_class_hash1)]), + declared_contracts: HashMap::from([(class_hash1, true)]), + }; + + // Test that `extend` extends all hash maps (by constructing `maps` without default values). + let mut empty = StateMaps::default(); + empty.extend(&maps); + + assert_eq!(maps, empty); +} diff --git a/crates/blockifier/src/test_utils/struct_impls.rs b/crates/blockifier/src/test_utils/struct_impls.rs index 0bd0998713..df36454e40 100644 --- a/crates/blockifier/src/test_utils/struct_impls.rs +++ b/crates/blockifier/src/test_utils/struct_impls.rs @@ -161,6 +161,10 @@ impl BlockContext { ..Self::create_for_account_testing() } } + + pub fn create_for_account_testing_with_concurrency_mode(concurrency_mode: bool) -> Self { + Self { concurrency_mode, ..Self::create_for_account_testing() } + } } impl CallExecution { diff --git a/crates/blockifier/src/transaction/account_transaction.rs b/crates/blockifier/src/transaction/account_transaction.rs index da3b3d69c1..1814fa0c22 100644 --- a/crates/blockifier/src/transaction/account_transaction.rs +++ b/crates/blockifier/src/transaction/account_transaction.rs @@ -7,7 +7,8 @@ use starknet_api::deprecated_contract_class::EntryPointType; use starknet_api::hash::StarkFelt; use starknet_api::transaction::{Calldata, Fee, ResourceBounds, TransactionVersion}; -use crate::abi::abi_utils::selector_from_name; +use crate::abi::abi_utils::{get_fee_token_var_address, selector_from_name}; +use crate::abi::sierra_types::next_storage_key; use crate::context::{BlockContext, TransactionContext}; use crate::execution::call_info::{CallInfo, Retdata}; use crate::execution::contract_class::ContractClass; @@ -301,8 +302,8 @@ impl AccountTransaction { Ok(()) } - fn handle_fee( - state: &mut dyn State, + fn handle_fee( + state: &mut TransactionalState<'_, S>, tx_context: Arc, actual_fee: Fee, charge_fee: bool, @@ -315,8 +316,14 @@ impl AccountTransaction { // TODO(Amos, 8/04/2024): Add test for this assert. Self::assert_actual_fee_in_bounds(&tx_context, actual_fee)?; - // Charge fee. - let fee_transfer_call_info = Self::execute_fee_transfer(state, tx_context, actual_fee)?; + let fee_transfer_call_info = if tx_context.block_context.concurrency_mode + && tx_context.block_context.block_info.sequencer_address + != tx_context.tx_info.sender_address() + { + Self::concurrency_execute_fee_transfer(state, tx_context, actual_fee)? + } else { + Self::execute_fee_transfer(state, tx_context, actual_fee)? + }; Ok(Some(fee_transfer_call_info)) } @@ -332,8 +339,6 @@ impl AccountTransaction { let msb_amount = StarkFelt::from(0_u8); let TransactionContext { block_context, tx_info } = tx_context.as_ref(); - - // TODO(Gilad): add test that correct fee address is taken, once we add V3 test support. let storage_address = block_context.chain_info.fee_token_address(&tx_info.fee_type()); let fee_transfer_call = CallEntryPoint { class_hash: None, @@ -359,6 +364,41 @@ impl AccountTransaction { .map_err(TransactionFeeError::ExecuteFeeTransferError)?) } + /// Handles fee transfer in concurrent execution. + /// + /// Accessing and updating the sequencer balance at this stage is a bottleneck; this function + /// manipulates the state to avoid that part. + /// Note: the returned transfer call info is partial, and should be completed at the commit + /// stage, as well as the actual sequencer balance. + fn concurrency_execute_fee_transfer( + state: &mut TransactionalState<'_, S>, + tx_context: Arc, + actual_fee: Fee, + ) -> TransactionExecutionResult { + let TransactionContext { block_context, tx_info } = tx_context.as_ref(); + let fee_address = block_context.chain_info.fee_token_address(&tx_info.fee_type()); + let sequencer_address = block_context.block_info.sequencer_address; + let sequencer_balance_key_low = get_fee_token_var_address(sequencer_address); + let sequencer_balance_key_high = next_storage_key(&sequencer_balance_key_low) + .expect("Cannot get sequencer balance high key."); + let mut transfer_state = CachedState::create_transactional(state); + + // Set the initial sequencer balance to avoid tarnishing the read-set of the transaction. + let cache = transfer_state.cache.get_mut(); + for key in [sequencer_balance_key_low, sequencer_balance_key_high] { + cache.set_storage_initial_value(fee_address, key, StarkFelt::ZERO); + } + + let fee_transfer_call_info = + AccountTransaction::execute_fee_transfer(&mut transfer_state, tx_context, actual_fee); + // Commit without updating the sequencer balance. + let storage_writes = &mut transfer_state.cache.get_mut().writes.storage; + storage_writes.remove(&(fee_address, sequencer_balance_key_low)); + storage_writes.remove(&(fee_address, sequencer_balance_key_high)); + transfer_state.commit(); + fee_transfer_call_info + } + fn run_execute( &self, state: &mut S, @@ -628,7 +668,6 @@ impl ExecutableTransaction for AccountTransaction { validate, charge_fee, )?; - let fee_transfer_call_info = Self::handle_fee(state, tx_context, final_fee, charge_fee)?; let tx_execution_info = TransactionExecutionInfo { diff --git a/crates/blockifier/src/transaction/account_transactions_test.rs b/crates/blockifier/src/transaction/account_transactions_test.rs index de23a838c9..bf249f7731 100644 --- a/crates/blockifier/src/transaction/account_transactions_test.rs +++ b/crates/blockifier/src/transaction/account_transactions_test.rs @@ -19,6 +19,7 @@ use starknet_api::{calldata, class_hash, contract_address, patricia_key, stark_f use crate::abi::abi_utils::{ get_fee_token_var_address, get_storage_var_address, selector_from_name, }; +use crate::abi::sierra_types::next_storage_key; use crate::context::BlockContext; use crate::execution::contract_class::{ContractClass, ContractClassV1}; use crate::execution::entry_point::EntryPointExecutionContext; @@ -1161,3 +1162,89 @@ fn test_count_actual_storage_changes( assert_eq!(expected_storage_update_transfer, state_changes_transfer.storage_updates); assert_eq!(state_changes_count_3, expected_state_changes_count_3); } + +#[rstest] +fn test_concurrency_execute_fee_transfer(#[values(FeeType::Eth, FeeType::Strk)] fee_type: FeeType) { + const STORAGE_WRITE_HIGH: u128 = 150; + const STORAGE_WRITE_LOW: u128 = 100; + const STORAGE_READ_LOW: u128 = 50; + let block_context = BlockContext::create_for_account_testing_with_concurrency_mode(true); + let empty_contract = FeatureContract::Empty(CairoVersion::Cairo1); + let account = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1); + let chain_info = &block_context.chain_info; + let state = &mut test_state(chain_info, BALANCE, &[(account, 1)]); + let class_hash = empty_contract.get_class_hash(); + let class_info = calculate_class_info_for_testing(empty_contract.get_class()); + let sender_address = account.get_instance_address(0); + + let account_tx = declare_tx( + declare_tx_args! { + sender_address, + version: TransactionVersion::THREE, + resource_bounds: l1_resource_bounds(MAX_L1_GAS_AMOUNT, MAX_L1_GAS_PRICE), + class_hash, + }, + class_info.clone(), + ); + + let fee_token_address = block_context.chain_info.fee_token_address(&fee_type); + let sequencer_address = block_context.block_info.sequencer_address; + let sequencer_balance_key_low = get_fee_token_var_address(sequencer_address); + let sequencer_balance_key_high = next_storage_key(&sequencer_balance_key_low).unwrap(); + // Case 1: The transaction did not read form/ write to the sequenser balance before executing + // fee transfer. + let mut transactional_state = CachedState::create_transactional(state); + account_tx.execute_raw(&mut transactional_state, &block_context, true, false).unwrap(); + let transactional_cache = transactional_state.cache.borrow(); + for storage in [ + transactional_cache.initial_reads.storage.clone(), + transactional_cache.writes.storage.clone(), + ] { + for seq_key in [sequencer_balance_key_low, sequencer_balance_key_high] { + assert!(storage.get(&(fee_token_address, seq_key)).is_none()); + } + } + + // Case 2: The transaction read from and write to the sequenser balance before executing fee + // transfer. + + // Set the sequencer balance to a constant value to check that the read set did not changed. + fund_account(chain_info, sequencer_address, STORAGE_READ_LOW, &mut state.state); + let mut transactional_state = CachedState::create_transactional(state); + + // Set the sequencer balance write set to a constant value. + // Note that it is enough to set the storage_write as execute_raw will update the + // storage_initial_values. + for (seq_key, value) in [ + (sequencer_balance_key_low, STORAGE_WRITE_LOW), + (sequencer_balance_key_high, STORAGE_WRITE_HIGH), + ] { + transactional_state.set_storage_at(fee_token_address, seq_key, stark_felt!(value)).unwrap(); + } + + account_tx.execute_raw(&mut transactional_state, &block_context, true, false).unwrap(); + // Check that the sequencer balance was not changed. + let storage_write = transactional_state.cache.borrow().writes.storage.clone(); + let storage_initial_values = transactional_state.cache.borrow().initial_reads.storage.clone(); + + for (seq_write_val, expexted_write_val) in [ + ( + storage_write.get(&(fee_token_address, sequencer_balance_key_low)), + stark_felt!(STORAGE_WRITE_LOW), + ), + ( + storage_initial_values.get(&(fee_token_address, sequencer_balance_key_low)), + stark_felt!(STORAGE_READ_LOW), + ), + ( + storage_write.get(&(fee_token_address, sequencer_balance_key_high)), + stark_felt!(STORAGE_WRITE_HIGH), + ), + ( + storage_initial_values.get(&(fee_token_address, sequencer_balance_key_high)), + StarkFelt::ZERO, + ), + ] { + assert_eq!(*seq_write_val.unwrap(), expexted_write_val); + } +} diff --git a/crates/native_blockifier/src/py_block_executor.rs b/crates/native_blockifier/src/py_block_executor.rs index d08dba37cc..5dd7a29179 100644 --- a/crates/native_blockifier/src/py_block_executor.rs +++ b/crates/native_blockifier/src/py_block_executor.rs @@ -303,12 +303,26 @@ impl PyBlockExecutor { } #[cfg(any(feature = "testing", test))] - #[pyo3(signature = (general_config, path))] + #[pyo3(signature = (general_config, path, max_state_diff_size))] #[staticmethod] - fn create_for_testing(general_config: PyGeneralConfig, path: std::path::PathBuf) -> Self { + fn create_for_testing( + general_config: PyGeneralConfig, + path: std::path::PathBuf, + max_state_diff_size: usize, + ) -> Self { + use blockifier::bouncer::BouncerWeights; use blockifier::state::global_cache::GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST; Self { - bouncer_config: BouncerConfig::max(), + bouncer_config: BouncerConfig { + block_max_capacity: BouncerWeights { + state_diff_size: max_state_diff_size, + ..BouncerWeights::max(false) + }, + block_max_capacity_with_keccak: BouncerWeights { + state_diff_size: max_state_diff_size, + ..BouncerWeights::max(true) + }, + }, storage: Box::new(PapyrusStorage::new_for_testing( path, &general_config.starknet_os_config.chain_id, diff --git a/crates/native_blockifier/src/py_block_executor_test.rs b/crates/native_blockifier/src/py_block_executor_test.rs index 766738b5b9..a00ea38b9e 100644 --- a/crates/native_blockifier/src/py_block_executor_test.rs +++ b/crates/native_blockifier/src/py_block_executor_test.rs @@ -23,7 +23,7 @@ fn global_contract_cache_update() { let temp_storage_path = tempfile::tempdir().unwrap().into_path(); let mut block_executor = - PyBlockExecutor::create_for_testing(PyGeneralConfig::default(), temp_storage_path); + PyBlockExecutor::create_for_testing(PyGeneralConfig::default(), temp_storage_path, 4000); block_executor .append_block( 0,