Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add filter functionality #16

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions examples/bench_filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
use std::time::Instant;

use sokoban::{NodeAllocatorMap, RedBlackTree};

#[derive(Default, PartialEq, PartialOrd, Clone, Copy, Debug, Ord, Eq)]
#[repr(C)]
/// Price-time priority key
struct Key {
price: u64,
id: u64,
}
impl Key {
fn rand() -> Self {
Self {
price: rand::random(),
id: rand::random(),
}
}
}
unsafe impl bytemuck::Pod for Key {}
unsafe impl bytemuck::Zeroable for Key {}

#[derive(Default, PartialEq, PartialOrd, Clone, Copy)]
#[repr(C)]
/// Mock limit order key
struct Entry {
lots: u64,
maker: u32,
_pad: [u8; 4],
}

impl Entry {
fn rand() -> Self {
Entry {
lots: 10,
maker: 0,
_pad: [0; 4],
}
}
fn rand_with_maker(idx: u32) -> Self {
assert!(idx > 0); // 0 is reserved
Entry {
lots: 10,
maker: idx,
_pad: [0; 4],
}
}
}

unsafe impl bytemuck::Pod for Entry {}
unsafe impl bytemuck::Zeroable for Entry {}

fn main() {
const ITERS: usize = 1000;
const WARMUP_ITERS: usize = 100;

const TARGET_MAKER: u32 = 5;

const TREE_SIZE: usize = 4096;
const REMOVE: usize = 256;

let mut total_remove_micros = 0;
for i in 0..ITERS + WARMUP_ITERS {
// Setup
let mut tree = RedBlackTree::<Key, Entry, TREE_SIZE>::new();
for i in 0..TREE_SIZE {
if i < REMOVE {
tree.insert(Key::rand(), Entry::rand_with_maker(TARGET_MAKER));
} else {
tree.insert(Key::rand(), Entry::rand());
}
}

// Start filter
let timer = Instant::now();
let keys = tree
.iter()
.filter(|(_key, entry)| entry.maker == TARGET_MAKER)
.map(|(key, _)| *key)
.collect::<Vec<_>>();
for key in keys {
tree.remove(&key);
}
if i > WARMUP_ITERS {
total_remove_micros += timer.elapsed().as_micros();
}
assert_eq!(tree.len(), TREE_SIZE - REMOVE);
}
println!("average id + remove: {total_remove_micros} micros");

let mut total_drain_alloc_micros = 0;
for i in 0..ITERS + WARMUP_ITERS {
// Setup
let mut tree = RedBlackTree::<Key, Entry, TREE_SIZE>::new();
for i in 0..TREE_SIZE {
if i < REMOVE {
tree.insert(Key::rand(), Entry::rand_with_maker(TARGET_MAKER));
} else {
tree.insert(Key::rand(), Entry::rand());
}
}

// Start filter
let timer = Instant::now();
drop(
tree.drain_filter(
#[inline(always)]
|_k, v| v.maker == TARGET_MAKER,
)
.collect::<Vec<_>>(),
);
if i > WARMUP_ITERS {
total_drain_alloc_micros += timer.elapsed().as_micros();
}
assert_eq!(tree.len(), TREE_SIZE - REMOVE);
}
println!("average drain_alloc: {total_drain_alloc_micros} micros");

let mut total_drain_micros = 0;
for i in 0..ITERS + WARMUP_ITERS {
// Setup
let mut tree = RedBlackTree::<Key, Entry, TREE_SIZE>::new();
for i in 0..TREE_SIZE {
if i < REMOVE {
tree.insert(Key::rand(), Entry::rand_with_maker(TARGET_MAKER));
} else {
tree.insert(Key::rand(), Entry::rand());
}
}

// Start filter
let timer = Instant::now();
for _x in tree.drain_filter(
#[inline(always)]
|_k, v| v.maker == TARGET_MAKER,
) {}
if i > WARMUP_ITERS {
total_drain_micros += timer.elapsed().as_micros();
}
assert_eq!(tree.len(), 4096 - REMOVE);
}
println!("average drain: {total_drain_micros} micros");
}
62 changes: 62 additions & 0 deletions examples/filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use sokoban::{NodeAllocatorMap, RedBlackTree};

fn main() {
{
let mut tree = RedBlackTree::<u32, u32, 8>::new();
tree.insert(0, 5); // this
tree.insert(1, 5);
tree.insert(2, 0); // this
tree.insert(3, 5);
tree.insert(4, 0); // this
tree.insert(5, 5);
tree.insert(6, 5);
tree.insert(7, 0); // this

println!("initial elements:");
for x in tree.iter() {
println!("initial node({}) {} {}", tree.get_addr(&x.0), x.0, x.1);
}

println!("\n Removing nodes");
for x in tree.drain_filter(my_predicate) {
println!("removed node {} {}", x.0, x.1);
}

println!("\n remaining elements:");
for x in tree.iter() {
println!("remaining node({}) {} {}", tree.get_addr(&x.0), x.0, x.1);
}
}

// Identical, but uses filter which allocates a vector for the key value pairs
{
let mut tree = RedBlackTree::<u32, u32, 8>::new();
tree.insert(0, 5); // this
tree.insert(1, 5);
tree.insert(2, 0); // this
tree.insert(3, 5);
tree.insert(4, 0); // this
tree.insert(5, 5);
tree.insert(6, 5);
tree.insert(7, 0); // this

println!("initial elements:");
for x in tree.iter() {
println!("initial node({}) {} {}", tree.get_addr(&x.0), x.0, x.1);
}

println!("\n Removing nodes");
for x in tree.filter(my_predicate) {
println!("removed node {} {}", x.0, x.1);
}

println!("\n remaining elements:");
for x in tree.iter() {
println!("remaining node({}) {} {}", tree.get_addr(&x.0), x.0, x.1);
}
}
}

fn my_predicate(key: &u32, value: &u32) -> bool {
(*key == 0) | (*value == 0)
}
113 changes: 113 additions & 0 deletions src/red_black_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,32 @@ impl<
terminated: false,
}
}

pub fn filter<'a, P: for<'p> Fn(&'p K, &'p V) -> bool>(
&'a mut self,
predicate: P,
) -> Vec<(K, V)> {
self.drain_filter(predicate).collect::<Vec<_>>()
}

pub fn drain_filter<'a, P: for<'p> Fn(&'p K, &'p V) -> bool>(
&'a mut self,
predicate: P,
) -> RedBlackTreeDrainFilter<'a, K, V, P, MAX_SIZE> {
let node = self.root;
RedBlackTreeDrainFilter {
tree: self,
fwd_stack: vec![],
fwd_ptr: node,
fwd_node: None,
_rev_stack: vec![],
_rev_ptr: node,
rev_node: None,
terminated: false,
remove: vec![],
jarry-xiao marked this conversation as resolved.
Show resolved Hide resolved
predicate,
}
}
}

impl<
Expand Down Expand Up @@ -933,6 +959,93 @@ impl<
}
}

pub struct RedBlackTreeDrainFilter<
'a,
K: Debug + PartialOrd + Ord + Copy + Clone + Default + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
P: for<'p> Fn(&'p K, &'p V) -> bool,
const MAX_SIZE: usize,
> {
tree: &'a mut RedBlackTree<K, V, MAX_SIZE>,
fwd_stack: Vec<u32>,
fwd_ptr: u32,
fwd_node: Option<u32>,
_rev_stack: Vec<u32>,
_rev_ptr: u32,
rev_node: Option<u32>,
terminated: bool,
/// Keeps addr's of nodes with predicate = true, to remove upon dropping.
/// It is possible to instead drop during iteration and update fwd_ptr & fwd_stack
remove: Vec<u32>,
predicate: P,
}

impl<
'a,
K: Debug + PartialOrd + Ord + Copy + Clone + Default + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
P: for<'p> Fn(&'p K, &'p V) -> bool,
const MAX_SIZE: usize,
> Drop for RedBlackTreeDrainFilter<'a, K, V, P, MAX_SIZE>
{
fn drop(&mut self) {
Copy link
Collaborator

@jarry-xiao jarry-xiao Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gist of the speed up here is that the traversal traversal to remove nodes is O(K) instead of O(K log N) right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually still O(K log N) but you save a log N on the find

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially, with remove you have to traverse the tree from the root to the entry to re-find the node addr of the entry you just found.

If you iterate through the elements and record the addr instead of the keys, you don't have to do that traversal to remove the entry. That's the gist

for node_index in core::mem::take(&mut self.remove) {
self.tree._remove_tree_node(node_index);
}
}
}

impl<
'a,
K: Debug + PartialOrd + Ord + Copy + Clone + Default + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
P: for<'p> Fn(&'p K, &'p V) -> bool,
const MAX_SIZE: usize,
> Iterator for RedBlackTreeDrainFilter<'a, K, V, P, MAX_SIZE>
{
type Item = (K, V);

fn next(&mut self) -> Option<Self::Item> {
while !self.terminated && (!self.fwd_stack.is_empty() || self.fwd_ptr != SENTINEL) {
if self.fwd_ptr != SENTINEL {
self.fwd_stack.push(self.fwd_ptr);
self.fwd_ptr = self.tree.get_left(self.fwd_ptr);
} else {
let current_node = self.fwd_stack.pop();
if current_node == self.rev_node {
self.terminated = true;
return None;
}
self.fwd_node = current_node;
let ptr = self.fwd_node.unwrap();

// Get node, check predicate.
// If predicate, remove and return
let node = self
.tree
.allocator
.nodes
.get((ptr - 1) as usize)
.unwrap()
.get_value();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there not a helper function to perform this lookup?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part was mostly copied from the iterator implementation, but perhaps

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switched to get_node. much neater


if (self.predicate)(&node.key, &node.value) {
let (key, value) = (node.key, node.value);
self.fwd_ptr = self.tree.get_right(ptr);
self.remove.push(ptr);

// Remove and return
return Some((key, value));
} else {
self.fwd_ptr = self.tree.get_right(ptr);
}
}
}

None
}
}

impl<
K: Debug + PartialOrd + Ord + Copy + Clone + Default + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
Expand Down