Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add filter functionality #16

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lib-sokoban"
version = "0.3.1"
version = "0.4.0"
edition = "2021"
repository = "https://github.com/jarry-xiao/sokoban"
authors = ["jarry-xiao <[email protected]>"]
Expand Down
143 changes: 143 additions & 0 deletions examples/bench_extract_if.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
use std::time::Instant;

use sokoban::{NodeAllocatorMap, RedBlackTree};

#[derive(Default, PartialEq, PartialOrd, Clone, Copy, Debug, Ord, Eq)]
#[repr(C)]
/// Price-time priority key
struct Key {
price: u64,
id: u64,
}
impl Key {
fn rand() -> Self {
Self {
price: rand::random(),
id: rand::random(),
}
}
}
unsafe impl bytemuck::Pod for Key {}
unsafe impl bytemuck::Zeroable for Key {}

#[derive(Default, PartialEq, PartialOrd, Clone, Copy)]
#[repr(C)]
/// Mock limit order key
struct Entry {
lots: u64,
maker: u32,
_pad: [u8; 4],
}

impl Entry {
fn rand() -> Self {
Entry {
lots: 10,
maker: 0,
_pad: [0; 4],
}
}
fn rand_with_maker(idx: u32) -> Self {
assert!(idx > 0); // 0 is reserved
Entry {
lots: 10,
maker: idx,
_pad: [0; 4],
}
}
}

unsafe impl bytemuck::Pod for Entry {}
unsafe impl bytemuck::Zeroable for Entry {}

fn main() {
const ITERS: usize = 1000;
const WARMUP_ITERS: usize = 100;

const TARGET_MAKER: u32 = 5;

const TREE_SIZE: usize = 4096;
const REMOVE: usize = 256;

let mut total_remove_micros = 0;
for i in 0..ITERS + WARMUP_ITERS {
// Setup
let mut tree = RedBlackTree::<Key, Entry, TREE_SIZE>::new();
for i in 0..TREE_SIZE {
if i < REMOVE {
tree.insert(Key::rand(), Entry::rand_with_maker(TARGET_MAKER));
} else {
tree.insert(Key::rand(), Entry::rand());
}
}

// Start filter
let timer = Instant::now();
let keys = tree
.iter()
.filter(|(_key, entry)| entry.maker == TARGET_MAKER)
.map(|(key, _)| *key)
.collect::<Vec<_>>();
for key in keys {
tree.remove(&key);
}
if i > WARMUP_ITERS {
total_remove_micros += timer.elapsed().as_micros();
}
assert_eq!(tree.len(), TREE_SIZE - REMOVE);
}
println!("average id + remove: {total_remove_micros} micros");

let mut total_drain_alloc_micros = 0;
for i in 0..ITERS + WARMUP_ITERS {
// Setup
let mut tree = RedBlackTree::<Key, Entry, TREE_SIZE>::new();
for i in 0..TREE_SIZE {
if i < REMOVE {
tree.insert(Key::rand(), Entry::rand_with_maker(TARGET_MAKER));
} else {
tree.insert(Key::rand(), Entry::rand());
}
}

// Start filter
let timer = Instant::now();
drop(
tree.extract_if(
#[inline(always)]
|_k, v| v.maker == TARGET_MAKER,
)
.collect::<Vec<_>>(),
);
if i > WARMUP_ITERS {
total_drain_alloc_micros += timer.elapsed().as_micros();
}
assert_eq!(tree.len(), TREE_SIZE - REMOVE);
}
println!("average extract_if_alloc: {total_drain_alloc_micros} micros");

let mut total_drain_micros = 0;
for i in 0..ITERS + WARMUP_ITERS {
// Setup
let mut tree = RedBlackTree::<Key, Entry, TREE_SIZE>::new();
for i in 0..TREE_SIZE {
if i < REMOVE {
tree.insert(Key::rand(), Entry::rand_with_maker(TARGET_MAKER));
} else {
tree.insert(Key::rand(), Entry::rand());
}
}

// Start filter
let timer = Instant::now();
for _x in tree.extract_if(
#[inline(always)]
|_k, v| v.maker == TARGET_MAKER,
) {}
if i > WARMUP_ITERS {
total_drain_micros += timer.elapsed().as_micros();
}
assert_eq!(tree.len(), 4096 - REMOVE);
}
println!("average extract_if: {total_drain_micros} micros");
}
34 changes: 34 additions & 0 deletions examples/extract_if.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use sokoban::{NodeAllocatorMap, RedBlackTree};

fn main() {
{
let mut tree = RedBlackTree::<u32, u32, 8>::new();
tree.insert(0, 5); // this
tree.insert(1, 5);
tree.insert(2, 0); // this
tree.insert(3, 5);
tree.insert(4, 0); // this
tree.insert(5, 5);
tree.insert(6, 5);
tree.insert(7, 0); // this

println!("initial elements:");
for x in tree.iter() {
println!("initial node({}) {} {}", tree.get_addr(&x.0), x.0, x.1);
}

println!("\n Removing nodes");
for x in tree.extract_if(my_predicate) {
println!("removed node {} {}", x.0, x.1);
}

println!("\n remaining elements:");
for x in tree.iter() {
println!("remaining node({}) {} {}", tree.get_addr(&x.0), x.0, x.1);
}
}
}

fn my_predicate(key: &u32, value: &u32) -> bool {
(*key == 0) | (*value == 0)
}
155 changes: 155 additions & 0 deletions src/red_black_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,80 @@ impl<
terminated: false,
}
}

/// Returns an `Iterator` that selectively removes and returns items
/// from the tree where the given predicate evaluates to `true`.
///
/// Recall that iterators are lazy. If the iterator is dropped before
/// iterating through the entire tree, not all items where the given
/// predicate would evaluate to `true` will be removed. Note that this
/// can be used to remove up to some number of entries from the tree.
///
/// Example:
/// ```rust
/// use sokoban::{RedBlackTree, NodeAllocatorMap};
/// let mut tree = RedBlackTree::<u32, u32, 8>::new();
///
/// // Remove if key or value is zero
/// let predicate = {
/// #[inline(always)]
/// |k: &u32, v: &u32| (*k == 0) | (*v == 0)
/// };
/// tree.insert(0, 5); // Key is zero
/// tree.insert(1, 5);
/// tree.insert(2, 0); // Value is zero
/// tree.insert(3, 5);
/// tree.insert(4, 0); // Value is zero
/// tree.insert(5, 5);
/// tree.insert(6, 5);
/// tree.insert(7, 0); // Value is zero
///
/// for x in tree.extract_if(predicate) {
/// println!("removed node {} {}", x.0, x.1);
/// }
///
/// // After removing all pairs with a zero key or value
/// // these should be the remaining elements
/// let remaining = [
/// (&1, &5),
/// (&3, &5),
/// (&5, &5),
/// (&6, &5),
/// ];
/// assert!(Iterator::eq(tree.iter(), remaining.into_iter()));
///
/// // Remove if value is 5 (all elements), removing up to 2 entries
/// let predicate = {
/// #[inline(always)]
/// |_k: &u32, v: &u32| (*v == 5)
/// };
/// for x in tree.extract_if(predicate).take(2) {
/// println!("removed node {} {}", x.0, x.1);
/// }
/// let remaining = [
/// (&5, &5),
/// (&6, &5),
/// ];
/// assert!(Iterator::eq(tree.iter(), remaining.into_iter()));
/// ```
pub fn extract_if<'a, P: for<'p> Fn(&'p K, &'p V) -> bool>(
&'a mut self,
predicate: P,
) -> RedBlackTreeExtractIf<'a, K, V, P, MAX_SIZE> {
let node = self.root;
RedBlackTreeExtractIf {
tree: self,
fwd_stack: vec![],
fwd_ptr: node,
fwd_node: None,
_rev_stack: vec![],
_rev_ptr: node,
rev_node: None,
terminated: false,
remove: vec![],
jarry-xiao marked this conversation as resolved.
Show resolved Hide resolved
predicate,
}
}
}

impl<
Expand Down Expand Up @@ -933,6 +1007,87 @@ impl<
}
}

pub struct RedBlackTreeExtractIf<
'a,
K: Debug + PartialOrd + Ord + Copy + Clone + Default + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
P: for<'p> Fn(&'p K, &'p V) -> bool,
const MAX_SIZE: usize,
> {
tree: &'a mut RedBlackTree<K, V, MAX_SIZE>,
fwd_stack: Vec<u32>,
fwd_ptr: u32,
fwd_node: Option<u32>,
_rev_stack: Vec<u32>,
_rev_ptr: u32,
rev_node: Option<u32>,
terminated: bool,
/// Keeps addr's of nodes with predicate = true, to remove upon dropping.
/// It is possible to instead drop during iteration and update fwd_ptr & fwd_stack
remove: Vec<u32>,
predicate: P,
}

impl<
'a,
K: Debug + PartialOrd + Ord + Copy + Clone + Default + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
P: for<'p> Fn(&'p K, &'p V) -> bool,
const MAX_SIZE: usize,
> Drop for RedBlackTreeExtractIf<'a, K, V, P, MAX_SIZE>
{
fn drop(&mut self) {
Copy link
Collaborator

@jarry-xiao jarry-xiao Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gist of the speed up here is that the traversal traversal to remove nodes is O(K) instead of O(K log N) right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually still O(K log N) but you save a log N on the find

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially, with remove you have to traverse the tree from the root to the entry to re-find the node addr of the entry you just found.

If you iterate through the elements and record the addr instead of the keys, you don't have to do that traversal to remove the entry. That's the gist

for node_index in core::mem::take(&mut self.remove) {
self.tree._remove_tree_node(node_index);
}
}
}

impl<
'a,
K: Debug + PartialOrd + Ord + Copy + Clone + Default + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
P: for<'p> Fn(&'p K, &'p V) -> bool,
const MAX_SIZE: usize,
> Iterator for RedBlackTreeExtractIf<'a, K, V, P, MAX_SIZE>
{
type Item = (K, V);

fn next(&mut self) -> Option<Self::Item> {
while !self.terminated && (!self.fwd_stack.is_empty() || self.fwd_ptr != SENTINEL) {
if self.fwd_ptr != SENTINEL {
self.fwd_stack.push(self.fwd_ptr);
self.fwd_ptr = self.tree.get_left(self.fwd_ptr);
} else {
let current_node = self.fwd_stack.pop();
if current_node == self.rev_node {
self.terminated = true;
return None;
}
self.fwd_node = current_node;
let ptr = self.fwd_node.unwrap();

// Get node, check predicate.
let node = self.tree.get_node(ptr);

// If predicate, remove and return
if (self.predicate)(&node.key, &node.value) {
let (key, value) = (node.key, node.value);
self.fwd_ptr = self.tree.get_right(ptr);
self.remove.push(ptr);

// Remove and return
return Some((key, value));
} else {
self.fwd_ptr = self.tree.get_right(ptr);
}
}
}

None
}
}

impl<
K: Debug + PartialOrd + Ord + Copy + Clone + Default + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
Expand Down