From 31e7c618a68e06431cdcd5a526e2ea4c0cd0f639 Mon Sep 17 00:00:00 2001 From: Alex Su <7680266+alexytsu@users.noreply.github.com> Date: Wed, 15 Feb 2023 17:21:13 +1100 Subject: [PATCH] Ranged HAMT iteration specified via keys (#1665) --- .gitignore | 1 + ipld/hamt/src/hamt.rs | 69 ++++++++++++++ ipld/hamt/src/node.rs | 103 ++++++++++++++++++++ ipld/hamt/tests/hamt_tests.rs | 173 ++++++++++++++++++++++++++++++++++ 4 files changed, 346 insertions(+) diff --git a/.gitignore b/.gitignore index 53d59825b..0c0635787 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ target !testing/integration/tests/assets/* testing/conformance/traces .idea/ +.vscode/ lcov.info diff --git a/ipld/hamt/src/hamt.rs b/ipld/hamt/src/hamt.rs index e07ad83bf..fbe3332d1 100644 --- a/ipld/hamt/src/hamt.rs +++ b/ipld/hamt/src/hamt.rs @@ -13,6 +13,7 @@ use multihash::Code; use serde::de::DeserializeOwned; use serde::{Serialize, Serializer}; +use crate::hash_bits::HashBits; use crate::node::Node; use crate::{Config, Error, Hash, HashAlgorithm, Sha256}; @@ -362,6 +363,74 @@ where self.root.for_each(self.store.borrow(), &mut f) } + /// Iterates over each KV in the Hamt and runs a function on the values. If starting key is + /// provided, iteration will start from that key. If max is provided, iteration will stop after + /// max number of items have been traversed. The number of items that were traversed is + /// returned. If there are more items in the Hamt after max items have been traversed, the key + /// of the next item will be returned. + /// + /// This function will constrain all values to be of the same type + /// + /// # Examples + /// + /// ``` + /// use fvm_ipld_hamt::Hamt; + /// + /// let store = fvm_ipld_blockstore::MemoryBlockstore::default(); + /// + /// let mut map: Hamt<_, _, u64> = Hamt::new(store); + /// map.set(1, 1).unwrap(); + /// map.set(2, 2).unwrap(); + /// map.set(3, 3).unwrap(); + /// map.set(4, 4).unwrap(); + /// + /// let mut numbers = vec![]; + /// + /// map.for_each_ranged(None, None, |_, v: &u64| { + /// numbers.push(*v); + /// Ok(()) + /// }).unwrap(); + /// + /// let mut subset = vec![]; + /// + /// let (_, next_key) = map.for_each_ranged(Some(&numbers[0]), Some(2), |_, v: &u64| { + /// subset.push(*v); + /// Ok(()) + /// }).unwrap(); + /// + /// assert_eq!(subset, numbers[..2]); + /// assert_eq!(next_key.unwrap(), numbers[2]); + /// ``` + #[inline] + pub fn for_each_ranged( + &self, + starting_key: Option<&Q>, + max: Option, + mut f: F, + ) -> Result<(usize, Option), Error> + where + K: Borrow + Clone, + Q: Eq + Hash, + V: DeserializeOwned, + F: FnMut(&K, &V) -> anyhow::Result<()>, + { + match starting_key { + Some(key) => { + let hash = H::hash(key); + self.root.for_each_ranged( + self.store.borrow(), + &self.conf, + Some((HashBits::new(&hash), key)), + max, + &mut f, + ) + } + None => self + .root + .for_each_ranged(self.store.borrow(), &self.conf, None, max, &mut f), + } + } + /// Consumes this HAMT and returns the Blockstore it owns. pub fn into_store(self) -> BS { self.store diff --git a/ipld/hamt/src/node.rs b/ipld/hamt/src/node.rs index 66e1683e9..81bad7d67 100644 --- a/ipld/hamt/src/node.rs +++ b/ipld/hamt/src/node.rs @@ -174,6 +174,109 @@ where Ok(()) } + pub(crate) fn for_each_ranged( + &self, + store: &S, + conf: &Config, + mut starting_cursor: Option<(HashBits, &Q)>, + limit: Option, + f: &mut F, + ) -> Result<(usize, Option), Error> + where + K: Borrow + Clone, + Q: Eq + Hash, + F: FnMut(&K, &V) -> anyhow::Result<()>, + S: Blockstore, + { + // determine which subtree the starting_cursor is in + let cindex = match starting_cursor { + Some((ref mut bits, _)) => { + let idx = bits.next(conf.bit_width)?; + self.index_for_bit_pos(idx) + } + None => 0, + }; + + let mut traversed_count = 0; + + // skip exploration of subtrees that are before the subtree which contains the cursor + for p in &self.pointers[cindex..] { + match p { + Pointer::Link { cid, cache } => { + if let Some(cached_node) = cache.get() { + let (traversed, key) = cached_node.for_each_ranged( + store, + conf, + starting_cursor.take(), + limit.map(|l| l.checked_sub(traversed_count).unwrap()), + f, + )?; + traversed_count += traversed; + if limit.map_or(false, |l| traversed_count >= l) && key.is_some() { + return Ok((traversed_count, key)); + } + } else { + let node = if let Some(node) = store.get_cbor(cid)? { + node + } else { + #[cfg(not(feature = "ignore-dead-links"))] + return Err(Error::CidNotFound(cid.to_string())); + + #[cfg(feature = "ignore-dead-links")] + continue; + }; + + // Ignore error intentionally, the cache value will always be the same + let cache_node = cache.get_or_init(|| node); + let (traversed, key) = cache_node.for_each_ranged( + store, + conf, + starting_cursor.take(), + limit.map(|l| l.checked_sub(traversed_count).unwrap()), + f, + )?; + traversed_count += traversed; + if limit.map_or(false, |l| traversed_count >= l) && key.is_some() { + return Ok((traversed_count, key)); + } + } + } + Pointer::Dirty(node) => { + let (traversed, key) = node.for_each_ranged( + store, + conf, + starting_cursor.take(), + limit.map(|l| l.checked_sub(traversed_count).unwrap()), + f, + )?; + traversed_count += traversed; + if limit.map_or(false, |l| traversed_count >= l) && key.is_some() { + return Ok((traversed_count, key)); + } + } + Pointer::Values(kvs) => { + for kv in kvs { + if limit.map_or(false, |l| traversed_count == l) { + // we have already found all requested items, return the key of the next item + return Ok((traversed_count, Some(kv.0.clone()))); + } else if starting_cursor.map_or(false, |(_, key)| key.eq(kv.0.borrow())) { + // mark that we have arrived at the starting cursor + starting_cursor = None + } + + if starting_cursor.is_none() { + // have already passed the start cursor + f(&kv.0, kv.1.borrow())?; + traversed_count += 1; + } + } + } + } + } + + Ok((traversed_count, None)) + } + /// Search for a key. fn search( &self, diff --git a/ipld/hamt/tests/hamt_tests.rs b/ipld/hamt/tests/hamt_tests.rs index 1d213006f..9c7312d37 100644 --- a/ipld/hamt/tests/hamt_tests.rs +++ b/ipld/hamt/tests/hamt_tests.rs @@ -386,6 +386,168 @@ fn for_each(factory: HamtFactory, stats: Option, mut cids: CidChecker) } } +fn for_each_ranged(factory: HamtFactory, stats: Option, mut cids: CidChecker) { + let mem = MemoryBlockstore::default(); + let store = TrackingBlockstore::new(&mem); + + let mut hamt: Hamt<_, usize> = factory.new_with_bit_width(&store, 5); + + const RANGE: usize = 200; + for i in 0..RANGE { + hamt.set(tstring(i), i).unwrap(); + } + + // collect all KV paris by iterating through the entire hamt + let mut kvs = Vec::new(); + hamt.for_each(|k, v| { + assert_eq!(k, &tstring(v)); + kvs.push((k.clone(), *v)); + Ok(()) + }) + .unwrap(); + + // Iterate through the array, requesting pages of different sizes + for page_size in 0..RANGE { + let mut kvs_variable_page = Vec::new(); + let (num_traversed, next_key) = hamt + .for_each_ranged::(None, Some(page_size), |k, v| { + kvs_variable_page.push((k.clone(), *v)); + Ok(()) + }) + .unwrap(); + + assert_eq!(num_traversed, page_size); + assert_eq!(kvs_variable_page.len(), num_traversed); + assert_eq!(next_key.unwrap(), kvs[page_size].0); + + // Items iterated over should match the ordering of for_each + assert_eq!(kvs_variable_page, kvs[..page_size]); + } + + // Iterate through the array, requesting more items than are remaining + let (num_traversed, next_key) = hamt + .for_each_ranged::(None, Some(RANGE + 10), |_k, _v| Ok(())) + .unwrap(); + assert_eq!(num_traversed, RANGE); + assert_eq!(next_key, None); + + // Iterate through it again starting at a certain key + for start_at in 0..RANGE as usize { + let mut kvs_variable_start = Vec::new(); + let (num_traversed, next_key) = hamt + .for_each_ranged(Some(&kvs[start_at].0), None, |k, v| { + assert_eq!(k, &tstring(v)); + kvs_variable_start.push((k.clone(), *v)); + + Ok(()) + }) + .unwrap(); + + // No limit specified, iteration should be exhaustive + assert_eq!(next_key, None); + assert_eq!(num_traversed, kvs_variable_start.len()); + assert_eq!(kvs_variable_start.len(), kvs.len() - start_at,); + + // Items iterated over should match the ordering of for_each + assert_eq!(kvs_variable_start, kvs[start_at..]); + } + + // Chain paginated requests to iterate over entire HAMT + { + let mut kvs_paginated_requests = Vec::new(); + let mut iterations = 0; + let mut cursor: Option = None; + + // Request all items in pages of 20 items each + const PAGE_SIZE: usize = 20; + loop { + let (page_size, next) = match cursor { + Some(ref start) => hamt + .for_each_ranged::(Some(start), Some(PAGE_SIZE), |k, v| { + kvs_paginated_requests.push((k.clone(), *v)); + Ok(()) + }) + .unwrap(), + None => hamt + .for_each_ranged::(None, Some(PAGE_SIZE), |k, v| { + kvs_paginated_requests.push((k.clone(), *v)); + Ok(()) + }) + .unwrap(), + }; + iterations += 1; + assert_eq!(page_size, PAGE_SIZE); + assert_eq!(kvs_paginated_requests.len(), iterations * PAGE_SIZE); + + if next.is_none() { + break; + } else { + assert_eq!(next.clone().unwrap(), kvs[(iterations * PAGE_SIZE)].0); + cursor = next; + } + } + + // should have retrieved all key value pairs in the same order + assert_eq!(kvs_paginated_requests.len(), kvs.len(), "{}", iterations); + assert_eq!(kvs_paginated_requests, kvs); + // should have used the expected number of iterations + assert_eq!(iterations, RANGE / PAGE_SIZE); + } + + let c = hamt.flush().unwrap(); + cids.check_next(c); + + // Chain paginated requests over a HAMT with committed nodes + let mut hamt: Hamt<_, usize> = factory.load_with_bit_width(&c, &store, 5).unwrap(); + { + let mut kvs_paginated_requests = Vec::new(); + let mut iterations = 0; + let mut cursor: Option = None; + + // Request all items in pages of 20 items each + const PAGE_SIZE: usize = 20; + loop { + let (page_size, next) = match cursor { + Some(ref start) => hamt + .for_each_ranged::(Some(start), Some(PAGE_SIZE), |k, v| { + kvs_paginated_requests.push((k.clone(), *v)); + Ok(()) + }) + .unwrap(), + None => hamt + .for_each_ranged::(None, Some(PAGE_SIZE), |k, v| { + kvs_paginated_requests.push((k.clone(), *v)); + Ok(()) + }) + .unwrap(), + }; + iterations += 1; + assert_eq!(page_size, PAGE_SIZE); + assert_eq!(kvs_paginated_requests.len(), iterations * PAGE_SIZE); + + if next.is_none() { + break; + } else { + assert_eq!(next.clone().unwrap(), kvs[(iterations * PAGE_SIZE)].0); + cursor = next; + } + } + + // should have retrieved all key value pairs in the same order + assert_eq!(kvs_paginated_requests.len(), kvs.len(), "{}", iterations); + assert_eq!(kvs_paginated_requests, kvs); + // should have used the expected number of iterations + assert_eq!(iterations, RANGE / PAGE_SIZE); + } + + let c = hamt.flush().unwrap(); + cids.check_next(c); + + if let Some(stats) = stats { + assert_eq!(*store.stats.borrow(), stats); + } +} + #[cfg(feature = "identity")] fn add_and_remove_keys( bit_width: u32, @@ -823,6 +985,17 @@ mod test_default { super::for_each(HamtFactory::default(), Some(stats), cids); } + #[test] + fn for_each_ranged() { + #[rustfmt::skip] + let stats = BSStats {r: 30, w: 30, br: 2895, bw: 2895}; + let cids = CidChecker::new(vec![ + "bafy2bzacedy4ypl2vedhdqep3llnwko6vrtfiys5flciz2f3c55pl4whlhlqm", + "bafy2bzacedy4ypl2vedhdqep3llnwko6vrtfiys5flciz2f3c55pl4whlhlqm", + ]); + super::for_each_ranged(HamtFactory::default(), Some(stats), cids); + } + #[test] fn clean_child_ordering() { #[rustfmt::skip]