Skip to content

Commit

Permalink
sorts: Make heap sort generic
Browse files Browse the repository at this point in the history
  • Loading branch information
XuShaohua committed May 23, 2024
1 parent aa680f7 commit f3a68cb
Showing 1 changed file with 71 additions and 124 deletions.
195 changes: 71 additions & 124 deletions sort/src/heap_sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,145 +2,92 @@
// Use of this source is governed by General Public License that can be found
// in the LICENSE file.

#![allow(dead_code)]

pub fn heap_sort(list: &mut [i32]) {
let mut max_heap = MaxHeap::from_slice(list);
pub fn heap_sort<T: PartialOrd>(list: &mut [T]) {
let max_heap = list;

// 建立初始堆.
// 从最后一个非叶节点开始, 进行下移操作.
let len = max_heap.len();
let parent = (len - 1) / 2;
for i in (0..=parent).rev() {
max_heap.shift_down(i, len);
shift_down(max_heap, i, len);
}

// 每个循环, 将最大元素下沉到数组尾部, 每次循环, 数组长度减少1,
// 直到最后只剩下 1 个元素.
for i in (0..len).rev() {
// 交换根节点与当前堆的最后一个节点
max_heap.heap.swap(0, i);
// 交换根节点与当前堆的最后一个节点, 把最大的元素放到数组尾部.
max_heap.swap(0, i);

// 从根节点开始, 进行下移操作
max_heap.shift_down(0, i);
// 重新调整堆顶, 从根节点开始, 进行下移操作.
// 这里, 忽略了刚刚的最大的那个元素.
shift_down(max_heap, 0, i);
}

// 将排序后的结果更新到 list.
list.swap_with_slice(&mut max_heap.heap);
}

#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub struct MaxHeap {
heap: Vec<i32>,
}

impl Default for MaxHeap {
fn default() -> Self {
Self::new()
}
}

impl MaxHeap {
#[must_use]
#[inline]
pub const fn new() -> Self {
Self { heap: Vec::new() }
}

pub fn from_slice(slice: &[i32]) -> Self {
Self {
heap: slice.to_vec(),
// 将该节点与子节点进行比较, 如果比子节点大, 就让它跟较大的那个子节点交换,
// 一直重复这个过 程, 直到找到合适的位置.
//
// 时间复杂度是 O(log(n)), 是树的高度.
fn shift_down<T: PartialOrd>(heap: &mut [T], mut pos: usize, len: usize) {
while pos * 2 + 1 < len {
let left = pos * 2 + 1;
let right = pos * 2 + 2;

let larger = if right >= len {
// 右子树不存在, 它超出了数组的范围.
left
} else if heap[left] >= heap[right] {
// 左侧子树较大
left
} else {
// 右侧子树较大
right
};

// 当前节点与较大的子节点进行比较, 如果比它小就进行交换.
if heap[pos] < heap[larger] {
heap.swap(pos, larger);
pos = larger;
} else {
// 当前节点比最大的子节点还大
break;
}
}
}

#[must_use]
#[inline]
pub fn len(&self) -> usize {
self.heap.len()
}

#[must_use]
#[inline]
pub fn is_empty(&self) -> bool {
self.heap.is_empty()
}

#[must_use]
#[inline]
pub fn peak(&self) -> Option<i32> {
self.heap.first().copied()
}

pub fn push(&mut self, val: i32) {
// 先将新的元素插入到数组尾部.
self.heap.push(val);

// 然后将该元素进行上移 (shift up).
self.shift_up(self.heap.len() - 1);
}

#[must_use]
pub fn pop(&mut self) -> Option<i32> {
if self.is_empty() {
return None;
}

// 将堆顶的元素交换到数组的尾部, 并弹出它.
let len = self.len();
self.heap.swap(0, len - 1);
let last = self.heap.pop();

// 然后调整堆顶的位置, 让它保持大顶堆的特性.
self.shift_down(0, len);

last
}

// 将该节点与父节点进行比较, 如果比父节点大, 就让它跟父节点交换,
// 一直重复这个过程, 直到找到合适的位置.
//
// 时间复杂度是 O(log(n)), 是树的高度.
fn shift_up(&mut self, mut pos: usize) {
if pos == 0 {
return;
}

let mut parent_pos = (pos - 1) / 2;
while self.heap[parent_pos] < self.heap[pos] {
if parent_pos == 0 {
break;
}
pos = parent_pos;
parent_pos = (pos - 1) / 2;
}
}

// 将该节点与子节点进行比较, 如果比子节点大, 就让它跟较大的那个子节点交换,
// 一直重复这个过 程, 直到找到合适的位置.
//
// 时间复杂度是 O(log(n)), 是树的高度.
fn shift_down(&mut self, mut pos: usize, len: usize) {
while pos * 2 + 1 < len {
let left = pos * 2 + 1;
let right = pos * 2 + 2;

let larger = if right >= len {
// 右子树不存在, 它超出了数组的范围.
left
} else if self.heap[left] >= self.heap[right] {
// 左侧子树较大
left
} else {
// 右侧子树较大
right
};

// 当前节点与较大的子节点进行比较, 如果比它小就进行交换.
if self.heap[pos] < self.heap[larger] {
self.heap.swap(pos, larger);
pos = larger;
} else {
// 当前节点比最大的子节点还大
break;
}
}
#[cfg(test)]
mod tests {
use super::heap_sort;

#[test]
fn test_heap_sort() {
let mut list = [0, 5, 3, 2, 2];
heap_sort(&mut list);
assert_eq!(list, [0, 2, 2, 3, 5]);

let mut list = [-2, -5, -45];
heap_sort(&mut list);
assert_eq!(list, [-45, -5, -2]);

let mut list = [
-998_166, -996_360, -995_703, -995_238, -995_066, -994_740, -992_987, -983_833,
-987_905, -980_069, -977_640,
];
heap_sort(&mut list);
assert_eq!(
list,
[
-998_166, -996_360, -995_703, -995_238, -995_066, -994_740, -992_987, -987_905,
-983_833, -980_069, -977_640,
]
);

let mut list = "EASYQUESTION".chars().collect::<Vec<_>>();
heap_sort(&mut list);
assert_eq!(
list,
['A', 'E', 'E', 'I', 'N', 'O', 'Q', 'S', 'S', 'T', 'U', 'Y']
);
}
}

0 comments on commit f3a68cb

Please sign in to comment.