From c1000501c25e63f1ddb2b71f2f66225f2ec72111 Mon Sep 17 00:00:00 2001 From: cavemanloverboy Date: Sat, 16 Dec 2023 20:22:51 -0800 Subject: [PATCH 1/4] add filter functionality --- examples/bench_filter.rs | 143 +++++++++++++++++++++++++++++++++++++++ examples/filter.rs | 62 +++++++++++++++++ src/red_black_tree.rs | 113 +++++++++++++++++++++++++++++++ 3 files changed, 318 insertions(+) create mode 100644 examples/bench_filter.rs create mode 100644 examples/filter.rs diff --git a/examples/bench_filter.rs b/examples/bench_filter.rs new file mode 100644 index 0000000..b11b3ca --- /dev/null +++ b/examples/bench_filter.rs @@ -0,0 +1,143 @@ +use std::time::Instant; + +use sokoban::{NodeAllocatorMap, RedBlackTree}; + +#[derive(Default, PartialEq, PartialOrd, Clone, Copy, Debug, Ord, Eq)] +#[repr(C)] +/// Price-time priority key +struct Key { + price: u64, + id: u64, +} +impl Key { + fn rand() -> Self { + Self { + price: rand::random(), + id: rand::random(), + } + } +} +unsafe impl bytemuck::Pod for Key {} +unsafe impl bytemuck::Zeroable for Key {} + +#[derive(Default, PartialEq, PartialOrd, Clone, Copy)] +#[repr(C)] +/// Mock limit order key +struct Entry { + lots: u64, + maker: u32, + _pad: [u8; 4], +} + +impl Entry { + fn rand() -> Self { + Entry { + lots: 10, + maker: 0, + _pad: [0; 4], + } + } + fn rand_with_maker(idx: u32) -> Self { + assert!(idx > 0); // 0 is reserved + Entry { + lots: 10, + maker: idx, + _pad: [0; 4], + } + } +} + +unsafe impl bytemuck::Pod for Entry {} +unsafe impl bytemuck::Zeroable for Entry {} + +fn main() { + const ITERS: usize = 1000; + const WARMUP_ITERS: usize = 100; + + const TARGET_MAKER: u32 = 5; + + const TREE_SIZE: usize = 4096; + const REMOVE: usize = 256; + + let mut total_remove_micros = 0; + for i in 0..ITERS + WARMUP_ITERS { + // Setup + let mut tree = RedBlackTree::::new(); + for i in 0..TREE_SIZE { + if i < REMOVE { + tree.insert(Key::rand(), Entry::rand_with_maker(TARGET_MAKER)); + } else { + tree.insert(Key::rand(), Entry::rand()); + } + } + + // Start filter + let timer = Instant::now(); + let keys = tree + .iter() + .filter(|(_key, entry)| entry.maker == TARGET_MAKER) + .map(|(key, _)| *key) + .collect::>(); + for key in keys { + tree.remove(&key); + } + if i > WARMUP_ITERS { + total_remove_micros += timer.elapsed().as_micros(); + } + assert_eq!(tree.len(), TREE_SIZE - REMOVE); + } + println!("average id + remove: {total_remove_micros} micros"); + + let mut total_drain_alloc_micros = 0; + for i in 0..ITERS + WARMUP_ITERS { + // Setup + let mut tree = RedBlackTree::::new(); + for i in 0..TREE_SIZE { + if i < REMOVE { + tree.insert(Key::rand(), Entry::rand_with_maker(TARGET_MAKER)); + } else { + tree.insert(Key::rand(), Entry::rand()); + } + } + + // Start filter + let timer = Instant::now(); + drop( + tree.drain_filter( + #[inline(always)] + |_k, v| v.maker == TARGET_MAKER, + ) + .collect::>(), + ); + if i > WARMUP_ITERS { + total_drain_alloc_micros += timer.elapsed().as_micros(); + } + assert_eq!(tree.len(), TREE_SIZE - REMOVE); + } + println!("average drain_alloc: {total_drain_alloc_micros} micros"); + + let mut total_drain_micros = 0; + for i in 0..ITERS + WARMUP_ITERS { + // Setup + let mut tree = RedBlackTree::::new(); + for i in 0..TREE_SIZE { + if i < REMOVE { + tree.insert(Key::rand(), Entry::rand_with_maker(TARGET_MAKER)); + } else { + tree.insert(Key::rand(), Entry::rand()); + } + } + + // Start filter + let timer = Instant::now(); + for _x in tree.drain_filter( + #[inline(always)] + |_k, v| v.maker == TARGET_MAKER, + ) {} + if i > WARMUP_ITERS { + total_drain_micros += timer.elapsed().as_micros(); + } + assert_eq!(tree.len(), 4096 - REMOVE); + } + println!("average drain: {total_drain_micros} micros"); +} diff --git a/examples/filter.rs b/examples/filter.rs new file mode 100644 index 0000000..200c5be --- /dev/null +++ b/examples/filter.rs @@ -0,0 +1,62 @@ +use sokoban::{NodeAllocatorMap, RedBlackTree}; + +fn main() { + { + let mut tree = RedBlackTree::::new(); + tree.insert(0, 5); // this + tree.insert(1, 5); + tree.insert(2, 0); // this + tree.insert(3, 5); + tree.insert(4, 0); // this + tree.insert(5, 5); + tree.insert(6, 5); + tree.insert(7, 0); // this + + println!("initial elements:"); + for x in tree.iter() { + println!("initial node({}) {} {}", tree.get_addr(&x.0), x.0, x.1); + } + + println!("\n Removing nodes"); + for x in tree.drain_filter(my_predicate) { + println!("removed node {} {}", x.0, x.1); + } + + println!("\n remaining elements:"); + for x in tree.iter() { + println!("remaining node({}) {} {}", tree.get_addr(&x.0), x.0, x.1); + } + } + + // Identical, but uses filter which allocates a vector for the key value pairs + { + let mut tree = RedBlackTree::::new(); + tree.insert(0, 5); // this + tree.insert(1, 5); + tree.insert(2, 0); // this + tree.insert(3, 5); + tree.insert(4, 0); // this + tree.insert(5, 5); + tree.insert(6, 5); + tree.insert(7, 0); // this + + println!("initial elements:"); + for x in tree.iter() { + println!("initial node({}) {} {}", tree.get_addr(&x.0), x.0, x.1); + } + + println!("\n Removing nodes"); + for x in tree.filter(my_predicate) { + println!("removed node {} {}", x.0, x.1); + } + + println!("\n remaining elements:"); + for x in tree.iter() { + println!("remaining node({}) {} {}", tree.get_addr(&x.0), x.0, x.1); + } + } +} + +fn my_predicate(key: &u32, value: &u32) -> bool { + (*key == 0) | (*value == 0) +} diff --git a/src/red_black_tree.rs b/src/red_black_tree.rs index a3b0116..c148cbc 100644 --- a/src/red_black_tree.rs +++ b/src/red_black_tree.rs @@ -735,6 +735,32 @@ impl< terminated: false, } } + + pub fn filter<'a, P: for<'p> Fn(&'p K, &'p V) -> bool>( + &'a mut self, + predicate: P, + ) -> Vec<(K, V)> { + self.drain_filter(predicate).collect::>() + } + + pub fn drain_filter<'a, P: for<'p> Fn(&'p K, &'p V) -> bool>( + &'a mut self, + predicate: P, + ) -> RedBlackTreeDrainFilter<'a, K, V, P, MAX_SIZE> { + let node = self.root; + RedBlackTreeDrainFilter { + tree: self, + fwd_stack: vec![], + fwd_ptr: node, + fwd_node: None, + _rev_stack: vec![], + _rev_ptr: node, + rev_node: None, + terminated: false, + remove: vec![], + predicate, + } + } } impl< @@ -933,6 +959,93 @@ impl< } } +pub struct RedBlackTreeDrainFilter< + 'a, + K: Debug + PartialOrd + Ord + Copy + Clone + Default + Pod + Zeroable, + V: Default + Copy + Clone + Pod + Zeroable, + P: for<'p> Fn(&'p K, &'p V) -> bool, + const MAX_SIZE: usize, +> { + tree: &'a mut RedBlackTree, + fwd_stack: Vec, + fwd_ptr: u32, + fwd_node: Option, + _rev_stack: Vec, + _rev_ptr: u32, + rev_node: Option, + terminated: bool, + /// Keeps addr's of nodes with predicate = true, to remove upon dropping. + /// It is possible to instead drop during iteration and update fwd_ptr & fwd_stack + remove: Vec, + predicate: P, +} + +impl< + 'a, + K: Debug + PartialOrd + Ord + Copy + Clone + Default + Pod + Zeroable, + V: Default + Copy + Clone + Pod + Zeroable, + P: for<'p> Fn(&'p K, &'p V) -> bool, + const MAX_SIZE: usize, + > Drop for RedBlackTreeDrainFilter<'a, K, V, P, MAX_SIZE> +{ + fn drop(&mut self) { + for node_index in core::mem::take(&mut self.remove) { + self.tree._remove_tree_node(node_index); + } + } +} + +impl< + 'a, + K: Debug + PartialOrd + Ord + Copy + Clone + Default + Pod + Zeroable, + V: Default + Copy + Clone + Pod + Zeroable, + P: for<'p> Fn(&'p K, &'p V) -> bool, + const MAX_SIZE: usize, + > Iterator for RedBlackTreeDrainFilter<'a, K, V, P, MAX_SIZE> +{ + type Item = (K, V); + + fn next(&mut self) -> Option { + while !self.terminated && (!self.fwd_stack.is_empty() || self.fwd_ptr != SENTINEL) { + if self.fwd_ptr != SENTINEL { + self.fwd_stack.push(self.fwd_ptr); + self.fwd_ptr = self.tree.get_left(self.fwd_ptr); + } else { + let current_node = self.fwd_stack.pop(); + if current_node == self.rev_node { + self.terminated = true; + return None; + } + self.fwd_node = current_node; + let ptr = self.fwd_node.unwrap(); + + // Get node, check predicate. + // If predicate, remove and return + let node = self + .tree + .allocator + .nodes + .get((ptr - 1) as usize) + .unwrap() + .get_value(); + + if (self.predicate)(&node.key, &node.value) { + let (key, value) = (node.key, node.value); + self.fwd_ptr = self.tree.get_right(ptr); + self.remove.push(ptr); + + // Remove and return + return Some((key, value)); + } else { + self.fwd_ptr = self.tree.get_right(ptr); + } + } + } + + None + } +} + impl< K: Debug + PartialOrd + Ord + Copy + Clone + Default + Pod + Zeroable, V: Default + Copy + Clone + Pod + Zeroable, From 8b84f04452fd907484bb6e625fce7476c48fd65e Mon Sep 17 00:00:00 2001 From: cavemanloverboy Date: Mon, 18 Dec 2023 16:45:14 -0800 Subject: [PATCH 2/4] update method name --- .../{bench_filter.rs => bench_extract_if.rs} | 8 +- examples/extract_if.rs | 34 +++++++++ examples/filter.rs | 62 ---------------- src/red_black_tree.rs | 74 +++++++++++++++---- 4 files changed, 99 insertions(+), 79 deletions(-) rename examples/{bench_filter.rs => bench_extract_if.rs} (94%) create mode 100644 examples/extract_if.rs delete mode 100644 examples/filter.rs diff --git a/examples/bench_filter.rs b/examples/bench_extract_if.rs similarity index 94% rename from examples/bench_filter.rs rename to examples/bench_extract_if.rs index b11b3ca..ce16135 100644 --- a/examples/bench_filter.rs +++ b/examples/bench_extract_if.rs @@ -103,7 +103,7 @@ fn main() { // Start filter let timer = Instant::now(); drop( - tree.drain_filter( + tree.extract_if( #[inline(always)] |_k, v| v.maker == TARGET_MAKER, ) @@ -114,7 +114,7 @@ fn main() { } assert_eq!(tree.len(), TREE_SIZE - REMOVE); } - println!("average drain_alloc: {total_drain_alloc_micros} micros"); + println!("average extract_if_alloc: {total_drain_alloc_micros} micros"); let mut total_drain_micros = 0; for i in 0..ITERS + WARMUP_ITERS { @@ -130,7 +130,7 @@ fn main() { // Start filter let timer = Instant::now(); - for _x in tree.drain_filter( + for _x in tree.extract_if( #[inline(always)] |_k, v| v.maker == TARGET_MAKER, ) {} @@ -139,5 +139,5 @@ fn main() { } assert_eq!(tree.len(), 4096 - REMOVE); } - println!("average drain: {total_drain_micros} micros"); + println!("average extract_if: {total_drain_micros} micros"); } diff --git a/examples/extract_if.rs b/examples/extract_if.rs new file mode 100644 index 0000000..553dfeb --- /dev/null +++ b/examples/extract_if.rs @@ -0,0 +1,34 @@ +use sokoban::{NodeAllocatorMap, RedBlackTree}; + +fn main() { + { + let mut tree = RedBlackTree::::new(); + tree.insert(0, 5); // this + tree.insert(1, 5); + tree.insert(2, 0); // this + tree.insert(3, 5); + tree.insert(4, 0); // this + tree.insert(5, 5); + tree.insert(6, 5); + tree.insert(7, 0); // this + + println!("initial elements:"); + for x in tree.iter() { + println!("initial node({}) {} {}", tree.get_addr(&x.0), x.0, x.1); + } + + println!("\n Removing nodes"); + for x in tree.extract_if(my_predicate) { + println!("removed node {} {}", x.0, x.1); + } + + println!("\n remaining elements:"); + for x in tree.iter() { + println!("remaining node({}) {} {}", tree.get_addr(&x.0), x.0, x.1); + } + } +} + +fn my_predicate(key: &u32, value: &u32) -> bool { + (*key == 0) | (*value == 0) +} diff --git a/examples/filter.rs b/examples/filter.rs deleted file mode 100644 index 200c5be..0000000 --- a/examples/filter.rs +++ /dev/null @@ -1,62 +0,0 @@ -use sokoban::{NodeAllocatorMap, RedBlackTree}; - -fn main() { - { - let mut tree = RedBlackTree::::new(); - tree.insert(0, 5); // this - tree.insert(1, 5); - tree.insert(2, 0); // this - tree.insert(3, 5); - tree.insert(4, 0); // this - tree.insert(5, 5); - tree.insert(6, 5); - tree.insert(7, 0); // this - - println!("initial elements:"); - for x in tree.iter() { - println!("initial node({}) {} {}", tree.get_addr(&x.0), x.0, x.1); - } - - println!("\n Removing nodes"); - for x in tree.drain_filter(my_predicate) { - println!("removed node {} {}", x.0, x.1); - } - - println!("\n remaining elements:"); - for x in tree.iter() { - println!("remaining node({}) {} {}", tree.get_addr(&x.0), x.0, x.1); - } - } - - // Identical, but uses filter which allocates a vector for the key value pairs - { - let mut tree = RedBlackTree::::new(); - tree.insert(0, 5); // this - tree.insert(1, 5); - tree.insert(2, 0); // this - tree.insert(3, 5); - tree.insert(4, 0); // this - tree.insert(5, 5); - tree.insert(6, 5); - tree.insert(7, 0); // this - - println!("initial elements:"); - for x in tree.iter() { - println!("initial node({}) {} {}", tree.get_addr(&x.0), x.0, x.1); - } - - println!("\n Removing nodes"); - for x in tree.filter(my_predicate) { - println!("removed node {} {}", x.0, x.1); - } - - println!("\n remaining elements:"); - for x in tree.iter() { - println!("remaining node({}) {} {}", tree.get_addr(&x.0), x.0, x.1); - } - } -} - -fn my_predicate(key: &u32, value: &u32) -> bool { - (*key == 0) | (*value == 0) -} diff --git a/src/red_black_tree.rs b/src/red_black_tree.rs index c148cbc..dddeb2f 100644 --- a/src/red_black_tree.rs +++ b/src/red_black_tree.rs @@ -736,19 +736,67 @@ impl< } } - pub fn filter<'a, P: for<'p> Fn(&'p K, &'p V) -> bool>( + /// Returns an `Iterator` that selectively removes and returns items + /// from the tree where the given predicate evaluates to `true`. + /// + /// Recall that iterators are lazy. If the iterator is dropped before + /// iterating through the entire tree, not all items where the given + /// predicate would evaluate to `true` will be removed. Note that this + /// can be used to remove up to some number of entries from the tree. + /// + /// Example: + /// ```rust + /// use sokoban::{RedBlackTree, NodeAllocatorMap}; + /// let mut tree = RedBlackTree::::new(); + /// + /// // Remove if key or value is zero + /// let predicate = { + /// #[inline(always)] + /// |k: &u32, v: &u32| (*k == 0) | (*v == 0) + /// }; + /// tree.insert(0, 5); // Key is zero + /// tree.insert(1, 5); + /// tree.insert(2, 0); // Value is zero + /// tree.insert(3, 5); + /// tree.insert(4, 0); // Value is zero + /// tree.insert(5, 5); + /// tree.insert(6, 5); + /// tree.insert(7, 0); // Value is zero + /// + /// for x in tree.extract_if(predicate) { + /// println!("removed node {} {}", x.0, x.1); + /// } + /// + /// // After removing all pairs with a zero key or value + /// // these should be the remaining elements + /// let remaining = [ + /// (&1, &5), + /// (&3, &5), + /// (&5, &5), + /// (&6, &5), + /// ]; + /// assert!(Iterator::eq(tree.iter(), remaining.into_iter())); + /// + /// // Remove if value is 5 (all elements), removing up to 2 entries + /// let predicate = { + /// #[inline(always)] + /// |_k: &u32, v: &u32| (*v == 5) + /// }; + /// for x in tree.extract_if(predicate).take(2) { + /// println!("removed node {} {}", x.0, x.1); + /// } + /// let remaining = [ + /// (&5, &5), + /// (&6, &5), + /// ]; + /// assert!(Iterator::eq(tree.iter(), remaining.into_iter())); + /// ``` + pub fn extract_if<'a, P: for<'p> Fn(&'p K, &'p V) -> bool>( &'a mut self, predicate: P, - ) -> Vec<(K, V)> { - self.drain_filter(predicate).collect::>() - } - - pub fn drain_filter<'a, P: for<'p> Fn(&'p K, &'p V) -> bool>( - &'a mut self, - predicate: P, - ) -> RedBlackTreeDrainFilter<'a, K, V, P, MAX_SIZE> { + ) -> RedBlackTreeExtractIf<'a, K, V, P, MAX_SIZE> { let node = self.root; - RedBlackTreeDrainFilter { + RedBlackTreeExtractIf { tree: self, fwd_stack: vec![], fwd_ptr: node, @@ -959,7 +1007,7 @@ impl< } } -pub struct RedBlackTreeDrainFilter< +pub struct RedBlackTreeExtractIf< 'a, K: Debug + PartialOrd + Ord + Copy + Clone + Default + Pod + Zeroable, V: Default + Copy + Clone + Pod + Zeroable, @@ -986,7 +1034,7 @@ impl< V: Default + Copy + Clone + Pod + Zeroable, P: for<'p> Fn(&'p K, &'p V) -> bool, const MAX_SIZE: usize, - > Drop for RedBlackTreeDrainFilter<'a, K, V, P, MAX_SIZE> + > Drop for RedBlackTreeExtractIf<'a, K, V, P, MAX_SIZE> { fn drop(&mut self) { for node_index in core::mem::take(&mut self.remove) { @@ -1001,7 +1049,7 @@ impl< V: Default + Copy + Clone + Pod + Zeroable, P: for<'p> Fn(&'p K, &'p V) -> bool, const MAX_SIZE: usize, - > Iterator for RedBlackTreeDrainFilter<'a, K, V, P, MAX_SIZE> + > Iterator for RedBlackTreeExtractIf<'a, K, V, P, MAX_SIZE> { type Item = (K, V); From 20c0d29718c7f6451111088a4b3851194a5abc90 Mon Sep 17 00:00:00 2001 From: cavemanloverboy Date: Mon, 18 Dec 2023 16:45:20 -0800 Subject: [PATCH 3/4] bump version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 2320eff..14d85f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lib-sokoban" -version = "0.3.1" +version = "0.4.0" edition = "2021" repository = "https://github.com/jarry-xiao/sokoban" authors = ["jarry-xiao "] From b7e9959189ecb0a84a87b6bf5cca88cb411492d3 Mon Sep 17 00:00:00 2001 From: cavemanloverboy Date: Mon, 18 Dec 2023 16:48:30 -0800 Subject: [PATCH 4/4] use existing get_node method --- src/red_black_tree.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/red_black_tree.rs b/src/red_black_tree.rs index dddeb2f..e2f75a6 100644 --- a/src/red_black_tree.rs +++ b/src/red_black_tree.rs @@ -1068,15 +1068,9 @@ impl< let ptr = self.fwd_node.unwrap(); // Get node, check predicate. - // If predicate, remove and return - let node = self - .tree - .allocator - .nodes - .get((ptr - 1) as usize) - .unwrap() - .get_value(); + let node = self.tree.get_node(ptr); + // If predicate, remove and return if (self.predicate)(&node.key, &node.value) { let (key, value) = (node.key, node.value); self.fwd_ptr = self.tree.get_right(ptr);