From 5a839395ee6fd1303b54403ff332103c6295bf2f Mon Sep 17 00:00:00 2001 From: Truong Nhan Nguyen <80200848+sozelfist@users.noreply.github.com> Date: Fri, 1 Nov 2024 00:47:25 +0700 Subject: [PATCH] Refactor Segment Tree Implementation (#835) ref: refactor segment tree --- src/data_structures/segment_tree.rs | 323 ++++++++++++++++------------ 1 file changed, 181 insertions(+), 142 deletions(-) diff --git a/src/data_structures/segment_tree.rs b/src/data_structures/segment_tree.rs index 1a55dc8a47e..f569381967e 100644 --- a/src/data_structures/segment_tree.rs +++ b/src/data_structures/segment_tree.rs @@ -1,185 +1,224 @@ -use std::cmp::min; +//! A module providing a Segment Tree data structure for efficient range queries +//! and updates. It supports operations like finding the minimum, maximum, +//! and sum of segments in an array. + use std::fmt::Debug; use std::ops::Range; -/// This data structure implements a segment-tree that can efficiently answer range (interval) queries on arrays. -/// It represents this array as a binary tree of merged intervals. From top to bottom: [aggregated value for the overall array], then [left-hand half, right hand half], etc. until [each individual value, ...] -/// It is generic over a reduction function for each segment or interval: basically, to describe how we merge two intervals together. -/// Note that this function should be commutative and associative -/// It could be `std::cmp::min(interval_1, interval_2)` or `std::cmp::max(interval_1, interval_2)`, or `|a, b| a + b`, `|a, b| a * b` -pub struct SegmentTree { - len: usize, // length of the represented - tree: Vec, // represents a binary tree of intervals as an array (as a BinaryHeap does, for instance) - merge: fn(T, T) -> T, // how we merge two values together +/// Custom error types representing possible errors that can occur during operations on the `SegmentTree`. +#[derive(Debug, PartialEq, Eq)] +pub enum SegmentTreeError { + /// Error indicating that an index is out of bounds. + IndexOutOfBounds, + /// Error indicating that a range provided for a query is invalid. + InvalidRange, +} + +/// A structure representing a Segment Tree. This tree can be used to efficiently +/// perform range queries and updates on an array of elements. +pub struct SegmentTree +where + T: Debug + Default + Ord + Copy, + F: Fn(T, T) -> T, +{ + /// The length of the input array for which the segment tree is built. + size: usize, + /// A vector representing the segment tree. + nodes: Vec, + /// A merging function defined as a closure or callable type. + merge_fn: F, } -impl SegmentTree { - /// Builds a SegmentTree from an array and a merge function - pub fn from_vec(arr: &[T], merge: fn(T, T) -> T) -> Self { - let len = arr.len(); - let mut buf: Vec = vec![T::default(); 2 * len]; - // Populate the tree bottom-up, from right to left - buf[len..(2 * len)].clone_from_slice(&arr[0..len]); // last len pos is the bottom of the tree -> every individual value - for i in (1..len).rev() { - // a nice property of this "flat" representation of a tree: the parent of an element at index i is located at index i/2 - buf[i] = merge(buf[2 * i], buf[2 * i + 1]); +impl SegmentTree +where + T: Debug + Default + Ord + Copy, + F: Fn(T, T) -> T, +{ + /// Creates a new `SegmentTree` from the provided slice of elements. + /// + /// # Arguments + /// + /// * `arr`: A slice of elements of type `T` to initialize the segment tree. + /// * `merge`: A merging function that defines how to merge two elements of type `T`. + /// + /// # Returns + /// + /// A new `SegmentTree` instance populated with the given elements. + pub fn from_vec(arr: &[T], merge: F) -> Self { + let size = arr.len(); + let mut buffer: Vec = vec![T::default(); 2 * size]; + + // Populate the leaves of the tree + buffer[size..(2 * size)].clone_from_slice(arr); + for idx in (1..size).rev() { + buffer[idx] = merge(buffer[2 * idx], buffer[2 * idx + 1]); } + SegmentTree { - len, - tree: buf, - merge, + size, + nodes: buffer, + merge_fn: merge, } } - /// Query the range (exclusive) - /// returns None if the range is out of the array's boundaries (eg: if start is after the end of the array, or start > end, etc.) - /// return the aggregate of values over this range otherwise - pub fn query(&self, range: Range) -> Option { - let mut l = range.start + self.len; - let mut r = min(self.len, range.end) + self.len; - let mut res = None; - // Check Wikipedia or other detailed explanations here for how to navigate the tree bottom-up to limit the number of operations - while l < r { - if l % 2 == 1 { - res = Some(match res { - None => self.tree[l], - Some(old) => (self.merge)(old, self.tree[l]), + /// Queries the segment tree for the result of merging the elements in the given range. + /// + /// # Arguments + /// + /// * `range`: A range specified as `Range`, indicating the start (inclusive) + /// and end (exclusive) indices of the segment to query. + /// + /// # Returns + /// + /// * `Ok(Some(result))` if the query was successful and there are elements in the range, + /// * `Ok(None)` if the range is empty, + /// * `Err(SegmentTreeError::InvalidRange)` if the provided range is invalid. + pub fn query(&self, range: Range) -> Result, SegmentTreeError> { + if range.start >= self.size || range.end > self.size { + return Err(SegmentTreeError::InvalidRange); + } + + let mut left = range.start + self.size; + let mut right = range.end + self.size; + let mut result = None; + + // Iterate through the segment tree to accumulate results + while left < right { + if left % 2 == 1 { + result = Some(match result { + None => self.nodes[left], + Some(old) => (self.merge_fn)(old, self.nodes[left]), }); - l += 1; + left += 1; } - if r % 2 == 1 { - r -= 1; - res = Some(match res { - None => self.tree[r], - Some(old) => (self.merge)(old, self.tree[r]), + if right % 2 == 1 { + right -= 1; + result = Some(match result { + None => self.nodes[right], + Some(old) => (self.merge_fn)(old, self.nodes[right]), }); } - l /= 2; - r /= 2; + left /= 2; + right /= 2; } - res + + Ok(result) } - /// Updates the value at index `idx` in the original array with a new value `val` - pub fn update(&mut self, idx: usize, val: T) { - // change every value where `idx` plays a role, bottom -> up - // 1: change in the right-hand side of the tree (bottom row) - let mut idx = idx + self.len; - self.tree[idx] = val; - - // 2: then bubble up - idx /= 2; - while idx != 0 { - self.tree[idx] = (self.merge)(self.tree[2 * idx], self.tree[2 * idx + 1]); - idx /= 2; + /// Updates the value at the specified index in the segment tree. + /// + /// # Arguments + /// + /// * `idx`: The index (0-based) of the element to update. + /// * `val`: The new value of type `T` to set at the specified index. + /// + /// # Returns + /// + /// * `Ok(())` if the update was successful, + /// * `Err(SegmentTreeError::IndexOutOfBounds)` if the index is out of bounds. + pub fn update(&mut self, idx: usize, val: T) -> Result<(), SegmentTreeError> { + if idx >= self.size { + return Err(SegmentTreeError::IndexOutOfBounds); + } + + let mut index = idx + self.size; + if self.nodes[index] == val { + return Ok(()); } + + self.nodes[index] = val; + while index > 1 { + index /= 2; + self.nodes[index] = (self.merge_fn)(self.nodes[2 * index], self.nodes[2 * index + 1]); + } + + Ok(()) } } #[cfg(test)] mod tests { use super::*; - use quickcheck::TestResult; - use quickcheck_macros::quickcheck; use std::cmp::{max, min}; #[test] fn test_min_segments() { let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; - let min_seg_tree = SegmentTree::from_vec(&vec, min); - assert_eq!(Some(-5), min_seg_tree.query(4..7)); - assert_eq!(Some(-30), min_seg_tree.query(0..vec.len())); - assert_eq!(Some(-30), min_seg_tree.query(0..2)); - assert_eq!(Some(-4), min_seg_tree.query(1..3)); - assert_eq!(Some(-5), min_seg_tree.query(1..7)); + let mut min_seg_tree = SegmentTree::from_vec(&vec, min); + assert_eq!(min_seg_tree.query(4..7), Ok(Some(-5))); + assert_eq!(min_seg_tree.query(0..vec.len()), Ok(Some(-30))); + assert_eq!(min_seg_tree.query(0..2), Ok(Some(-30))); + assert_eq!(min_seg_tree.query(1..3), Ok(Some(-4))); + assert_eq!(min_seg_tree.query(1..7), Ok(Some(-5))); + assert_eq!(min_seg_tree.update(5, 10), Ok(())); + assert_eq!(min_seg_tree.update(14, -8), Ok(())); + assert_eq!(min_seg_tree.query(4..7), Ok(Some(3))); + assert_eq!( + min_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(min_seg_tree.query(5..5), Ok(None)); + assert_eq!( + min_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + min_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); } #[test] fn test_max_segments() { - let val_at_6 = 6; - let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8]; + let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; let mut max_seg_tree = SegmentTree::from_vec(&vec, max); - assert_eq!(Some(15), max_seg_tree.query(0..vec.len())); - let max_4_to_6 = 6; - assert_eq!(Some(max_4_to_6), max_seg_tree.query(4..7)); - let delta = 2; - max_seg_tree.update(6, val_at_6 + delta); - assert_eq!(Some(val_at_6 + delta), max_seg_tree.query(4..7)); + assert_eq!(max_seg_tree.query(0..vec.len()), Ok(Some(15))); + assert_eq!(max_seg_tree.query(3..5), Ok(Some(7))); + assert_eq!(max_seg_tree.query(4..8), Ok(Some(11))); + assert_eq!(max_seg_tree.query(8..10), Ok(Some(9))); + assert_eq!(max_seg_tree.query(9..12), Ok(Some(15))); + assert_eq!(max_seg_tree.update(4, 10), Ok(())); + assert_eq!(max_seg_tree.update(14, -8), Ok(())); + assert_eq!(max_seg_tree.query(3..5), Ok(Some(10))); + assert_eq!( + max_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(max_seg_tree.query(5..5), Ok(None)); + assert_eq!( + max_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + max_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); } #[test] fn test_sum_segments() { - let val_at_6 = 6; - let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8]; + let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; let mut sum_seg_tree = SegmentTree::from_vec(&vec, |a, b| a + b); - for (i, val) in vec.iter().enumerate() { - assert_eq!(Some(*val), sum_seg_tree.query(i..(i + 1))); - } - let sum_4_to_6 = sum_seg_tree.query(4..7); - assert_eq!(Some(4), sum_4_to_6); - let delta = 3; - sum_seg_tree.update(6, val_at_6 + delta); + assert_eq!(sum_seg_tree.query(0..vec.len()), Ok(Some(38))); + assert_eq!(sum_seg_tree.query(1..4), Ok(Some(5))); + assert_eq!(sum_seg_tree.query(4..7), Ok(Some(4))); + assert_eq!(sum_seg_tree.query(6..9), Ok(Some(-3))); + assert_eq!(sum_seg_tree.query(9..vec.len()), Ok(Some(37))); + assert_eq!(sum_seg_tree.update(5, 10), Ok(())); + assert_eq!(sum_seg_tree.update(14, -8), Ok(())); + assert_eq!(sum_seg_tree.query(4..7), Ok(Some(19))); assert_eq!( - sum_4_to_6.unwrap() + delta, - sum_seg_tree.query(4..7).unwrap() + sum_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(sum_seg_tree.query(5..5), Ok(None)); + assert_eq!( + sum_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + sum_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) ); - } - - // Some properties over segment trees: - // When asking for the range of the overall array, return the same as iter().min() or iter().max(), etc. - // When asking for an interval containing a single value, return this value, no matter the merge function - - #[quickcheck] - fn check_overall_interval_min(array: Vec) -> TestResult { - let seg_tree = SegmentTree::from_vec(&array, min); - TestResult::from_bool(array.iter().min().copied() == seg_tree.query(0..array.len())) - } - - #[quickcheck] - fn check_overall_interval_max(array: Vec) -> TestResult { - let seg_tree = SegmentTree::from_vec(&array, max); - TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len())) - } - - #[quickcheck] - fn check_overall_interval_sum(array: Vec) -> TestResult { - let seg_tree = SegmentTree::from_vec(&array, max); - TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len())) - } - - #[quickcheck] - fn check_single_interval_min(array: Vec) -> TestResult { - let seg_tree = SegmentTree::from_vec(&array, min); - for (i, value) in array.into_iter().enumerate() { - let res = seg_tree.query(i..(i + 1)); - if res != Some(value) { - return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); - } - } - TestResult::passed() - } - - #[quickcheck] - fn check_single_interval_max(array: Vec) -> TestResult { - let seg_tree = SegmentTree::from_vec(&array, max); - for (i, value) in array.into_iter().enumerate() { - let res = seg_tree.query(i..(i + 1)); - if res != Some(value) { - return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); - } - } - TestResult::passed() - } - - #[quickcheck] - fn check_single_interval_sum(array: Vec) -> TestResult { - let seg_tree = SegmentTree::from_vec(&array, max); - for (i, value) in array.into_iter().enumerate() { - let res = seg_tree.query(i..(i + 1)); - if res != Some(value) { - return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); - } - } - TestResult::passed() } }