diff --git a/src/tree/node.rs b/src/tree/node.rs index 8cc0582..ea3b65c 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -2,29 +2,29 @@ use super::sphere::Sphere; pub struct InsertionEntry { pub idx: usize, - pub sphere: Sphere, pub parent_height: usize, + pub sphere: Sphere, } pub struct Node { pub idx: usize, + pub parent: usize, pub height: usize, pub sphere: Sphere, - pub parent: usize, - pub children: Vec, pub variance: [f64; D], + pub children: Vec, pub bound: f64, } impl Node { - pub fn new(idx: usize, height: usize, sphere: Sphere, parent: usize) -> Node { + pub fn new(idx: usize, parent: usize, height: usize, sphere: Sphere) -> Node { Node { idx, + parent, height, sphere, - parent, - children: Vec::new(), variance: [f64::INFINITY; D], + children: Vec::new(), bound: f64::INFINITY, } } diff --git a/src/tree/sstree.rs b/src/tree/sstree.rs index 0006cd2..077b0af 100644 --- a/src/tree/sstree.rs +++ b/src/tree/sstree.rs @@ -15,11 +15,10 @@ pub struct SSTree { root: usize, node_max_entries: usize, nodes: Vec>, - coreset: Vec, } impl Index for SSTree { - fn insert(&mut self, point: [f64; D]) -> Vec { + fn insert(&mut self, point: [f64; D]) { // Insert the new point let new_point_index = self.data.len(); self.data.push(point); @@ -32,49 +31,42 @@ impl Index for SSTree { self.add_core(new_point_index, neighbor_idx); } - // Init the coreset points - self.coreset.clear(); - self.coreset.push(new_point_index); - - if self.root == usize::MAX { - self.root = self.nodes.len(); - let mut root = Node::new(self.root, 0, Sphere::new(point, 0.0), usize::MAX); - root.children.push(new_point_index); - self.nodes.push(root); - self.reshape(self.root); - } else { - let mut reinsert_entries = vec![InsertionEntry { - idx: new_point_index, - sphere: Sphere::new(point, 0.0), - parent_height: 0, - }]; - let mut reinsert_height = 0; - while let Some(entry) = reinsert_entries.pop() { - let new_reinsert_entries = self.insert_recursive(entry, self.root, reinsert_height); - reinsert_entries.extend(new_reinsert_entries); - reinsert_height += 1; - } - - if self.nodes[self.root].children.len() > self.node_max_entries { - let old_root_idx = self.root; - let sibling_entry = self.split(self.root); - let new_root_idx = self.nodes.len(); - let mut new_root = Node::new( - new_root_idx, - self.nodes[old_root_idx].height + 1, - Sphere::new(self.nodes[sibling_entry.idx].sphere.center, 0.0), - usize::MAX, - ); - new_root.children = vec![old_root_idx, sibling_entry.idx]; - self.nodes[old_root_idx].parent = new_root_idx; - self.nodes[sibling_entry.idx].parent = new_root_idx; - self.nodes.push(new_root); - self.reshape(new_root_idx); - self.root = new_root_idx; - } + let mut reinsert_entries = vec![InsertionEntry { + idx: new_point_index, + parent_height: 0, + sphere: Sphere::new(point, 0.0), + }]; + let mut reinsert_height = 0; + while let Some(entry) = reinsert_entries.pop() { + let new_reinsert_entries = self.insert_recursive(entry, self.root, reinsert_height); + reinsert_entries.extend(new_reinsert_entries); + reinsert_height += 1; + } + + if self.nodes[self.root].children.len() > self.node_max_entries { + let old_root_idx = self.root; + let sibling_entry = self.split(self.root); + let new_root_idx = self.nodes.len(); + let mut new_root = Node::new( + new_root_idx, + usize::MAX, + self.nodes[old_root_idx].height + 1, + Sphere::new(self.nodes[sibling_entry.idx].sphere.center, 0.0), + ); + new_root.children = vec![old_root_idx, sibling_entry.idx]; + self.nodes[old_root_idx].parent = new_root_idx; + self.nodes[sibling_entry.idx].parent = new_root_idx; + self.nodes.push(new_root); + self.reshape(new_root_idx); + self.root = new_root_idx; } self.update_core(self.root, new_point_index); - self.coreset.clone() + } + + fn rknn(&self, point: [f64; D]) -> Vec { + let mut result = Vec::new(); + self.rknn_recursive(self.root, &point, &mut result); + result } fn query_range(&self, point_index: usize, range: f64) -> Vec { @@ -105,14 +97,34 @@ impl Index for SSTree { impl SSTree { #[must_use] pub fn new(k: usize) -> Self { + let root = Node::new(0, usize::MAX, 0, Sphere::new([f64::INFINITY; D], 0.0)); Self { k, data: Vec::new(), neighbors: Vec::new(), - root: usize::MAX, + root: 0, node_max_entries: 2 * k + 1, - nodes: Vec::new(), - coreset: Vec::new(), + nodes: vec![root], + } + } + + fn rknn_recursive(&self, node_idx: usize, point: &[f64; D], rneighbors: &mut Vec) { + let distance_to_node = self.nodes[node_idx].sphere.min_distance(point); + if distance_to_node > self.nodes[node_idx].bound { + return; + } + + if self.nodes[node_idx].is_leaf() { + for neighbor_idx in &self.nodes[node_idx].children { + let distance = euclidean(point, &self.data[*neighbor_idx]); + if distance < self.core_distance_of(*neighbor_idx) { + rneighbors.push(*neighbor_idx); + } + } + } else { + for child_idx in &self.nodes[node_idx].children { + self.rknn_recursive(*child_idx, point, rneighbors); + } } } @@ -230,7 +242,6 @@ impl SSTree { if distance.0 >= cur_core_distance { return false; } - self.coreset.push(point_index); self.neighbors[point_index].push((distance, neighbor_index)); if self.neighbors[point_index].len() > self.k { self.neighbors[point_index].pop(); @@ -392,9 +403,9 @@ impl SSTree { let sibling_sphere = Sphere::new(left_centroid, 0.); let mut sibling = Node::new( sibling_idx, + parent, self.nodes[node_idx].height, sibling_sphere, - parent, ); sibling.children = sibling_children; self.nodes.push(sibling);