From 8421974b56ed8b9739828a53336987a9479c0204 Mon Sep 17 00:00:00 2001 From: Alex Huszagh Date: Thu, 5 Dec 2024 11:21:01 -0600 Subject: [PATCH] Add optimizations for digit counting. This adds optimizations for base 2 and base 4 radices, which can use fast log2 implementations for digit counting. --- lexical-write-integer/src/digit_count.rs | 114 +++++++++-- .../tests/digit_count_tests.rs | 181 +++++++++++++++++- 2 files changed, 282 insertions(+), 13 deletions(-) diff --git a/lexical-write-integer/src/digit_count.rs b/lexical-write-integer/src/digit_count.rs index 2a3f36f7..ec49b5dd 100644 --- a/lexical-write-integer/src/digit_count.rs +++ b/lexical-write-integer/src/digit_count.rs @@ -8,6 +8,7 @@ #![doc(hidden)] use lexical_util::{ + assert::debug_assert_radix, div128::u128_divrem, num::{AsPrimitive, UnsignedInteger}, step::u64_step, @@ -31,9 +32,44 @@ pub fn fast_log2(x: T) -> usize { T::BITS - 1 - (x | T::ONE).leading_zeros() as usize } -// Uses a naive approach to calculate the number of digits. -macro_rules! naive_count { - ($t:ty, $radix:expr, $x:expr) => {{ +// Algorithms to calculate the number of digits from a single value. +macro_rules! digit_count { + // Highly-optimized digit count for 2^N values. + (@2 $x:expr) => {{ + digit_log2($x) + }}; + + (@4 $x:expr) => {{ + digit_log4($x) + }}; + + (@8 $x:expr) => {{ + digit_log8($x) + }}; + + (@16 $x:expr) => {{ + digit_log16($x) + }}; + + (@32 $x:expr) => {{ + digit_log32($x) + }}; + + // Uses a naive approach to calculate the number of digits. + // This uses multi-digit optimizations when possible, and always + // accurately calculates the number of digits similar to how + // the digit generation algorithm works, just without the table + // lookups. + // + // There's no good way to do this, since float logn functions are + // lossy and the value might not be exactly represented in the type + // (that is, `>= 2^53`), in which case `log(b^x, b)`, `log(b^x + 1, b)`, + // and `log(b^x - 1, b)` would all be the same. Rust's integral [`ilog`] + // functions just use naive 1-digit at a time multiplication, so it's + // less efficient than our optimized variant. + // + // [`ilog`]: https://github.com/rust-lang/rust/blob/0e98766/library/core/src/num/uint_macros.rs#L1290-L1320 + (@naive $t:ty, $radix:expr, $x:expr) => {{ // If we can do multi-digit optimizations, it's ideal, // so we want to check if our type size max value is >= // to the value. @@ -74,6 +110,48 @@ macro_rules! naive_count { }}; } +/// Highly-optimized digit count for base2 values. +/// +/// This is always the number of `BITS - ctlz(x | 1)`, so it's +/// `fast_log2(x) + 1`. This is because 0 has 1 digit, as does 1, +/// but 2 and 3 have 2, etc. +#[inline(always)] +fn digit_log2(x: T) -> usize { + fast_log2(x) + 1 +} + +/// Highly-optimized digit count for base4 values. +/// +/// This is very similar to base 2, except we shift right by +/// 1 and adjust by 1. For example, `fast_log2(3) == 1`, so +/// `fast_log2(3) >> 1 == 0`, which then gives us our result. +#[inline(always)] +fn digit_log4(x: T) -> usize { + // NOTE: This cannot be `fast_log2(fast_log2())` + (fast_log2(x | T::ONE) >> 1) + 1 +} + +/// Specialized digit count for base8 values. +#[inline(always)] +fn digit_log8(x: T) -> usize { + // FIXME: Optimize + digit_count!(@naive T, 8, x) +} + +/// Specialized digit count for base16 values. +#[inline(always)] +fn digit_log16(x: T) -> usize { + // FIXME: Optimize + digit_count!(@naive T, 16, x) +} + +/// Specialized digit count for base32 values. +#[inline(always)] +fn digit_log32(x: T) -> usize { + // FIXME: Optimize + digit_count!(@naive T, 32, x) +} + /// Quickly calculate the number of digits in a type. /// /// This uses optimizations for powers-of-two and decimal @@ -93,11 +171,16 @@ pub unsafe trait DigitCount: UnsignedInteger + DecimalCount { fn digit_count(self, radix: u32) -> usize { assert!((2..=36).contains(&radix), "radix must be >= 2 and <= 36"); match radix { + // decimal 10 => self.decimal_count(), - // NOTE: This is currently horribly inefficient and exists just for correctness. - // FIXME: Optimize for power-of-two radices - // FIXME: Optimize for non-power-of-two radices. - _ => naive_count!(Self, radix, self), + // 2^N + 2 => digit_count!(@2 self), + 4 => digit_count!(@4 self), + 8 => digit_count!(@8 self), + 16 => digit_count!(@16 self), + 32 => digit_count!(@32 self), + // fallback + _ => digit_count!(@naive Self, radix, self), } } } @@ -119,15 +202,22 @@ unsafe impl DigitCount for u128 { /// Get the number of digits in a value. #[inline(always)] fn digit_count(self, radix: u32) -> usize { - assert!((2..=36).contains(&radix), "radix must be >= 2 and <= 36"); + debug_assert_radix(radix); match radix { + // decimal 10 => self.decimal_count(), - // FIXME: Optimize this + // 2^N + 2 => digit_count!(@2 self), + 4 => digit_count!(@4 self), + 8 => digit_count!(@8 self), + 16 => digit_count!(@16 self), + 32 => digit_count!(@32 self), + // fallback _ => { // NOTE: This follows the same implementation as the digit count // generation, so this is safe. if self <= u64::MAX as u128 { - return naive_count!(u64, radix, self as u64); + return digit_count!(@naive u64, radix, self as u64); } // Doesn't fit in 64 bits, let's try our divmod. @@ -135,13 +225,13 @@ unsafe impl DigitCount for u128 { let (value, _) = u128_divrem(self, radix); let mut count = step; if value <= u64::MAX as u128 { - count += naive_count!(u64, radix, value as u64); + count += digit_count!(@naive u64, radix, value as u64); } else { // Value has to be greater than 1.8e38 let (value, _) = u128_divrem(value, radix); count += step; if value != 0 { - count += naive_count!(u64, radix, value as u64); + count += digit_count!(@naive u64, radix, value as u64); } } diff --git a/lexical-write-integer/tests/digit_count_tests.rs b/lexical-write-integer/tests/digit_count_tests.rs index 05ab31f0..ed1b2cfd 100644 --- a/lexical-write-integer/tests/digit_count_tests.rs +++ b/lexical-write-integer/tests/digit_count_tests.rs @@ -2,7 +2,11 @@ mod util; -use lexical_write_integer::digit_count; +use lexical_write_integer::decimal::DecimalCount; +use lexical_write_integer::digit_count::{self, DigitCount}; +use proptest::prelude::*; + +use crate::util::default_proptest_config; #[test] fn fast_log2_test() { @@ -28,8 +32,183 @@ fn slow_log2(x: u32) -> usize { } } +#[test] +fn base10_count_test() { + assert_eq!(1, 0u32.digit_count(10)); + assert_eq!(1, 9u32.digit_count(10)); + assert_eq!(2, 10u32.digit_count(10)); + assert_eq!(2, 11u32.digit_count(10)); + assert_eq!(2, 99u32.digit_count(10)); + assert_eq!(3, 100u32.digit_count(10)); + assert_eq!(3, 101u32.digit_count(10)); +} + +#[test] +fn base2_count_test() { + assert_eq!(1, 0u32.digit_count(2)); + assert_eq!(1, 1u32.digit_count(2)); + assert_eq!(2, 2u32.digit_count(2)); + assert_eq!(2, 3u32.digit_count(2)); + assert_eq!(3, 4u32.digit_count(2)); + + if cfg!(feature = "power-of-two") { + for i in 1usize..=127 { + let value = 2u128.pow(i as u32); + assert_eq!(i + 1, value.digit_count(2)); + assert_eq!(i + 1, (value + 1).digit_count(2)); + assert_eq!(i, (value - 1).digit_count(2)); + } + } +} + +#[test] +fn base4_count_test() { + assert_eq!(1, 0u32.digit_count(4)); + assert_eq!(1, 1u32.digit_count(4)); + assert_eq!(1, 3u32.digit_count(4)); + assert_eq!(2, 4u32.digit_count(4)); + assert_eq!(2, 5u32.digit_count(4)); + assert_eq!(2, 15u32.digit_count(4)); + assert_eq!(3, 16u32.digit_count(4)); + assert_eq!(3, 17u32.digit_count(4)); + + if cfg!(feature = "power-of-two") { + for i in 1usize..=63 { + let value = 4u128.pow(i as u32); + assert_eq!(i + 1, value.digit_count(4)); + assert_eq!(i + 1, (value + 1).digit_count(4)); + assert_eq!(i, (value - 1).digit_count(4)); + + let halfway = value + 2u128.pow(i as u32); + assert_eq!(i + 1, halfway.digit_count(4)); + assert_eq!(i + 1, (halfway + 1).digit_count(4)); + assert_eq!(i + 1, (halfway - 1).digit_count(4)); + } + } +} + +#[test] +fn base8_count_test() { + assert_eq!(1, 0u32.digit_count(8)); + assert_eq!(1, 1u32.digit_count(8)); + assert_eq!(1, 7u32.digit_count(8)); + assert_eq!(2, 8u32.digit_count(8)); + assert_eq!(2, 9u32.digit_count(8)); + assert_eq!(2, 63u32.digit_count(8)); + assert_eq!(3, 64u32.digit_count(8)); + assert_eq!(3, 65u32.digit_count(8)); + + if cfg!(feature = "power-of-two") { + for i in 1usize..=31 { + let value = 8u128.pow(i as u32); + assert_eq!(i + 1, value.digit_count(8)); + assert_eq!(i + 1, (value + 1).digit_count(8)); + assert_eq!(i, (value - 1).digit_count(8)); + + let halfway = value + 4u128.pow(i as u32); + assert_eq!(i + 1, halfway.digit_count(8)); + assert_eq!(i + 1, (halfway + 1).digit_count(8)); + assert_eq!(i + 1, (halfway - 1).digit_count(8)); + } + } +} + +#[test] +fn base16_count_test() { + assert_eq!(1, 0u32.digit_count(16)); + assert_eq!(1, 1u32.digit_count(16)); + assert_eq!(1, 15u32.digit_count(16)); + assert_eq!(2, 16u32.digit_count(16)); + assert_eq!(2, 17u32.digit_count(16)); + assert_eq!(2, 255u32.digit_count(16)); + assert_eq!(3, 256u32.digit_count(16)); + assert_eq!(3, 257u32.digit_count(16)); + + if cfg!(feature = "power-of-two") { + for i in 1usize..=15 { + let value = 16u128.pow(i as u32); + assert_eq!(i + 1, value.digit_count(16)); + assert_eq!(i + 1, (value + 1).digit_count(16)); + assert_eq!(i, (value - 1).digit_count(16)); + + let halfway = value + 8u128.pow(i as u32); + assert_eq!(i + 1, halfway.digit_count(16)); + assert_eq!(i + 1, (halfway + 1).digit_count(16)); + assert_eq!(i + 1, (halfway - 1).digit_count(16)); + } + } +} + +#[test] +fn base32_count_test() { + assert_eq!(1, 0u32.digit_count(32)); + assert_eq!(1, 1u32.digit_count(32)); + assert_eq!(1, 31u32.digit_count(32)); + assert_eq!(2, 32u32.digit_count(32)); + assert_eq!(2, 33u32.digit_count(32)); + assert_eq!(2, 1023u32.digit_count(32)); + assert_eq!(3, 1024u32.digit_count(32)); + assert_eq!(3, 1025u32.digit_count(32)); + + if cfg!(feature = "power-of-two") { + for i in 1usize..=7 { + let value = 32u128.pow(i as u32); + assert_eq!(i + 1, value.digit_count(32)); + assert_eq!(i + 1, (value + 1).digit_count(32)); + assert_eq!(i, (value - 1).digit_count(32)); + + let halfway = value + 16u128.pow(i as u32); + assert_eq!(i + 1, halfway.digit_count(32)); + assert_eq!(i + 1, (halfway + 1).digit_count(32)); + assert_eq!(i + 1, (halfway - 1).digit_count(32)); + } + } +} + default_quickcheck! { + fn decimal_count_quickcheck(x: u32) -> bool { + x.digit_count(10) == x.decimal_count() + } + fn fast_log2_quickcheck(x: u32) -> bool { slow_log2(x) == digit_count::fast_log2(x) } } + +macro_rules! ilog { + ($x:ident, $radix:expr) => {{ + if $x > 0 { + $x.ilog($radix as _) as usize + } else { + 0usize + } + }}; +} + +proptest! { + #![proptest_config(default_proptest_config())] + + #[test] + fn basen_u64_test(x: u64, radix in 2u32..=36) { + prop_assert_eq!(x.digit_count(radix), ilog!(x, radix) + 1); + } + + #[test] + #[cfg(feature = "radix")] + fn basen_u128_test(x: u128, radix in 2u32..=36) { + prop_assert_eq!(x.digit_count(radix), ilog!(x, radix) + 1); + } + + #[test] + #[cfg(all(feature = "power-of-two", not(feature = "radix")))] + fn basen_u128_test(x: u128, power in 1u32..=5) { + let radix = 2u32.pow(power); + prop_assert_eq!(x.digit_count(radix), ilog!(x, radix) + 1); + } + + #[test] + #[cfg(not(feature = "power-of-two"))] + fn basen_u128_test(x: u128) { + prop_assert_eq!(x.digit_count(10), ilog!(x, 10) + 1); + } +}