Skip to content

Commit

Permalink
Add lossless TryFrom implementations.
Browse files Browse the repository at this point in the history
This adds `from_f*_lossless` functions, which then are implemented as `TryFrom`, which will only return values if they can fully roundtrip.

Related to starkat99#90
  • Loading branch information
Alexhuszagh committed Dec 15, 2024
1 parent a441b02 commit 0719610
Show file tree
Hide file tree
Showing 11 changed files with 649 additions and 6 deletions.
13 changes: 10 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## [Unreleased]

## [0.1.1][v0.1.1] - 2024-12-12 <a name="0.1.0"></a>
## [0.1.2][v0.1.2] - 2024-12-12 <a name="0.1.2"></a>

## Added

- Added lossless `from_f*_lossless` functions and `TryFrom` implementations which will never have rounding error.

## [0.1.1][v0.1.1] - 2024-12-12 <a name="0.1.1"></a>

## Changed

Expand Down Expand Up @@ -471,8 +477,9 @@ These were all changes for half, which `float16` is a fork of.
<!-- Versions -->

[Unreleased]: https://github.com/starkat99/half-rs/compare/v2.4.1...HEAD
[v0.1.1]: https://github.com/Alexhuszagh/float16/compare/0.1.1...0.1.0
[v0.1.0]: https://github.com/Alexhuszagh/float16/compare/0.1.0...v2.4.0
[v0.1.2]: https://github.com/Alexhuszagh/float16/compare/v0.1.1...v0.1.2
[v0.1.1]: https://github.com/Alexhuszagh/float16/compare/v0.1.0...v0.1.1
[v0.1.0]: https://github.com/Alexhuszagh/float16/compare/v0.1.0...v2.4.0
[2.4.0]: https://github.com/starkat99/half-rs/compare/v2.3.1...v2.4.0
[2.3.1]: https://github.com/starkat99/half-rs/compare/v2.3.0...v2.3.1
[2.3.0]: https://github.com/starkat99/half-rs/compare/v2.2.1...v2.3.0
Expand Down
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "float16"
# Remember to keep in sync with html_root_url crate attribute
version = "0.1.1"
version = "0.1.2"
authors = ["Kathryn Long <[email protected]>", "Alex Huszagh <[email protected]>"]
description = "Half-precision floating point f16 and bf16 types for Rust implementing the IEEE 754-2008 standard binary16 and bfloat16 types."
repository = "https://github.com/Alexhuszagh/float16"
Expand Down
2 changes: 1 addition & 1 deletion devel/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions devel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ authors = ["Alex Huszagh <[email protected]>"]
edition = "2021"
publish = false

[features]
std = ["float16/std"]

[dependencies.float16]
path = ".."
default-features = false
Expand Down
21 changes: 21 additions & 0 deletions devel/tests/bfloat_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,24 @@ fn qc_roundtrip_bf16_f64_is_identity(bits: u16) -> bool {
f.to_bits() == roundtrip.to_bits()
}
}

#[quickcheck]
fn qc_roundtrip_from_f32_lossless(f: f32) -> bool {
if let Some(b) = bf16::from_f32_lossless(f) {
(b.is_nan() == f.is_nan()) || b.as_f32() == f
} else {
true
}
}

#[quickcheck]
fn qc_roundtrip_from_f64_lossless(f: f64) -> bool {
if let Some(b) = bf16::from_f64_lossless(f) {
if !((b.is_nan() && f.is_nan()) || b.as_f64() == f) {
println!("b {}, f {}", b.to_bits(), f.to_bits());
}
(b.is_nan() == f.is_nan()) || b.as_f64() == f
} else {
true
}
}
248 changes: 248 additions & 0 deletions src/bfloat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ use core::{
str::FromStr,
};

use crate::error::TryFromFloatError;
use crate::try_from::try_from_lossless;

pub(crate) mod convert;

/// A 16-bit floating point type implementing the [`bfloat16`] format.
Expand Down Expand Up @@ -76,6 +79,61 @@ impl bf16 {
bf16(convert::f32_to_bf16(value))
}

/// Create a [`struct@bf16`] loslessly from an [`f32`].
///
/// This is only true if the [`f32`] is non-finite
/// (infinite or NaN), or no non-zero bits would
/// be truncated.
///
/// "Lossless" does not mean the data is represented the
/// same as a decimal number. For example, an [`f32`]
/// and [`f64`] have the significant digits (excluding the
/// hidden bit) for a value closest to `1e35` of:
/// - `f32`: `110100001001100001100`
/// - `f64`: `11010000100110000110000000000000000000000000000000`
///
/// However, the [`f64`] is displayed as `1.0000000409184788e+35`,
/// while the value closest to `1e35` in [`f64`] is
/// `11010000100110000101110010110001110100110110000010`. This
/// makes it look like precision has been lost but this is
/// due to the approximations used to represent binary values as
/// a decimal.
///
/// This does not respect signalling NaNs: if the value
/// is NaN or inf, then it will return that value.
///
/// Since [`struct@bf16`] has the same number of exponent
/// bits as [`f32`], this is effectively just checking if the
/// value is non-finite (infinite or NaN) or the value
/// is normal and the lower 16 bits are 0.
#[inline]
pub const fn from_f32_lossless(value: f32) -> Option<bf16> {
// NOTE: This logic is effectively just getting the top 16 bits
// and the bottom 16 bits, but it's done explicitly with mantissa
// digits for this reason. For explicit clarity, we remove the
// hidden bit in our exponent logic
const BF16_MANT_BITS: u32 = bf16::MANTISSA_DIGITS - 1;
const F32_MANT_BITS: u32 = f32::MANTISSA_DIGITS - 1;
const EXP_MASK: u32 = (f32::MAX_EXP as u32 * 2 - 1) << F32_MANT_BITS;
const TRUNCATED: u32 = F32_MANT_BITS - BF16_MANT_BITS;
const TRUNC_MASK: u32 = (1 << TRUNCATED) - 1;

// SAFETY: safe since it's plain old data
let bits: u32 = unsafe { core::mem::transmute(value) };

// `bits & exp_mask == exp_mask` -> infinite or NaN
// `truncated == 0` -> no bits truncated
// since the exp ranges are the same, any denormal handling
// is already implicit.
let exp = bits & EXP_MASK;
let is_special = exp == EXP_MASK;
if is_special || bits & TRUNC_MASK == 0 {
Some(Self::from_f32_const(value))
} else {
None
}
}

/// Constructs a [`struct@bf16`] value from a 64-bit floating point value.
///
/// This operation is lossy. If the 64-bit value is to large to fit, ±∞ will
Expand Down Expand Up @@ -107,6 +165,41 @@ impl bf16 {
bf16(convert::f64_to_bf16(value))
}

/// Create a [`struct@bf16`] loslessly from an [`f64`].
///
/// This is only true if the [`f64`] is non-finite
/// (infinite or NaN), zero, or the exponent can be
/// represented by a normal [`struct@bf16`] and no
/// non-zero bits would be truncated.
///
/// "Lossless" does not mean the data is represented the
/// same as a decimal number. For example, an [`f32`]
/// and [`f64`] have the significant digits (excluding the
/// hidden bit) for a value closest to `1e35` of:
/// - `f32`: `110100001001100001100`
/// - `f64`: `11010000100110000110000000000000000000000000000000`
///
/// However, the [`f64`] is displayed as `1.0000000409184788e+35`,
/// while the value closest to `1e35` in [`f64`] is
/// `11010000100110000101110010110001110100110110000010`. This
/// makes it look like precision has been lost but this is
/// due to the approximations used to represent binary values as
/// a decimal.
///
/// This does not respect signalling NaNs: if the value
/// is NaN or inf, then it will return that value.
#[inline]
pub const fn from_f64_lossless(value: f64) -> Option<bf16> {
try_from_lossless!(
value => value,
half => bf16,
full => f64,
half_bits => u16,
full_bits => u64,
to_half => from_f64
)
}

/// Converts a [`struct@bf16`] into the underlying bit representation.
#[inline]
#[must_use]
Expand Down Expand Up @@ -883,6 +976,24 @@ impl From<u8> for bf16 {
}
}

impl TryFrom<f32> for bf16 {
type Error = TryFromFloatError;

#[inline]
fn try_from(x: f32) -> Result<Self, Self::Error> {
Self::from_f32_lossless(x).ok_or(TryFromFloatError(()))
}
}

impl TryFrom<f64> for bf16 {
type Error = TryFromFloatError;

#[inline]
fn try_from(x: f64) -> Result<Self, Self::Error> {
Self::from_f64_lossless(x).ok_or(TryFromFloatError(()))
}
}

impl PartialEq for bf16 {
fn eq(&self, other: &bf16) -> bool {
eq(*self, *other)
Expand Down Expand Up @@ -1730,4 +1841,141 @@ mod test {
assert_eq!(bf16::from_f64(252.50f64).to_bits(), bf16::from_f64(252.0).to_bits());
assert_eq!(bf16::from_f64(252.51f64).to_bits(), bf16::from_f64(253.0).to_bits());
}

#[test]
fn from_f32_lossless() {
let from_f32 = |v: f32| bf16::from_f32_lossless(v);
let roundtrip = |v: f32, expected: Option<bf16>| {
let half = from_f32(v);
assert_eq!(half, expected);
if !expected.is_none() {
let as_f32 = expected.unwrap().to_f32_const();
assert_eq!(v, as_f32);
}
};

assert_eq!(from_f32(f32::NAN).map(bf16::is_nan), Some(true));
roundtrip(f32::INFINITY, Some(bf16::INFINITY));
roundtrip(f32::NEG_INFINITY, Some(bf16::NEG_INFINITY));
roundtrip(f32::from_bits(0b0_00000000_00000000000000000000000), Some(bf16(0)));
roundtrip(
f32::from_bits(0b1_00000000_00000000000000000000000),
Some(bf16(bf16::SIGN_MASK)),
);
roundtrip(f32::from_bits(1), None);
roundtrip(f32::from_bits(0b0_00001010_10101001010110100101110), None);
roundtrip(f32::from_bits(0b0_00001010_10101001010110100101110), None);
roundtrip(f32::from_bits(0b0_00001010_10101011000000000000000), None);
roundtrip(
f32::from_bits(0b0_00001010_10101010000000000000000),
Some(bf16(0b0_00001010_1010101)),
);
roundtrip(f32::from_bits(0b0_00000000_10000000000000000000000), Some(bf16(0x40)));
// special truncation with denormals, etc.
roundtrip(f32::from_bits(0b0_00000000_00000001000000000000000), None);
roundtrip(f32::from_bits(0b0_00000000_00000010000000000000000), Some(bf16(1)));
roundtrip(f32::from_bits(0b0_00000000_00000100000000000000000), Some(bf16(2)));
roundtrip(f32::from_bits(0b0_00000000_00000110000000000000000), Some(bf16(3)));
roundtrip(f32::from_bits(0b0_00000000_00000111000000000000000), None);
roundtrip(f32::from_bits(0b0_00001011_10100111101101101001001), None);
// 1.99170198e-35 and has bits until 16 to the end, so truncated 2
roundtrip(f32::from_bits(0b0_00001011_10100111100000000000000), None);
// 1.99170198e-35 and has bits until 15 to the end, so truncated 1
roundtrip(f32::from_bits(0b0_00001011_10100111000000000000000), None);
// 1.99170198e-35 and has bits until 15 to the end, so truncated 1
roundtrip(f32::from_bits(0b0_00001011_10100110000000000000000), Some(bf16(0x05d3)));
}

#[test]
fn from_f64_lossless() {
let from_f64 = |v: f64| bf16::from_f64_lossless(v);
let roundtrip = |v: f64, expected: Option<bf16>| {
let half = from_f64(v);
assert_eq!(half, expected);
if !expected.is_none() {
let as_f64 = expected.unwrap().to_f64_const();
assert_eq!(v, as_f64);
}
};

assert_eq!(from_f64(f64::NAN).map(bf16::is_nan), Some(true));
roundtrip(f64::INFINITY, Some(bf16::INFINITY));
roundtrip(f64::NEG_INFINITY, Some(bf16::NEG_INFINITY));
roundtrip(
f64::from_bits(0b0_00000000000_0000000000000000000000000000000000000000000000000000),
Some(bf16(0)),
);
roundtrip(
f64::from_bits(0b1_00000000000_0000000000000000000000000000000000000000000000000000),
Some(bf16(bf16::SIGN_MASK)),
);
roundtrip(
f64::from_bits(0b0_01110001010_1010100101011010010110110111111110000111101000001111),
None,
);
// 1.99170198e-35 and has bits until 44 to the end, so truncated 1
roundtrip(
f64::from_bits(0b0_01110001010_1010100100000000000000000000000000000000000000000000),
None,
);
roundtrip(
f64::from_bits(0b0_01110001010_1010100000000000000000000000000000000000000000000000),
Some(bf16(0x0554)),
);
roundtrip(
f64::from_bits(0b0_01110001010_1010101000000000000000000000000000000000000000000000),
Some(bf16(0x0555)),
);
roundtrip(
f64::from_bits(0b0_01110001010_1010110000000000000000000000000000000000000000000000),
Some(bf16(0x0556)),
);
roundtrip(
f64::from_bits(0b0_01110001010_1010111000000000000000000000000000000000000000000000),
Some(bf16(0x0557)),
);
roundtrip(
f64::from_bits(0b0_01110001010_1010101100000000000000000000000000000000000000000000),
None,
);
roundtrip(
f64::from_bits(0b0_01110001010_1010100110000000000000000000000000000000000000000000),
None,
);
roundtrip(
f64::from_bits(0b1_01110001010_1010100000000000000000000000000000000000000000000000),
Some(bf16(0x8554)),
);
roundtrip(
f64::from_bits(0b1_01110001010_1010101000000000000000000000000000000000000000000000),
Some(bf16(0x8555)),
);
// exp out of range but finite
roundtrip(
f64::from_bits(0b1_11110001010_1010101000000000000000000000000000000000000000000000),
None,
);
// explicitly check denormals
roundtrip(
f64::from_bits(0b0_01101111010_0000000000000000000000000000000000000000000000000000),
Some(bf16(1)),
);
roundtrip(
f64::from_bits(0b0_01101111011_1000000000000000000000000000000000000000000000000000),
Some(bf16(3)),
);
roundtrip(
f64::from_bits(0b0_01101111011_1100000000000000000000000000000000000000000000000000),
None,
);
// Due to being denormal, this is truncated out
roundtrip(
f64::from_bits(0b0_01101111010_0001000000000000000000000000000000000000000000000000),
None,
);
roundtrip(
f64::from_bits(0b0_01101111010_1000000000000000000000000000000000000000000000000000),
None,
);
}
}
Loading

0 comments on commit 0719610

Please sign in to comment.