Skip to content

Commit

Permalink
perf: optimized StrainsVec
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxOhn committed Mar 3, 2025
1 parent 9655951 commit cc10389
Showing 1 changed file with 130 additions and 43 deletions.
173 changes: 130 additions & 43 deletions src/util/strains_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ pub use inner::*;

#[cfg(not(feature = "raw_strains"))]
mod inner {
use std::{iter::Copied, slice::Iter};
use std::{
iter::{self, Copied},
slice::{self, Iter},
};

use self::entry::StrainsEntry;

Expand All @@ -26,6 +29,7 @@ mod inner {
impl StrainsVec {
/// Constructs a new, empty [`StrainsVec`] with at least the specified
/// capacity.
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
Self {
inner: Vec::with_capacity(capacity),
Expand All @@ -36,14 +40,17 @@ mod inner {
}

/// Returns the number of elements.
#[inline]
pub const fn len(&self) -> usize {
self.len
}

/// Appends an element to the back.
#[inline]
pub fn push(&mut self, value: f64) {
if value.to_bits() > 0 {
self.inner.push(StrainsEntry::new_value(value));
if likely(value.to_bits() > 0 && value.is_sign_positive()) {
// SAFETY: we just checked whether it's positive
self.inner.push(unsafe { StrainsEntry::new_value(value) });
} else if let Some(last) = self.inner.last_mut().filter(|e| e.is_zero()) {
last.incr_zero_count();
} else {
Expand All @@ -59,6 +66,7 @@ mod inner {
}

/// Sorts the entries in descending order.
#[inline]
pub fn sort_desc(&mut self) {
#[cfg(debug_assertions)]
debug_assert!(!self.has_zero);
Expand All @@ -67,8 +75,9 @@ mod inner {
}

/// Removes all zero entries
#[inline]
pub fn retain_non_zero(&mut self) {
self.inner.retain(StrainsEntry::is_value);
self.inner.retain(|e| likely(e.is_value()));

#[cfg(debug_assertions)]
{
Expand All @@ -77,6 +86,7 @@ mod inner {
}

/// Removes all zeros and sorts the remaining entries in descending order.
#[inline]
pub fn retain_non_zero_and_sort(&mut self) {
self.retain_non_zero();
self.sort_desc();
Expand All @@ -85,6 +95,7 @@ mod inner {
/// Iterator over the raw entries, assuming that there are no zeros.
///
/// Panics if there are zeros.
#[inline]
pub fn non_zero_iter(&self) -> impl ExactSizeIterator<Item = f64> + '_ {
#[cfg(debug_assertions)]
debug_assert!(!self.has_zero);
Expand All @@ -95,6 +106,7 @@ mod inner {
/// Same as [`StrainsVec::retain_non_zero_and_sort`] followed by
/// [`StrainsVec::iter`] but the resulting iterator is faster
/// because it doesn't need to check whether entries are zero.
#[inline]
pub fn sorted_non_zero_iter(&mut self) -> impl ExactSizeIterator<Item = f64> + '_ {
self.retain_non_zero_and_sort();

Expand All @@ -103,30 +115,75 @@ mod inner {

/// Removes all zeros, sorts the remaining entries in descending order, and
/// returns an iterator over mutable references to the values.
#[inline]
pub fn sorted_non_zero_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut f64> {
self.retain_non_zero_and_sort();

self.inner.iter_mut().map(StrainsEntry::as_value_mut)
}

/// Sum up all values.
#[inline]
pub fn sum(&self) -> f64 {
self.inner
.iter()
.copied()
.filter(StrainsEntry::is_value)
.fold(0.0, |sum, e| sum + e.value())
.filter_map(StrainsEntry::try_as_value)
.sum()
}

/// Returns an iterator over the [`StrainsVec`].
#[inline]
pub fn iter(&self) -> StrainsIter<'_> {
StrainsIter::new(self)
}

/// Allocates a new `Vec<f64>` to store all values, including zeros.
pub fn into_vec(self) -> Vec<f64> {
/// Copies the first `count` items of `slice` into `dst`.
fn copy_slice(slice: &[StrainsEntry], count: usize, dst: &mut Vec<f64>) {
if unlikely(count == 0) {
return;
}

let ptr = slice.as_ptr().cast();

// SAFETY: `StrainsEntry` has the same properties as `f64`
let slice = unsafe { slice::from_raw_parts(ptr, count) };
dst.extend_from_slice(slice);
}

/// Drives the iterator until it finds a zero count. It then copies
/// entries up to that and returns the zero count.
#[inline]
fn copy_non_zero(
iter: &mut Iter<'_, StrainsEntry>,
dst: &mut Vec<f64>,
) -> Option<usize> {
let mut count = 0;
let slice = iter.as_slice();

for entry in iter {
if unlikely(entry.is_zero()) {
copy_slice(slice, count, dst);

return Some(entry.zero_count() as usize);
}

count += 1;
}

copy_slice(slice, count, dst);

None
}

let mut vec = Vec::with_capacity(self.len);
vec.extend(self.iter());
let mut iter = self.inner.iter();

while let Some(zero_count) = copy_non_zero(&mut iter, &mut vec) {
vec.extend(iter::repeat(0.0).take(zero_count));
}

vec
}
Expand Down Expand Up @@ -158,7 +215,7 @@ mod inner {
loop {
let curr = self.curr.as_mut()?;

if curr.is_value() {
if likely(curr.is_value()) {
let value = curr.value();
self.curr = self.inner.next();
self.len -= 1;
Expand Down Expand Up @@ -190,6 +247,8 @@ mod inner {

/// Private module to hide internal fields.
mod entry {
use super::likely;

/// Either a positive `f64` or an amount of consecutive `0.0`.
///
/// If the first bit is not set, i.e. the sign bit of a `f64` indicates
Expand All @@ -204,68 +263,95 @@ mod inner {
impl StrainsEntry {
const ZERO_COUNT_MASK: u64 = u64::MAX >> 1;

pub fn new_value(value: f64) -> Self {
debug_assert!(
value.is_sign_positive(),
"attempted to create negative strain entry, please report as a bug"
);

/// # Safety
///
/// `value` must be positive, i.e. neither negative nor zero.
#[inline]
pub const unsafe fn new_value(value: f64) -> Self {
Self { value }
}

#[inline]
pub const fn new_zero() -> Self {
Self {
zero_count: !Self::ZERO_COUNT_MASK + 1,
}
}

#[inline]
pub const fn is_zero(self) -> bool {
unsafe { self.value.is_sign_negative() }
}

// Requiring `self` as a reference improves ergonomics for passing this
// method as argument to higher-order functions.
#[allow(clippy::trivially_copy_pass_by_ref)]
pub const fn is_value(&self) -> bool {
#[inline]
pub const fn is_value(self) -> bool {
!self.is_zero()
}

pub fn value(self) -> f64 {
debug_assert!(self.is_value());

#[inline]
pub const fn value(self) -> f64 {
unsafe { self.value }
}

pub fn as_value_mut(&mut self) -> &mut f64 {
debug_assert!(self.is_value());
#[inline]
pub const fn try_as_value(self) -> Option<f64> {
if likely(self.is_value()) {
Some(self.value())
} else {
None
}
}

#[inline]
pub fn as_value_mut(&mut self) -> &mut f64 {
unsafe { &mut self.value }
}

pub fn zero_count(self) -> u64 {
debug_assert!(self.is_zero());

#[inline]
pub const fn zero_count(self) -> u64 {
unsafe { self.zero_count & Self::ZERO_COUNT_MASK }
}

#[inline]
pub fn incr_zero_count(&mut self) {
debug_assert!(self.is_zero());

unsafe {
self.zero_count += 1;
}
}

#[inline]
pub fn decr_zero_count(&mut self) {
debug_assert!(self.is_zero());

unsafe {
self.zero_count -= 1;
}
}
}
}

#[inline]
#[cold]
const fn cold() {}

/// Hints at the compiler that the condition is likely `true`.
#[inline]
const fn likely(b: bool) -> bool {
if !b {
cold();
}

b
}

/// Hints at the compiler that the condition is likely `false`.
#[inline]
const fn unlikely(b: bool) -> bool {
if b {
cold();
}

b
}

#[cfg(test)]
mod tests {
use proptest::prelude::*;
Expand All @@ -276,20 +362,25 @@ mod inner {

proptest! {
#[test]
fn expected(mut values in prop::collection::vec(prop::option::of(0.0..1_000.0), 0..1_000)) {
fn expected(values in prop::collection::vec(prop::option::of(0.0..1_000.0), 0..1_000)) {
let mut vec = StrainsVec::with_capacity(values.len());
let mut raw = Vec::with_capacity(values.len());

let mut additional_zeros = 0;
let mut prev_zero = false;
let mut sum = 0.0;

for opt in values.iter().copied() {
if let Some(value) = opt {
if let Some(value) = opt.filter(|&value| value != 0.0) {
let value = f64::abs(value);

vec.push(value);
raw.push(value);
prev_zero = false;
sum += value;
} else {
vec.push(0.0);
raw.push(0.0);

if prev_zero {
additional_zeros += 1;
Expand All @@ -299,20 +390,16 @@ mod inner {
}
}

assert_eq!(vec.len(), values.len());
assert_eq!(vec.inner.len(), values.len() - additional_zeros);
assert_eq!(vec.len(), raw.len());
assert_eq!(vec.inner.len(), raw.len() - additional_zeros);
assert!(vec.sum().eq(sum));
assert!(vec.iter().eq(values.iter().copied().map(|opt| opt.unwrap_or(0.0))));

values.retain(Option::is_some);

values.sort_by(|a, b| {
let (Some(a), Some(b)) = (a, b) else { unreachable!() };
assert!(vec.iter().eq(raw.iter().copied()));
assert_eq!(vec.clone().into_vec(), raw);

b.total_cmp(a)
});
raw.retain(|&n| n > 0.0);
raw.sort_by(|a, b| b.total_cmp(a));

assert!(vec.sorted_non_zero_iter().eq(values.into_iter().flatten()));
assert!(vec.sorted_non_zero_iter().eq(raw));
}
}
}
Expand Down

0 comments on commit cc10389

Please sign in to comment.