@@ -11,6 +11,7 @@ use parking_lot::Mutex;
1111use std:: {
1212 fs:: File ,
1313 io,
14+ ops:: Deref ,
1415 path:: { Path , PathBuf } ,
1516} ;
1617
@@ -194,19 +195,12 @@ impl Database {
194195 Ok ( ( ) )
195196 }
196197
197- pub fn begin_rw ( & self ) -> Result < Transaction < ' _ , RW > , TransactionError > {
198- let context = self . storage_engine . write_context ( ) ;
199- let min_snapshot_id = self . transaction_manager . lock ( ) . begin_rw ( context. snapshot_id ) ?;
200- if min_snapshot_id > 0 {
201- self . storage_engine . unlock ( min_snapshot_id - 1 ) ;
202- }
203- Ok ( Transaction :: new ( context, self ) )
198+ pub fn begin_ro ( & self ) -> Result < Transaction < & Self , RO > , TransactionError > {
199+ begin_ro ( self )
204200 }
205201
206- pub fn begin_ro ( & self ) -> Result < Transaction < ' _ , RO > , TransactionError > {
207- let context = self . storage_engine . read_context ( ) ;
208- self . transaction_manager . lock ( ) . begin_ro ( context. snapshot_id ) ;
209- Ok ( Transaction :: new ( context, self ) )
202+ pub fn begin_rw ( & self ) -> Result < Transaction < & Self , RW > , TransactionError > {
203+ begin_rw ( self )
210204 }
211205
212206 pub fn state_root ( & self ) -> B256 {
@@ -244,13 +238,32 @@ impl Database {
244238 }
245239}
246240
241+ pub fn begin_ro < DB : Deref < Target = Database > > (
242+ db : DB ,
243+ ) -> Result < Transaction < DB , RO > , TransactionError > {
244+ let context = db. storage_engine . read_context ( ) ;
245+ db. transaction_manager . lock ( ) . begin_ro ( context. snapshot_id ) ;
246+ Ok ( Transaction :: new ( context, db) )
247+ }
248+
249+ pub fn begin_rw < DB : Deref < Target = Database > > (
250+ db : DB ,
251+ ) -> Result < Transaction < DB , RW > , TransactionError > {
252+ let context = db. storage_engine . write_context ( ) ;
253+ let min_snapshot_id = db. transaction_manager . lock ( ) . begin_rw ( context. snapshot_id ) ?;
254+ if min_snapshot_id > 0 {
255+ db. storage_engine . unlock ( min_snapshot_id - 1 ) ;
256+ }
257+ Ok ( Transaction :: new ( context, db) )
258+ }
259+
247260#[ cfg( test) ]
248261mod tests {
249262 use super :: * ;
250263 use crate :: { account:: Account , path:: AddressPath } ;
251264 use alloy_primitives:: { address, Address , U256 } ;
252265 use alloy_trie:: { EMPTY_ROOT_HASH , KECCAK_EMPTY } ;
253- use std:: fs ;
266+ use std:: { fs , sync :: Arc } ;
254267 use tempdir:: TempDir ;
255268
256269 #[ test]
@@ -315,13 +328,13 @@ mod tests {
315328 let address = address ! ( "0xd8da6bf26964af9d7eed9e03e53415d37aa96045" ) ;
316329
317330 let account1 = Account :: new ( 1 , U256 :: from ( 100 ) , EMPTY_ROOT_HASH , KECCAK_EMPTY ) ;
318- let mut tx = db . begin_rw ( ) . unwrap ( ) ;
331+ let mut tx = begin_rw ( & db ) . unwrap ( ) ;
319332 tx. set_account ( AddressPath :: for_address ( address) , Some ( account1. clone ( ) ) ) . unwrap ( ) ;
320333
321334 tx. commit ( ) . unwrap ( ) ;
322335
323336 let account2 = Account :: new ( 456 , U256 :: from ( 123 ) , EMPTY_ROOT_HASH , KECCAK_EMPTY ) ;
324- let mut tx = db . begin_rw ( ) . unwrap ( ) ;
337+ let mut tx = begin_rw ( & db ) . unwrap ( ) ;
325338 tx. set_account ( AddressPath :: for_address ( address) , Some ( account2. clone ( ) ) ) . unwrap ( ) ;
326339
327340 let mut ro_tx = db. begin_ro ( ) . unwrap ( ) ;
@@ -512,4 +525,26 @@ mod tests {
512525 ) ;
513526 }
514527 }
528+
529+ #[ test]
530+ fn test_db_arc_tx ( ) {
531+ let tmp_dir = TempDir :: new ( "test_db" ) . unwrap ( ) ;
532+ let file_path = tmp_dir. path ( ) . join ( "test.db" ) ;
533+ let db = Database :: create_new ( & file_path) . unwrap ( ) ;
534+
535+ let db_arc = Arc :: new ( db) ;
536+
537+ let address = address ! ( "0xd8da6bf26964af9d7eed9e03e53415d37aa96045" ) ;
538+ let mut tx = begin_rw ( db_arc. clone ( ) ) . unwrap ( ) ;
539+ tx. set_account (
540+ AddressPath :: for_address ( address) ,
541+ Some ( Account :: new ( 1 , U256 :: from ( 100 ) , EMPTY_ROOT_HASH , KECCAK_EMPTY ) ) ,
542+ )
543+ . unwrap ( ) ;
544+ tx. commit ( ) . unwrap ( ) ;
545+
546+ let mut tx = begin_ro ( db_arc) . unwrap ( ) ;
547+ let account = tx. get_account ( AddressPath :: for_address ( address) ) . unwrap ( ) . unwrap ( ) ;
548+ assert_eq ! ( account, Account :: new( 1 , U256 :: from( 100 ) , EMPTY_ROOT_HASH , KECCAK_EMPTY ) ) ;
549+ }
515550}
0 commit comments