Skip to content

Commit 20053ad

Browse files
committed
refactor: Hide Db/Tx lifetimes via Deref/Generics
1 parent e207624 commit 20053ad

File tree

2 files changed

+61
-24
lines changed

2 files changed

+61
-24
lines changed

src/database.rs

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use parking_lot::Mutex;
1111
use std::{
1212
fs::File,
1313
io,
14+
ops::Deref,
1415
path::{Path, PathBuf},
1516
};
1617

@@ -194,19 +195,12 @@ impl Database {
194195
Ok(())
195196
}
196197

197-
pub fn begin_rw(&self) -> Result<Transaction<'_, RW>, TransactionError> {
198-
let context = self.storage_engine.write_context();
199-
let min_snapshot_id = self.transaction_manager.lock().begin_rw(context.snapshot_id)?;
200-
if min_snapshot_id > 0 {
201-
self.storage_engine.unlock(min_snapshot_id - 1);
202-
}
203-
Ok(Transaction::new(context, self))
198+
pub fn begin_ro(&self) -> Result<Transaction<&Self, RO>, TransactionError> {
199+
begin_ro(self)
204200
}
205201

206-
pub fn begin_ro(&self) -> Result<Transaction<'_, RO>, TransactionError> {
207-
let context = self.storage_engine.read_context();
208-
self.transaction_manager.lock().begin_ro(context.snapshot_id);
209-
Ok(Transaction::new(context, self))
202+
pub fn begin_rw(&self) -> Result<Transaction<&Self, RW>, TransactionError> {
203+
begin_rw(self)
210204
}
211205

212206
pub fn state_root(&self) -> B256 {
@@ -244,13 +238,32 @@ impl Database {
244238
}
245239
}
246240

241+
pub fn begin_ro<DB: Deref<Target = Database>>(
242+
db: DB,
243+
) -> Result<Transaction<DB, RO>, TransactionError> {
244+
let context = db.storage_engine.read_context();
245+
db.transaction_manager.lock().begin_ro(context.snapshot_id);
246+
Ok(Transaction::new(context, db))
247+
}
248+
249+
pub fn begin_rw<DB: Deref<Target = Database>>(
250+
db: DB,
251+
) -> Result<Transaction<DB, RW>, TransactionError> {
252+
let context = db.storage_engine.write_context();
253+
let min_snapshot_id = db.transaction_manager.lock().begin_rw(context.snapshot_id)?;
254+
if min_snapshot_id > 0 {
255+
db.storage_engine.unlock(min_snapshot_id - 1);
256+
}
257+
Ok(Transaction::new(context, db))
258+
}
259+
247260
#[cfg(test)]
248261
mod tests {
249262
use super::*;
250263
use crate::{account::Account, path::AddressPath};
251264
use alloy_primitives::{address, Address, U256};
252265
use alloy_trie::{EMPTY_ROOT_HASH, KECCAK_EMPTY};
253-
use std::fs;
266+
use std::{fs, sync::Arc};
254267
use tempdir::TempDir;
255268

256269
#[test]
@@ -315,13 +328,13 @@ mod tests {
315328
let address = address!("0xd8da6bf26964af9d7eed9e03e53415d37aa96045");
316329

317330
let account1 = Account::new(1, U256::from(100), EMPTY_ROOT_HASH, KECCAK_EMPTY);
318-
let mut tx = db.begin_rw().unwrap();
331+
let mut tx = begin_rw(&db).unwrap();
319332
tx.set_account(AddressPath::for_address(address), Some(account1.clone())).unwrap();
320333

321334
tx.commit().unwrap();
322335

323336
let account2 = Account::new(456, U256::from(123), EMPTY_ROOT_HASH, KECCAK_EMPTY);
324-
let mut tx = db.begin_rw().unwrap();
337+
let mut tx = begin_rw(&db).unwrap();
325338
tx.set_account(AddressPath::for_address(address), Some(account2.clone())).unwrap();
326339

327340
let mut ro_tx = db.begin_ro().unwrap();
@@ -512,4 +525,26 @@ mod tests {
512525
);
513526
}
514527
}
528+
529+
#[test]
530+
fn test_db_arc_tx() {
531+
let tmp_dir = TempDir::new("test_db").unwrap();
532+
let file_path = tmp_dir.path().join("test.db");
533+
let db = Database::create_new(&file_path).unwrap();
534+
535+
let db_arc = Arc::new(db);
536+
537+
let address = address!("0xd8da6bf26964af9d7eed9e03e53415d37aa96045");
538+
let mut tx = begin_rw(db_arc.clone()).unwrap();
539+
tx.set_account(
540+
AddressPath::for_address(address),
541+
Some(Account::new(1, U256::from(100), EMPTY_ROOT_HASH, KECCAK_EMPTY)),
542+
)
543+
.unwrap();
544+
tx.commit().unwrap();
545+
546+
let mut tx = begin_ro(db_arc).unwrap();
547+
let account = tx.get_account(AddressPath::for_address(address)).unwrap().unwrap();
548+
assert_eq!(account, Account::new(1, U256::from(100), EMPTY_ROOT_HASH, KECCAK_EMPTY));
549+
}
515550
}

src/transaction.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use alloy_trie::Nibbles;
1414
pub use error::TransactionError;
1515
pub use manager::TransactionManager;
1616
use sealed::sealed;
17-
use std::{collections::HashMap, fmt::Debug};
17+
use std::{collections::HashMap, fmt::Debug, ops::Deref, sync::Arc};
1818

1919
#[sealed]
2020
pub trait TransactionKind: Debug {}
@@ -34,21 +34,23 @@ impl TransactionKind for RO {}
3434
// Compile-time assertion to ensure that `Transaction` is `Send`
3535
const _: fn() = || {
3636
fn consumer<T: Send>() {}
37-
consumer::<Transaction<'_, RO>>();
38-
consumer::<Transaction<'_, RW>>();
37+
consumer::<Transaction<&Database, RO>>();
38+
consumer::<Transaction<&Database, RW>>();
39+
consumer::<Transaction<Arc<Database>, RO>>();
40+
consumer::<Transaction<Arc<Database>, RW>>();
3941
};
4042

4143
#[derive(Debug)]
42-
pub struct Transaction<'tx, K: TransactionKind> {
44+
pub struct Transaction<DB, K: TransactionKind> {
4345
committed: bool,
4446
context: TransactionContext,
45-
database: &'tx Database,
47+
database: DB,
4648
pending_changes: HashMap<Nibbles, Option<TrieValue>>,
4749
_marker: std::marker::PhantomData<K>,
4850
}
4951

50-
impl<'tx, K: TransactionKind> Transaction<'tx, K> {
51-
pub(crate) fn new(context: TransactionContext, database: &'tx Database) -> Self {
52+
impl<DB: Deref<Target = Database>, K: TransactionKind> Transaction<DB, K> {
53+
pub(crate) fn new(context: TransactionContext, database: DB) -> Self {
5254
Self {
5355
committed: false,
5456
context,
@@ -137,7 +139,7 @@ impl<'tx, K: TransactionKind> Transaction<'tx, K> {
137139
}
138140
}
139141

140-
impl Transaction<'_, RW> {
142+
impl<DB: Deref<Target = Database>> Transaction<DB, RW> {
141143
pub fn set_account(
142144
&mut self,
143145
address_path: AddressPath,
@@ -186,7 +188,7 @@ impl Transaction<'_, RW> {
186188
}
187189
}
188190

189-
impl Transaction<'_, RO> {
191+
impl<DB: Deref<Target = Database>> Transaction<DB, RO> {
190192
pub fn commit(mut self) -> Result<(), TransactionError> {
191193
let mut transaction_manager = self.database.transaction_manager.lock();
192194
transaction_manager.remove_tx(self.context.snapshot_id, false);
@@ -196,7 +198,7 @@ impl Transaction<'_, RO> {
196198
}
197199
}
198200

199-
impl<K: TransactionKind> Drop for Transaction<'_, K> {
201+
impl<DB, K: TransactionKind> Drop for Transaction<DB, K> {
200202
fn drop(&mut self) {
201203
// TODO: panic if the transaction is not committed
202204
}

0 commit comments

Comments
 (0)