Skip to content

Commit

Permalink
Add optimizations for digit counting.
Browse files Browse the repository at this point in the history
This adds optimizations for base 2 and base 4 radices, which can use fast log2 implementations for digit counting.
  • Loading branch information
Alexhuszagh committed Dec 5, 2024
1 parent 21363ef commit 8421974
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 13 deletions.
114 changes: 102 additions & 12 deletions lexical-write-integer/src/digit_count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#![doc(hidden)]

use lexical_util::{
assert::debug_assert_radix,
div128::u128_divrem,
num::{AsPrimitive, UnsignedInteger},
step::u64_step,
Expand All @@ -31,9 +32,44 @@ pub fn fast_log2<T: UnsignedInteger>(x: T) -> usize {
T::BITS - 1 - (x | T::ONE).leading_zeros() as usize
}

// Uses a naive approach to calculate the number of digits.
macro_rules! naive_count {
($t:ty, $radix:expr, $x:expr) => {{
// Algorithms to calculate the number of digits from a single value.
macro_rules! digit_count {
// Highly-optimized digit count for 2^N values.
(@2 $x:expr) => {{
digit_log2($x)
}};

(@4 $x:expr) => {{
digit_log4($x)
}};

(@8 $x:expr) => {{
digit_log8($x)
}};

(@16 $x:expr) => {{
digit_log16($x)
}};

(@32 $x:expr) => {{
digit_log32($x)
}};

// Uses a naive approach to calculate the number of digits.
// This uses multi-digit optimizations when possible, and always
// accurately calculates the number of digits similar to how
// the digit generation algorithm works, just without the table
// lookups.
//
// There's no good way to do this, since float logn functions are
// lossy and the value might not be exactly represented in the type
// (that is, `>= 2^53`), in which case `log(b^x, b)`, `log(b^x + 1, b)`,
// and `log(b^x - 1, b)` would all be the same. Rust's integral [`ilog`]
// functions just use naive 1-digit at a time multiplication, so it's
// less efficient than our optimized variant.
//
// [`ilog`]: https://github.com/rust-lang/rust/blob/0e98766/library/core/src/num/uint_macros.rs#L1290-L1320
(@naive $t:ty, $radix:expr, $x:expr) => {{
// If we can do multi-digit optimizations, it's ideal,
// so we want to check if our type size max value is >=
// to the value.
Expand Down Expand Up @@ -74,6 +110,48 @@ macro_rules! naive_count {
}};
}

/// Highly-optimized digit count for base2 values.
///
/// This is always the number of `BITS - ctlz(x | 1)`, so it's
/// `fast_log2(x) + 1`. This is because 0 has 1 digit, as does 1,
/// but 2 and 3 have 2, etc.
#[inline(always)]
fn digit_log2<T: UnsignedInteger>(x: T) -> usize {
fast_log2(x) + 1
}

/// Highly-optimized digit count for base4 values.
///
/// This is very similar to base 2, except we shift right by
/// 1 and adjust by 1. For example, `fast_log2(3) == 1`, so
/// `fast_log2(3) >> 1 == 0`, which then gives us our result.
#[inline(always)]
fn digit_log4<T: UnsignedInteger>(x: T) -> usize {
// NOTE: This cannot be `fast_log2(fast_log2())`
(fast_log2(x | T::ONE) >> 1) + 1
}

/// Specialized digit count for base8 values.
#[inline(always)]
fn digit_log8<T: UnsignedInteger>(x: T) -> usize {
// FIXME: Optimize
digit_count!(@naive T, 8, x)
}

/// Specialized digit count for base16 values.
#[inline(always)]
fn digit_log16<T: UnsignedInteger>(x: T) -> usize {
// FIXME: Optimize
digit_count!(@naive T, 16, x)
}

/// Specialized digit count for base32 values.
#[inline(always)]
fn digit_log32<T: UnsignedInteger>(x: T) -> usize {
// FIXME: Optimize
digit_count!(@naive T, 32, x)
}

/// Quickly calculate the number of digits in a type.
///
/// This uses optimizations for powers-of-two and decimal
Expand All @@ -93,11 +171,16 @@ pub unsafe trait DigitCount: UnsignedInteger + DecimalCount {
fn digit_count(self, radix: u32) -> usize {
assert!((2..=36).contains(&radix), "radix must be >= 2 and <= 36");
match radix {
// decimal
10 => self.decimal_count(),
// NOTE: This is currently horribly inefficient and exists just for correctness.
// FIXME: Optimize for power-of-two radices
// FIXME: Optimize for non-power-of-two radices.
_ => naive_count!(Self, radix, self),
// 2^N
2 => digit_count!(@2 self),
4 => digit_count!(@4 self),
8 => digit_count!(@8 self),
16 => digit_count!(@16 self),
32 => digit_count!(@32 self),
// fallback
_ => digit_count!(@naive Self, radix, self),
}
}
}
Expand All @@ -119,29 +202,36 @@ unsafe impl DigitCount for u128 {
/// Get the number of digits in a value.
#[inline(always)]
fn digit_count(self, radix: u32) -> usize {
assert!((2..=36).contains(&radix), "radix must be >= 2 and <= 36");
debug_assert_radix(radix);
match radix {
// decimal
10 => self.decimal_count(),
// FIXME: Optimize this
// 2^N
2 => digit_count!(@2 self),
4 => digit_count!(@4 self),
8 => digit_count!(@8 self),
16 => digit_count!(@16 self),
32 => digit_count!(@32 self),
// fallback
_ => {
// NOTE: This follows the same implementation as the digit count
// generation, so this is safe.
if self <= u64::MAX as u128 {
return naive_count!(u64, radix, self as u64);
return digit_count!(@naive u64, radix, self as u64);
}

// Doesn't fit in 64 bits, let's try our divmod.
let step = u64_step(radix);
let (value, _) = u128_divrem(self, radix);
let mut count = step;
if value <= u64::MAX as u128 {
count += naive_count!(u64, radix, value as u64);
count += digit_count!(@naive u64, radix, value as u64);
} else {
// Value has to be greater than 1.8e38
let (value, _) = u128_divrem(value, radix);
count += step;
if value != 0 {
count += naive_count!(u64, radix, value as u64);
count += digit_count!(@naive u64, radix, value as u64);
}
}

Expand Down
181 changes: 180 additions & 1 deletion lexical-write-integer/tests/digit_count_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

mod util;

use lexical_write_integer::digit_count;
use lexical_write_integer::decimal::DecimalCount;
use lexical_write_integer::digit_count::{self, DigitCount};
use proptest::prelude::*;

use crate::util::default_proptest_config;

#[test]
fn fast_log2_test() {
Expand All @@ -28,8 +32,183 @@ fn slow_log2(x: u32) -> usize {
}
}

#[test]
fn base10_count_test() {
assert_eq!(1, 0u32.digit_count(10));
assert_eq!(1, 9u32.digit_count(10));
assert_eq!(2, 10u32.digit_count(10));
assert_eq!(2, 11u32.digit_count(10));
assert_eq!(2, 99u32.digit_count(10));
assert_eq!(3, 100u32.digit_count(10));
assert_eq!(3, 101u32.digit_count(10));
}

#[test]
fn base2_count_test() {
assert_eq!(1, 0u32.digit_count(2));
assert_eq!(1, 1u32.digit_count(2));
assert_eq!(2, 2u32.digit_count(2));
assert_eq!(2, 3u32.digit_count(2));
assert_eq!(3, 4u32.digit_count(2));

if cfg!(feature = "power-of-two") {
for i in 1usize..=127 {
let value = 2u128.pow(i as u32);
assert_eq!(i + 1, value.digit_count(2));
assert_eq!(i + 1, (value + 1).digit_count(2));
assert_eq!(i, (value - 1).digit_count(2));
}
}
}

#[test]
fn base4_count_test() {
assert_eq!(1, 0u32.digit_count(4));
assert_eq!(1, 1u32.digit_count(4));
assert_eq!(1, 3u32.digit_count(4));
assert_eq!(2, 4u32.digit_count(4));
assert_eq!(2, 5u32.digit_count(4));
assert_eq!(2, 15u32.digit_count(4));
assert_eq!(3, 16u32.digit_count(4));
assert_eq!(3, 17u32.digit_count(4));

if cfg!(feature = "power-of-two") {
for i in 1usize..=63 {
let value = 4u128.pow(i as u32);
assert_eq!(i + 1, value.digit_count(4));
assert_eq!(i + 1, (value + 1).digit_count(4));
assert_eq!(i, (value - 1).digit_count(4));

let halfway = value + 2u128.pow(i as u32);
assert_eq!(i + 1, halfway.digit_count(4));
assert_eq!(i + 1, (halfway + 1).digit_count(4));
assert_eq!(i + 1, (halfway - 1).digit_count(4));
}
}
}

#[test]
fn base8_count_test() {
assert_eq!(1, 0u32.digit_count(8));
assert_eq!(1, 1u32.digit_count(8));
assert_eq!(1, 7u32.digit_count(8));
assert_eq!(2, 8u32.digit_count(8));
assert_eq!(2, 9u32.digit_count(8));
assert_eq!(2, 63u32.digit_count(8));
assert_eq!(3, 64u32.digit_count(8));
assert_eq!(3, 65u32.digit_count(8));

if cfg!(feature = "power-of-two") {
for i in 1usize..=31 {
let value = 8u128.pow(i as u32);
assert_eq!(i + 1, value.digit_count(8));
assert_eq!(i + 1, (value + 1).digit_count(8));
assert_eq!(i, (value - 1).digit_count(8));

let halfway = value + 4u128.pow(i as u32);
assert_eq!(i + 1, halfway.digit_count(8));
assert_eq!(i + 1, (halfway + 1).digit_count(8));
assert_eq!(i + 1, (halfway - 1).digit_count(8));
}
}
}

#[test]
fn base16_count_test() {
assert_eq!(1, 0u32.digit_count(16));
assert_eq!(1, 1u32.digit_count(16));
assert_eq!(1, 15u32.digit_count(16));
assert_eq!(2, 16u32.digit_count(16));
assert_eq!(2, 17u32.digit_count(16));
assert_eq!(2, 255u32.digit_count(16));
assert_eq!(3, 256u32.digit_count(16));
assert_eq!(3, 257u32.digit_count(16));

if cfg!(feature = "power-of-two") {
for i in 1usize..=15 {
let value = 16u128.pow(i as u32);
assert_eq!(i + 1, value.digit_count(16));
assert_eq!(i + 1, (value + 1).digit_count(16));
assert_eq!(i, (value - 1).digit_count(16));

let halfway = value + 8u128.pow(i as u32);
assert_eq!(i + 1, halfway.digit_count(16));
assert_eq!(i + 1, (halfway + 1).digit_count(16));
assert_eq!(i + 1, (halfway - 1).digit_count(16));
}
}
}

#[test]
fn base32_count_test() {
assert_eq!(1, 0u32.digit_count(32));
assert_eq!(1, 1u32.digit_count(32));
assert_eq!(1, 31u32.digit_count(32));
assert_eq!(2, 32u32.digit_count(32));
assert_eq!(2, 33u32.digit_count(32));
assert_eq!(2, 1023u32.digit_count(32));
assert_eq!(3, 1024u32.digit_count(32));
assert_eq!(3, 1025u32.digit_count(32));

if cfg!(feature = "power-of-two") {
for i in 1usize..=7 {
let value = 32u128.pow(i as u32);
assert_eq!(i + 1, value.digit_count(32));
assert_eq!(i + 1, (value + 1).digit_count(32));
assert_eq!(i, (value - 1).digit_count(32));

let halfway = value + 16u128.pow(i as u32);
assert_eq!(i + 1, halfway.digit_count(32));
assert_eq!(i + 1, (halfway + 1).digit_count(32));
assert_eq!(i + 1, (halfway - 1).digit_count(32));
}
}
}

default_quickcheck! {
fn decimal_count_quickcheck(x: u32) -> bool {
x.digit_count(10) == x.decimal_count()
}

fn fast_log2_quickcheck(x: u32) -> bool {
slow_log2(x) == digit_count::fast_log2(x)
}
}

macro_rules! ilog {
($x:ident, $radix:expr) => {{
if $x > 0 {
$x.ilog($radix as _) as usize
} else {
0usize
}
}};
}

proptest! {
#![proptest_config(default_proptest_config())]

#[test]
fn basen_u64_test(x: u64, radix in 2u32..=36) {
prop_assert_eq!(x.digit_count(radix), ilog!(x, radix) + 1);
}

#[test]
#[cfg(feature = "radix")]
fn basen_u128_test(x: u128, radix in 2u32..=36) {
prop_assert_eq!(x.digit_count(radix), ilog!(x, radix) + 1);
}

#[test]
#[cfg(all(feature = "power-of-two", not(feature = "radix")))]
fn basen_u128_test(x: u128, power in 1u32..=5) {
let radix = 2u32.pow(power);
prop_assert_eq!(x.digit_count(radix), ilog!(x, radix) + 1);
}

#[test]
#[cfg(not(feature = "power-of-two"))]
fn basen_u128_test(x: u128) {
prop_assert_eq!(x.digit_count(10), ilog!(x, 10) + 1);
}
}

0 comments on commit 8421974

Please sign in to comment.