diff --git a/rust/structures/mergeable_heap.rs b/rust/structures/mergeable_heap.rs index 40ee8359..c205a40a 100644 --- a/rust/structures/mergeable_heap.rs +++ b/rust/structures/mergeable_heap.rs @@ -1,72 +1,88 @@ -use std::cell::RefCell; -use std::mem::{replace, swap, take}; -use std::rc::Rc; +use std::mem::swap; struct Heap { value: V, - left: Option>>>, - right: Option>>>, + left: Option>>, + right: Option>>, } -impl Heap { - fn new(value: V) -> Option>>> { - Some(Rc::new(RefCell::new(Self { +impl Heap { + fn new(value: V) -> Option>> { + Some(Box::new(Self { value, left: None, right: None, - }))) + })) } - fn merge<'a>( - mut a: &'a Option>>>, - mut b: &'a Option>>>, - ) -> Option>>> { + + fn merge(a: Option>>, b: Option>>) -> Option>> { if a.is_none() { - return b.clone(); + return b; } if b.is_none() { - return a.clone(); + return a; } - if a.as_ref()?.borrow().value > b.as_ref()?.borrow().value { - swap(&mut a, &mut b); + let mut ra = a.unwrap(); + let mut rb = b.unwrap(); + if ra.value > rb.value { + swap(&mut ra, &mut rb); } if rand::random() { - let mut ra = a.as_ref()?.borrow_mut(); - let l = take(&mut ra.left); - let r = take(&mut ra.right); - ra.left = r; - ra.right = l; + swap(&mut ra.left, &mut ra.right); } - let m = Self::merge(replace(&mut &a.as_ref()?.borrow_mut().left, &None), b); - a.as_ref()?.borrow_mut().left = m; - a.clone() + ra.left = Self::merge(ra.left, Some(rb)); + Some(ra) } - fn remove_min(heap: &Option>>>) -> (Option>>>, V) { - let h = heap.as_ref().unwrap().borrow(); - (Self::merge(&h.left, &h.right), h.value.clone()) + fn remove_min(heap: Option>>) -> (Option>>, V) { + let h = heap.unwrap(); + (Self::merge(h.left, h.right), h.value) } - fn add(heap: &Option>>>, value: V) -> Option>>> { - Self::merge(heap, &Heap::new(value)) + fn add(heap: Option>>, value: V) -> Option>> { + Self::merge(heap, Heap::new(value)) } } #[cfg(test)] mod tests { use crate::structures::mergeable_heap::Heap; + use rand::seq::SliceRandom; + use rand::thread_rng; + use rstest::rstest; + + #[rstest] + #[case(&mut [])] + #[case(&mut [0])] + #[case(&mut [1, 1])] + #[case(&mut [3, 1, 2])] + fn basic_test(#[case] seq: &mut [u32]) { + test(seq); + } + + #[test] + fn big_test1() { + let mut values = (0..10_000).collect::>(); + values.shuffle(&mut thread_rng()); + test(&mut values); + } #[test] - fn basic_test() { - let mut h = None; - h = Heap::add(&h, 3); - h = Heap::add(&h, 1); - h = Heap::add(&h, 2); - let mut values = Vec::new(); - while h.is_some() { - let (heap, min_value) = Heap::remove_min(&h); - values.push(min_value); - h = heap; + fn big_test2() { + let mut values = vec![0; 10_000]; + values.shuffle(&mut thread_rng()); + test(&mut values); + } + + fn test(seq: &mut [u32]) { + let mut heap = seq.iter().fold(None, |h, v| Heap::add(h, v)); + let mut heap_sorted_values = Vec::new(); + while heap.is_some() { + let (updated_heap, min_value) = Heap::remove_min(heap); + heap = updated_heap; + heap_sorted_values.push(*min_value); } - assert_eq!(values, [1, 2, 3]); + seq.sort(); + assert_eq!(heap_sorted_values, seq); } }