Skip to content

Commit

Permalink
Remove most unsafety from the Grisu algorithm.
Browse files Browse the repository at this point in the history
This is the compact, slow algorithm and mostly did not impact
performance. It was mostly within the range of +/-3% of the baseline.
  • Loading branch information
Alexhuszagh committed Sep 14, 2024
1 parent a27501d commit 4bbec3a
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 61 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ Lexical is highly customizable, and contains numerous other optional features:

To ensure the safety when bounds checking is disabled, we extensively fuzz the all numeric conversion routines. See the [Safety](#safety) section below for more information.

Lexical also places a heavy focus on code bloat: with algorithms both optimized for performance and size. By default, this focuses on performance, however, using the `compact` feature, you can also opt-in to reduced code size at the cost of performance. The compact algorithms minimize the use of pre-computed tables and other optimizations at the cost of performance.
Lexical also places a heavy focus on code bloat: with algorithms both optimized for performance and size. By default, this focuses on performance, however, using the `compact` feature, you can also opt-in to reduced code size at the cost of performance. The compact algorithms minimize the use of pre-computed tables and other optimizations at a major cost to performance.

## Customization

Expand Down
2 changes: 1 addition & 1 deletion lexical-parse-float/src/libm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ macro_rules! i {
if cfg!(feature = "safe") {
$array[$index]
} else {
$array.get_unchecked($index)
*$array.get_unchecked($index)
}
}
};
Expand Down
1 change: 1 addition & 0 deletions lexical-write-float/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use lexical_write_integer::write::WriteInteger;
///
/// Panics if exponent notation is used, and the exponent base and
/// mantissa radix are not the same in `FORMAT`.
// TODO: This needs to be safe
pub unsafe fn write_float<F: Float, const FORMAT: u128>(
float: F,
bytes: &mut [u8],
Expand Down
92 changes: 41 additions & 51 deletions lexical-write-float/src/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ use lexical_util::num::{AsPrimitive, Float};
///
/// Safe as long as the float isn't special (NaN or Infinity), and `bytes`
/// is large enough to hold the significant digits.
pub unsafe fn write_float<F: RawFloat, const FORMAT: u128>(
// TODO: This needs to be safe
pub fn write_float<F: RawFloat, const FORMAT: u128>(
float: F,
bytes: &mut [u8],
options: &Options,
Expand All @@ -70,11 +71,10 @@ pub unsafe fn write_float<F: RawFloat, const FORMAT: u128>(
} else {
// SAFETY: safe since `digits.len()` is large enough to always hold
// the generated digits, which is always <= 18.
unsafe {
let (start, k) = grisu(float, &mut digits);
let (end, carried) = shared::truncate_and_round_decimal(&mut digits, start, options);
(end, k + start as i32 - end as i32, carried)
}
let (start, k) = grisu(float, &mut digits);
let (end, carried) =
unsafe { shared::truncate_and_round_decimal(&mut digits, start, options) };
(end, k + start as i32 - end as i32, carried)
};

let sci_exp = kappa + digit_count as i32 - 1 + carried as i32;
Expand Down Expand Up @@ -118,19 +118,17 @@ pub unsafe fn write_float_scientific<const FORMAT: u128>(

// Write our significant digits
let mut cursor: usize;
bytes[0] = digits[0];
bytes[1] = decimal_point;
unsafe {
// SAFETY: safe since `digits.len() == 32 && bytes.len() >= 2`.
index_unchecked_mut!(bytes[0] = digits[0]);
index_unchecked_mut!(bytes[1]) = decimal_point;

// SAFETY: safe if bytes is large enough to store all significant digits.
if !format.no_exponent_without_fraction() && digit_count == 1 && options.trim_floats() {
// No more digits and need to trim floats.
cursor = 1;
} else if digit_count < exact_count {
// Write our significant digits.
let src = index_unchecked!(digits[1..digit_count]).as_ptr();
let dst = &mut index_unchecked_mut!(bytes[2..digit_count + 1]);
let src = digits[1..digit_count].as_ptr();
let dst = &mut bytes[2..digit_count + 1];
copy_nonoverlapping_unchecked!(dst, src, digit_count - 1);
cursor = digit_count + 1;

Expand All @@ -140,12 +138,12 @@ pub unsafe fn write_float_scientific<const FORMAT: u128>(
cursor += zeros;
} else if digit_count == 1 {
// Write a single, trailing 0.
index_unchecked_mut!(bytes[2]) = b'0';
bytes[2] = b'0';
cursor = 3;
} else {
// Write our significant digits.
let src = index_unchecked!(digits[1..digit_count]).as_ptr();
let dst = &mut index_unchecked_mut!(bytes[2..digit_count + 1]);
let src = digits[1..digit_count].as_ptr();
let dst = &mut bytes[2..digit_count + 1];
copy_nonoverlapping_unchecked!(dst, src, digit_count - 1);
cursor = digit_count + 1;
}
Expand Down Expand Up @@ -184,9 +182,9 @@ pub unsafe fn write_float_negative_exponent<const FORMAT: u128>(
// Write our 0 digits. Note that we cannot have carried, since we previously
// adjusted for carrying and rounding before.
// SAFETY: safe if `bytes.len() < BUFFER_SIZE - 2`.
bytes[0] = b'0';
bytes[1] = decimal_point;
unsafe {
index_unchecked_mut!(bytes[0]) = b'0';
index_unchecked_mut!(bytes[1]) = decimal_point;
let digits = &mut index_unchecked_mut!(bytes[2..sci_exp + 1]);
slice_fill_unchecked!(digits, b'0');
}
Expand All @@ -196,7 +194,7 @@ pub unsafe fn write_float_negative_exponent<const FORMAT: u128>(
// SAFETY: safe if the buffer is large enough to hold all the significant digits.
unsafe {
let src = digits.as_ptr();
let dst = &mut index_unchecked_mut!(bytes[cursor..cursor + digit_count]);
let dst = &mut bytes[cursor..cursor + digit_count];
copy_nonoverlapping_unchecked!(dst, src, digit_count);
cursor += digit_count;
}
Expand All @@ -209,7 +207,7 @@ pub unsafe fn write_float_negative_exponent<const FORMAT: u128>(
let zeros = exact_count - digit_count;
// SAFETY: safe if bytes is large enough to hold the significant digits.
unsafe {
slice_fill_unchecked!(index_unchecked_mut!(bytes[cursor..cursor + zeros]), b'0');
slice_fill_unchecked!(bytes[cursor..cursor + zeros], b'0');
}
cursor += zeros;
}
Expand Down Expand Up @@ -248,19 +246,19 @@ pub unsafe fn write_float_positive_exponent<const FORMAT: u128>(
// SAFETY: safe if the buffer is large enough to hold the significant digits.
unsafe {
let src = digits.as_ptr();
let dst = &mut index_unchecked_mut!(bytes[..digit_count]);
let dst = &mut bytes[..digit_count];
copy_nonoverlapping_unchecked!(dst, src, digit_count);
let digits = &mut index_unchecked_mut!(bytes[digit_count..leading_digits]);
let digits = &mut bytes[digit_count..leading_digits];
slice_fill_unchecked!(digits, b'0');
}
cursor = leading_digits;
digit_count = leading_digits;
// Only write decimal point if we're not trimming floats.
if !options.trim_floats() {
// SAFETY: safe if `cursor + 2 <= bytes.len()`.
unsafe { index_unchecked_mut!(bytes[cursor]) = decimal_point };
bytes[cursor] = decimal_point;
cursor += 1;
unsafe { index_unchecked_mut!(bytes[cursor]) = b'0' };
bytes[cursor] = b'0';
cursor += 1;
digit_count += 1;
} else {
Expand All @@ -273,16 +271,16 @@ pub unsafe fn write_float_positive_exponent<const FORMAT: u128>(
// SAFETY: safe if the buffer is large enough to hold the significant digits.
unsafe {
let src = digits.as_ptr();
let dst = &mut index_unchecked_mut!(bytes[..leading_digits]);
let dst = &mut bytes[..leading_digits];
copy_nonoverlapping_unchecked!(dst, src, leading_digits);
index_unchecked_mut!(bytes[leading_digits]) = decimal_point;
bytes[leading_digits] = decimal_point;
}

// Write the digits after the decimal point.
// SAFETY: safe if the buffer is large enough to hold the significant digits.
unsafe {
let src = index_unchecked!(digits[leading_digits..digit_count]).as_ptr();
let dst = &mut index_unchecked_mut!(bytes[leading_digits + 1..digit_count + 1]);
let src = digits[leading_digits..digit_count].as_ptr();
let dst = &mut bytes[leading_digits + 1..digit_count + 1];
copy_nonoverlapping_unchecked!(dst, src, digit_count - leading_digits);
}

Expand All @@ -298,7 +296,7 @@ pub unsafe fn write_float_positive_exponent<const FORMAT: u128>(
let zeros = exact_count - digit_count;
// SAFETY: safe if the buffer is large enough to hold the significant digits.
unsafe {
let digits = &mut index_unchecked_mut!(bytes[cursor..cursor + zeros]);
let digits = &mut bytes[cursor..cursor + zeros];
slice_fill_unchecked!(digits, b'0');
}
cursor += zeros;
Expand All @@ -311,11 +309,7 @@ pub unsafe fn write_float_positive_exponent<const FORMAT: u128>(
// ---------

/// Round digit to normal approximation.
///
/// # Safety
///
/// Safe as long as `digit_count <= digits.len() && digit_count > 0`.
unsafe fn round_digit(
fn round_digit(
digits: &mut [u8],
digit_count: usize,
delta: u64,
Expand All @@ -329,8 +323,7 @@ unsafe fn round_digit(
&& delta - rem >= kappa
&& (rem + kappa < mant || mant - rem > rem + kappa - mant)
{
// SAFETY: safe if `digit_count > 0`.
unsafe { index_unchecked_mut!(digits[digit_count - 1]) -= 1 };
digits[digit_count - 1] -= 1;
rem += kappa;
}
}
Expand All @@ -340,7 +333,7 @@ unsafe fn round_digit(
/// # Safety
///
/// Safe as long as the extended float does not represent a 0.
pub unsafe fn generate_digits(
pub fn generate_digits(
fp: &ExtendedFloat80,
upper: &ExtendedFloat80,
lower: &ExtendedFloat80,
Expand All @@ -367,8 +360,7 @@ pub unsafe fn generate_digits(
while kappa > 0 {
let digit = part1 / div;
if digit != 0 || idx != 0 {
// SAFETY: safe, digits.len() == 32.
unsafe { index_unchecked_mut!(digits[idx]) = digit_to_char_const(digit as u32, 10) };
digits[idx] = digit_to_char_const(digit as u32, 10);
idx += 1;
}

Expand All @@ -378,8 +370,7 @@ pub unsafe fn generate_digits(
let tmp = (part1 << -one.exp) + part2;
if tmp <= delta {
k += kappa;
// SAFETY: safe since `idx > 0 && idx < digits.len()`.
unsafe { round_digit(digits, idx, delta, tmp, div << -one.exp, wmant) };
round_digit(digits, idx, delta, tmp, div << -one.exp, wmant);
return (idx, k);
}
div /= 10;
Expand All @@ -399,15 +390,14 @@ pub unsafe fn generate_digits(
// In practice, this can't exceed 18, however, we have extra digits
// **just** in case, since we write technically up to 29 here
// before we underflow TENS.
unsafe { index_unchecked_mut!(digits[idx]) = digit_to_char_const(digit as u32, 10) };
digits[idx] = digit_to_char_const(digit as u32, 10);
idx += 1;
}

part2 &= one.mant - 1;
if part2 < delta {
k += kappa;
// SAFETY: safe since `idx < digits.len() && idx > 0`.
unsafe { round_digit(digits, idx, delta, part2, one.mant, wmant * ten) };
round_digit(digits, idx, delta, part2, one.mant, wmant * ten);
return (idx, k);
}
ten *= 10;
Expand All @@ -423,15 +413,15 @@ pub unsafe fn generate_digits(
/// # Safety
///
/// Safe as long as float is not 0.
pub unsafe fn grisu<F: Float>(float: F, digits: &mut [u8]) -> (usize, i32) {
pub fn grisu<F: Float>(float: F, digits: &mut [u8]) -> (usize, i32) {
debug_assert!(float != F::ZERO);

let mut w = from_float(float);

let (lower, upper) = normalized_boundaries::<F>(&w);
normalize(&mut w);
// SAFETY: safe since upper.exp must be in the valid binary range.
let (cp, ki) = unsafe { cached_grisu_power(upper.exp) };
let (cp, ki) = cached_grisu_power(upper.exp);

let w = mul(&w, &cp);
let mut upper = mul(&upper, &cp);
Expand All @@ -443,7 +433,7 @@ pub unsafe fn grisu<F: Float>(float: F, digits: &mut [u8]) -> (usize, i32) {
let k = -ki;

// SAFETY: safe since generate_digits can only generate 18 digits
unsafe { generate_digits(&w, &upper, &lower, digits, k) }
generate_digits(&w, &upper, &lower, digits, k)
}

// EXTENDED FLOAT
Expand Down Expand Up @@ -551,7 +541,7 @@ pub fn mul(x: &ExtendedFloat80, y: &ExtendedFloat80) -> ExtendedFloat80 {
/// # Safety
///
/// Safe as long as exp is within the range [-1140, 1089]
unsafe fn cached_grisu_power(exp: i32) -> (ExtendedFloat80, i32) {
fn cached_grisu_power(exp: i32) -> (ExtendedFloat80, i32) {
// Make the bounds 64 + 1 larger, since those will still work,
// but the exp can be biased within that range.
debug_assert!(((-1075 - 64 - 1)..=(1024 + 64 + 1)).contains(&exp));
Expand All @@ -570,7 +560,7 @@ unsafe fn cached_grisu_power(exp: i32) -> (ExtendedFloat80, i32) {

loop {
// SAFETY: safe as long as the original exponent was in range.
let mant = unsafe { f64::grisu_power(idx) };
let mant = f64::grisu_power(idx);
let decexp = fast_decimal_power(idx);
let binexp = fast_binary_power(decexp);
let current = exp + binexp + 64;
Expand Down Expand Up @@ -617,9 +607,9 @@ pub trait GrisuFloat: Float {
///
/// Safe as long as `index < GRISU_POWERS_OF_TEN.len()`.
#[inline(always)]
unsafe fn grisu_power(index: usize) -> u64 {
fn grisu_power(index: usize) -> u64 {
debug_assert!(index <= GRISU_POWERS_OF_TEN.len());
unsafe { index_unchecked!(GRISU_POWERS_OF_TEN[index]) }
GRISU_POWERS_OF_TEN[index]
}
}

Expand All @@ -636,7 +626,7 @@ macro_rules! grisu_unimpl {
($($t:ident)*) => ($(
impl GrisuFloat for $t {
#[inline(always)]
unsafe fn grisu_power(_: usize) -> u64 {
fn grisu_power(_: usize) -> u64 {
unimplemented!()
}
}
Expand Down
16 changes: 8 additions & 8 deletions lexical-write-float/tests/compact_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ fn mul_test() {

fn grisu<T: RawFloat>(f: T, expected: &str, k: i32) {
let mut buffer = [b'\x00'; 32];
let (count, real_k) = unsafe { compact::grisu(f, &mut buffer) };
let (count, real_k) = compact::grisu(f, &mut buffer);
let actual = unsafe { std::str::from_utf8_unchecked(&buffer[..count]) };
assert_eq!(actual, expected);
assert_eq!(k, real_k);
Expand All @@ -356,7 +356,7 @@ fn grisu_test() {

fn write_float<T: RawFloat, const FORMAT: u128>(f: T, options: &Options, expected: &str) {
let mut buffer = [b'\x00'; BUFFER_SIZE];
let count = unsafe { compact::write_float::<_, FORMAT>(f, &mut buffer, options) };
let count = compact::write_float::<_, FORMAT>(f, &mut buffer, options);
let actual = unsafe { std::str::from_utf8_unchecked(&buffer[..count]) };
assert_eq!(actual, expected);
}
Expand Down Expand Up @@ -640,7 +640,7 @@ fn f32_roundtrip_test() {
let mut buffer = [b'\x00'; BUFFER_SIZE];
let options = Options::builder().build().unwrap();
for &float in F32_DATA.iter() {
let count = unsafe { compact::write_float::<_, DECIMAL>(float, &mut buffer, &options) };
let count = compact::write_float::<_, DECIMAL>(float, &mut buffer, &options);
let actual = unsafe { std::str::from_utf8_unchecked(&buffer[..count]) };
let roundtrip = actual.parse::<f32>();
assert_eq!(roundtrip, Ok(float));
Expand Down Expand Up @@ -754,7 +754,7 @@ fn f64_roundtrip_test() {
let mut buffer = [b'\x00'; BUFFER_SIZE];
let options = Options::builder().build().unwrap();
for &float in F64_DATA.iter() {
let count = unsafe { compact::write_float::<_, DECIMAL>(float, &mut buffer, &options) };
let count = compact::write_float::<_, DECIMAL>(float, &mut buffer, &options);
let actual = unsafe { std::str::from_utf8_unchecked(&buffer[..count]) };
let roundtrip = actual.parse::<f64>();
assert_eq!(roundtrip, Ok(float));
Expand All @@ -769,7 +769,7 @@ default_quickcheck! {
if f.is_special() {
true
} else {
let count = unsafe { compact::write_float::<_, DECIMAL>(f, &mut buffer, &options) };
let count = compact::write_float::<_, DECIMAL>(f, &mut buffer, &options);
let actual = unsafe { std::str::from_utf8_unchecked(&buffer[..count]) };
let roundtrip = actual.parse::<f32>();
roundtrip == Ok(f)
Expand All @@ -783,7 +783,7 @@ default_quickcheck! {
if f.is_special() {
true
} else {
let count = unsafe { compact::write_float::<_, DECIMAL>(f, &mut buffer, &options) };
let count = compact::write_float::<_, DECIMAL>(f, &mut buffer, &options);
let actual = unsafe { std::str::from_utf8_unchecked(&buffer[..count]) };
let roundtrip = actual.parse::<f64>();
roundtrip == Ok(f)
Expand All @@ -800,7 +800,7 @@ proptest! {
let options = Options::builder().build().unwrap();
let f = f.abs();
if !f.is_special() {
let count = unsafe { compact::write_float::<_, DECIMAL>(f, &mut buffer, &options) };
let count = compact::write_float::<_, DECIMAL>(f, &mut buffer, &options);
let actual = unsafe { std::str::from_utf8_unchecked(&buffer[..count]) };
let roundtrip = actual.parse::<f32>();
prop_assert_eq!(roundtrip, Ok(f))
Expand All @@ -813,7 +813,7 @@ proptest! {
let options = Options::builder().build().unwrap();
let f = f.abs();
if !f.is_special() {
let count = unsafe { compact::write_float::<_, DECIMAL>(f, &mut buffer, &options) };
let count = compact::write_float::<_, DECIMAL>(f, &mut buffer, &options);
let actual = unsafe { std::str::from_utf8_unchecked(&buffer[..count]) };
let roundtrip = actual.parse::<f64>();
prop_assert_eq!(roundtrip, Ok(f))
Expand Down

0 comments on commit 4bbec3a

Please sign in to comment.