Skip to content

Commit

Permalink
[trie] Unify implementation of DiskTrieIterator and MemTrieIterator (#…
Browse files Browse the repository at this point in the history
…12813)

We had separate implementations for DiskTrieIterator and MemTrieIterator
which were originally copy pasted from the same source. This PR unifies
the implementation while exposing a simple interface to get nodes and
values from trie.

Part of issue #12361
  • Loading branch information
shreyan-gupta authored Jan 28, 2025
1 parent e6bb098 commit 57b9b81
Show file tree
Hide file tree
Showing 14 changed files with 585 additions and 770 deletions.
1 change: 0 additions & 1 deletion core/store/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
extern crate core;

use crate::db::{refcount, DBIterator, DBOp, DBSlice, DBTransaction, Database, StoreStatistics};
pub use crate::trie::iterator::{TrieIterator, TrieTraversalItem};
pub use crate::trie::update::{TrieUpdate, TrieUpdateIterator, TrieUpdateValuePtr};
pub use crate::trie::{
estimator, resharding_v2, ApplyStatePartResult, KeyForStateChanges, KeyLookupMode, NibbleSlice,
Expand Down
500 changes: 46 additions & 454 deletions core/store/src/trie/iterator.rs

Large diffs are not rendered by default.

310 changes: 41 additions & 269 deletions core/store/src/trie/mem/iter.rs
Original file line number Diff line number Diff line change
@@ -1,288 +1,60 @@
//! Iterator that traverses a memtrie in key order.
//!
//! This is essentially a copy of the `DiskTrieIterator`, with the following notable differences:
//! - It doesn't support extra options like remembering nodes or pruning;
//! - It uses None to represent an "empty" placeholder node rather than `TrieNode::Empty`;
//! - MemTrieNodeView splits Branch and BranchWithValue into separate variants, whereas TrieNode
//! handles them in a single variant with an optional value field, but the iteration logic
//! remains the same.
//! - Memtrie code paths don't return any errors, except when looking up the value from the State
//! column.
//!
//! Testing of the `MemTrieIterator` is done together by tests of `DiskTrieIterator`.
use near_primitives::errors::StorageError;
use near_primitives::hash::CryptoHash;
use near_primitives::state::FlatStateValue;

use crate::trie::iterator::TrieItem;
use crate::trie::ops::interface::GenericTrieInternalStorage;
use crate::trie::ops::iter::TrieIteratorImpl;
use crate::trie::OptimizedValueRef;
use crate::{NibbleSlice, Trie};
use crate::Trie;

use super::arena::hybrid::HybridArenaMemory;
use super::arena::ArenaMemory;
use super::node::{MemTrieNodePtr, MemTrieNodeView};
use super::arena::Arena;
use super::memtrie_update::MemTrieNode;
use super::memtries::MemTries;
use super::node::MemTrieNodeId;

/// Crumb is a piece of trie iteration state. It describes a node on the trail and processing status of that node.
#[derive(Debug)]
struct Crumb<'a, M: ArenaMemory> {
node: Option<MemTrieNodeView<'a, M>>,
status: CrumbStatus,
prefix_boundary: bool,
}

/// The status of processing of a node during trie iteration.
/// Each node is processed in the following order:
/// Entering -> At -> AtChild(0) -> ... -> AtChild(15) -> Exiting
#[derive(Debug, Clone, Copy)]
enum CrumbStatus {
Entering,
At,
AtChild(u8),
Exiting,
}

impl<'a, M: ArenaMemory> Crumb<'a, M> {
fn increment(&mut self) {
if self.prefix_boundary {
self.status = CrumbStatus::Exiting;
return;
}
self.status = match (&self.status, &self.node) {
(_, None) => CrumbStatus::Exiting,
(&CrumbStatus::Entering, _) => CrumbStatus::At,
(&CrumbStatus::At, Some(MemTrieNodeView::Branch { .. })) => CrumbStatus::AtChild(0),
(&CrumbStatus::At, Some(MemTrieNodeView::BranchWithValue { .. })) => {
CrumbStatus::AtChild(0)
}
(&CrumbStatus::AtChild(x), Some(MemTrieNodeView::Branch { .. })) if x < 15 => {
CrumbStatus::AtChild(x + 1)
}
(&CrumbStatus::AtChild(x), Some(MemTrieNodeView::BranchWithValue { .. })) if x < 15 => {
CrumbStatus::AtChild(x + 1)
}
_ => CrumbStatus::Exiting,
}
}
}

/// Trie iteration is done using a stack based approach.
///
/// There are two stacks that we track while iterating: the trail and the key_nibbles.
/// The trail is a vector of trie nodes on the path from root node to the node that is
/// currently being processed together with processing status - the Crumb.
///
/// The key_nibbles is a vector of nibbles from the state root node to the node that is
/// currently being processed.
///
/// The trail and the key_nibbles may have different lengths e.g. an extension trie node
/// will add only a single item to the trail but may add multiple nibbles to the key_nibbles.
pub type STMemTrieIterator<'a> = MemTrieIterator<'a, HybridArenaMemory>;

pub struct MemTrieIterator<'a, M: ArenaMemory> {
root: Option<MemTrieNodePtr<'a, M>>,
/// Tiny wrapper around `MemTries` and `Trie` to provide `GenericTrieInternalStorage` implementation.
pub struct MemTrieIteratorInner<'a> {
memtrie: &'a MemTries,
trie: &'a Trie,
trail: Vec<Crumb<'a, M>>,
key_nibbles: Vec<u8>,
}

impl<'a, M: ArenaMemory> MemTrieIterator<'a, M> {
/// Create a new iterator.
pub fn new(root: Option<MemTrieNodePtr<'a, M>>, trie: &'a Trie) -> Self {
let mut r = MemTrieIterator { root, trie, trail: Vec::new(), key_nibbles: Vec::new() };
r.descend_into_node(root);
r
impl<'a> MemTrieIteratorInner<'a> {
pub fn new(memtrie: &'a MemTries, trie: &'a Trie) -> Self {
Self { memtrie, trie }
}
}

/// Position the iterator on the first element with key >= `key`.
pub fn seek_prefix<K: AsRef<[u8]>>(&mut self, key: K) {
self.seek_nibble_slice(NibbleSlice::new(key.as_ref()), true);
}

/// Returns the hash of the last node.
pub(crate) fn seek_nibble_slice(
&mut self,
mut key: NibbleSlice<'_>,
is_prefix_seek: bool,
) -> Option<MemTrieNodePtr<'a, M>> {
self.trail.clear();
self.key_nibbles.clear();
// Checks if a key in an extension or leaf matches our search query.
//
// When doing prefix seek, this checks whether `key` is a prefix of
// `ext_key`. When doing regular range seek, this checks whether `key`
// is no greater than `ext_key`. If those conditions aren’t met, the
// node with `ext_key` should not match our query.
let check_ext_key = |key: &NibbleSlice, ext_key: &NibbleSlice| {
if is_prefix_seek {
ext_key.starts_with(key)
} else {
ext_key >= key
}
};

let mut ptr = self.root;
let mut prev_prefix_boundary = &mut false;
loop {
*prev_prefix_boundary = is_prefix_seek;
self.descend_into_node(ptr);
let Crumb { status, node, prefix_boundary } = self.trail.last_mut().unwrap();
prev_prefix_boundary = prefix_boundary;
match &node {
None => break,
Some(MemTrieNodeView::Leaf { extension, .. }) => {
let existing_key = NibbleSlice::from_encoded(extension).0;
if !check_ext_key(&key, &existing_key) {
self.key_nibbles.extend(existing_key.iter());
*status = CrumbStatus::Exiting;
}
break;
}
Some(MemTrieNodeView::Branch { children, .. })
| Some(MemTrieNodeView::BranchWithValue { children, .. }) => {
if key.is_empty() {
break;
}
let idx = key.at(0);
self.key_nibbles.push(idx);
*status = CrumbStatus::AtChild(idx);
if let Some(child) = children.get(idx as usize) {
ptr = Some(child);
key = key.mid(1);
} else {
*prefix_boundary = is_prefix_seek;
break;
}
}
Some(MemTrieNodeView::Extension { extension, child, .. }) => {
let existing_key = NibbleSlice::from_encoded(extension).0;
if key.starts_with(&existing_key) {
key = key.mid(existing_key.len());
ptr = Some(*child);
*status = CrumbStatus::At;
self.key_nibbles.extend(existing_key.iter());
} else {
if !check_ext_key(&key, &existing_key) {
*status = CrumbStatus::Exiting;
self.key_nibbles.extend(existing_key.iter());
}
break;
}
}
}
impl<'a> GenericTrieInternalStorage<MemTrieNodeId, FlatStateValue> for MemTrieIteratorInner<'a> {
fn get_root(&self) -> Option<MemTrieNodeId> {
let root_hash = self.trie.root;
if root_hash == CryptoHash::default() {
return None;
}
ptr
}

/// Fetches node by its ptr and adds it to the trail.
///
/// The node is stored as the last [`Crumb`] in the trail.
fn descend_into_node(&mut self, ptr: Option<MemTrieNodePtr<'a, M>>) {
let node = ptr.map(|ptr| {
let view = ptr.view();
if let Some(recorder) = &self.trie.recorder {
let raw_node_serialized =
borsh::to_vec(&view.to_raw_trie_node_with_size()).unwrap();
recorder.borrow_mut().record(&view.node_hash(), raw_node_serialized.into());
}
view
});
self.trail.push(Crumb { status: CrumbStatus::Entering, node, prefix_boundary: false });
let root_node = self.memtrie.get_root(&root_hash).unwrap();
let root_ptr = root_node.id();
Some(root_ptr)
}

fn key(&self) -> Vec<u8> {
let mut result = <Vec<u8>>::with_capacity(self.key_nibbles.len() / 2);
for i in (1..self.key_nibbles.len()).step_by(2) {
result.push(self.key_nibbles[i - 1] * 16 + self.key_nibbles[i]);
fn get_and_record_node(&self, node: MemTrieNodeId) -> Result<MemTrieNode, StorageError> {
let view = node.as_ptr(self.memtrie.arena.memory()).view();
if let Some(recorder) = &self.trie.recorder {
let raw_node_serialized = borsh::to_vec(&view.to_raw_trie_node_with_size()).unwrap();
recorder.borrow_mut().record(&view.node_hash(), raw_node_serialized.into());
}
result
let node = MemTrieNode::from_existing_node_view(view);
Ok(node)
}

/// Calculates the next step of the iteration.
fn iter_step(&mut self) -> Option<IterStep<'a, M>> {
let last = self.trail.last_mut()?;
last.increment();
Some(match (last.status, &last.node) {
(CrumbStatus::Exiting, n) => {
match n {
Some(MemTrieNodeView::Leaf { extension, .. })
| Some(MemTrieNodeView::Extension { extension, .. }) => {
let existing_key = NibbleSlice::from_encoded(extension).0;
let l = self.key_nibbles.len();
self.key_nibbles.truncate(l - existing_key.len());
}
Some(MemTrieNodeView::Branch { .. })
| Some(MemTrieNodeView::BranchWithValue { .. }) => {
self.key_nibbles.pop();
}
_ => {}
}
IterStep::PopTrail
}
(CrumbStatus::At, Some(MemTrieNodeView::BranchWithValue { value, .. })) => {
IterStep::Value(value.to_optimized_value_ref())
}
(CrumbStatus::At, Some(MemTrieNodeView::Branch { .. })) => IterStep::Continue,
(CrumbStatus::At, Some(MemTrieNodeView::Leaf { extension, value })) => {
let key = NibbleSlice::from_encoded(extension).0;
self.key_nibbles.extend(key.iter());
IterStep::Value(value.to_optimized_value_ref())
}
(CrumbStatus::At, Some(MemTrieNodeView::Extension { extension, child, .. })) => {
let key = NibbleSlice::from_encoded(extension).0;
self.key_nibbles.extend(key.iter());
IterStep::Descend(*child)
}
(CrumbStatus::AtChild(i), Some(MemTrieNodeView::Branch { children, .. }))
| (CrumbStatus::AtChild(i), Some(MemTrieNodeView::BranchWithValue { children, .. })) => {
if i == 0 {
self.key_nibbles.push(0);
}
if let Some(ref child) = children.get(i as usize) {
if i != 0 {
*self.key_nibbles.last_mut().expect("Pushed child value before") = i;
}
IterStep::Descend(*child)
} else {
IterStep::Continue
}
}
_ => panic!("Should never see Entering or AtChild without a Branch here."),
})
fn get_and_record_value(&self, value_ref: FlatStateValue) -> Result<Vec<u8>, StorageError> {
let optimized_value_ref = OptimizedValueRef::from_flat_value(value_ref);
let value = self.trie.deref_optimized(&optimized_value_ref)?;
if let Some(recorder) = &self.trie.recorder {
let value_hash = optimized_value_ref.into_value_ref().hash;
recorder.borrow_mut().record(&value_hash, value.clone().into());
};
Ok(value)
}
}

#[derive(Debug)]
enum IterStep<'a, M: ArenaMemory> {
Continue,
PopTrail,
Descend(MemTrieNodePtr<'a, M>),
Value(OptimizedValueRef),
}

impl<'a, M: ArenaMemory> Iterator for MemTrieIterator<'a, M> {
type Item = Result<TrieItem, StorageError>;

fn next(&mut self) -> Option<Self::Item> {
loop {
let iter_step = self.iter_step()?;

match iter_step {
IterStep::Continue => {}
IterStep::PopTrail => {
self.trail.pop();
}
IterStep::Descend(ptr) => {
self.descend_into_node(Some(ptr));
}
IterStep::Value(value_ref) => {
let value = self.trie.deref_optimized(&value_ref);
if let Ok(value) = &value {
if let Some(recorder) = &self.trie.recorder {
let value_hash = value_ref.into_value_ref().hash;
recorder.borrow_mut().record(&value_hash, value.clone().into());
}
}
return Some(value.map(|value| (self.key(), value)));
}
}
}
}
}
pub type STMemTrieIterator<'a> =
TrieIteratorImpl<MemTrieNodeId, FlatStateValue, MemTrieIteratorInner<'a>>;
17 changes: 10 additions & 7 deletions core/store/src/trie/mem/memtrie_update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,9 @@ pub type MemTrieNodeWithSize = GenericTrieNodeWithSize<MemTrieNodeId, FlatStateV

pub type UpdatedMemTrieNodeWithSize = GenericUpdatedTrieNodeWithSize<MemTrieNodeId, FlatStateValue>;

impl MemTrieNodeWithSize {
/// Converts an existing in-memory trie node into an updated one that is
/// equivalent.
impl MemTrieNode {
pub fn from_existing_node_view<'a, M: ArenaMemory>(view: MemTrieNodeView<'a, M>) -> Self {
let memory_usage = view.memory_usage();
let node = match view {
match view {
MemTrieNodeView::Leaf { extension, value } => MemTrieNode::Leaf {
extension: extension.to_vec().into_boxed_slice(),
value: value.to_flat_value(),
Expand All @@ -52,8 +49,7 @@ impl MemTrieNodeWithSize {
extension: extension.to_vec().into_boxed_slice(),
child: child.id(),
},
};
Self { node, memory_usage }
}
}

fn convert_children_to_updated<'a, M: ArenaMemory>(
Expand All @@ -69,6 +65,13 @@ impl MemTrieNodeWithSize {
}
}

impl MemTrieNodeWithSize {
pub fn from_existing_node_view<'a, M: ArenaMemory>(view: MemTrieNodeView<'a, M>) -> Self {
let memory_usage = view.memory_usage();
Self { node: MemTrieNode::from_existing_node_view(view), memory_usage }
}
}

/// Allows using in-memory tries to construct the trie node changes entirely
/// (for both in-memory and on-disk updates) because it's much faster.
pub enum TrackingMode<'a> {
Expand Down
Loading

0 comments on commit 57b9b81

Please sign in to comment.