Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lossless TryFrom implementations. #2

Merged
merged 1 commit into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## [Unreleased]

## [0.1.1][v0.1.1] - 2024-12-12 <a name="0.1.0"></a>
## [0.1.2][v0.1.2] - 2024-12-12 <a name="0.1.2"></a>

## Added

- Added lossless `from_f*_lossless` functions and `TryFrom` implementations which will never have rounding error.

## [0.1.1][v0.1.1] - 2024-12-12 <a name="0.1.1"></a>

## Changed

Expand Down Expand Up @@ -471,8 +477,9 @@ These were all changes for half, which `float16` is a fork of.
<!-- Versions -->

[Unreleased]: https://github.com/starkat99/half-rs/compare/v2.4.1...HEAD
[v0.1.1]: https://github.com/Alexhuszagh/float16/compare/0.1.1...0.1.0
[v0.1.0]: https://github.com/Alexhuszagh/float16/compare/0.1.0...v2.4.0
[v0.1.2]: https://github.com/Alexhuszagh/float16/compare/v0.1.1...v0.1.2
[v0.1.1]: https://github.com/Alexhuszagh/float16/compare/v0.1.0...v0.1.1
[v0.1.0]: https://github.com/Alexhuszagh/float16/compare/v0.1.0...v2.4.0
[2.4.0]: https://github.com/starkat99/half-rs/compare/v2.3.1...v2.4.0
[2.3.1]: https://github.com/starkat99/half-rs/compare/v2.3.0...v2.3.1
[2.3.0]: https://github.com/starkat99/half-rs/compare/v2.2.1...v2.3.0
Expand Down
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "float16"
# Remember to keep in sync with html_root_url crate attribute
version = "0.1.1"
version = "0.1.2"
authors = ["Kathryn Long <[email protected]>", "Alex Huszagh <[email protected]>"]
description = "Half-precision floating point f16 and bf16 types for Rust implementing the IEEE 754-2008 standard binary16 and bfloat16 types."
repository = "https://github.com/Alexhuszagh/float16"
Expand Down
2 changes: 1 addition & 1 deletion devel/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions devel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ authors = ["Alex Huszagh <[email protected]>"]
edition = "2021"
publish = false

[features]
std = ["float16/std"]

[dependencies.float16]
path = ".."
default-features = false
Expand Down
21 changes: 21 additions & 0 deletions devel/tests/bfloat_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,24 @@ fn qc_roundtrip_bf16_f64_is_identity(bits: u16) -> bool {
f.to_bits() == roundtrip.to_bits()
}
}

#[quickcheck]
fn qc_roundtrip_from_f32_lossless(f: f32) -> bool {
if let Some(b) = bf16::from_f32_lossless(f) {
(b.is_nan() == f.is_nan()) || b.as_f32() == f
} else {
true
}
}

#[quickcheck]
fn qc_roundtrip_from_f64_lossless(f: f64) -> bool {
if let Some(b) = bf16::from_f64_lossless(f) {
if !((b.is_nan() && f.is_nan()) || b.as_f64() == f) {
println!("b {}, f {}", b.to_bits(), f.to_bits());
}
(b.is_nan() == f.is_nan()) || b.as_f64() == f
} else {
true
}
}
248 changes: 248 additions & 0 deletions src/bfloat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ use core::{
str::FromStr,
};

use crate::error::TryFromFloatError;
use crate::try_from::try_from_lossless;

pub(crate) mod convert;

/// A 16-bit floating point type implementing the [`bfloat16`] format.
Expand Down Expand Up @@ -76,6 +79,61 @@ impl bf16 {
bf16(convert::f32_to_bf16(value))
}

/// Create a [`struct@bf16`] loslessly from an [`f32`].
///
/// This is only true if the [`f32`] is non-finite
/// (infinite or NaN), or no non-zero bits would
/// be truncated.
///
/// "Lossless" does not mean the data is represented the
/// same as a decimal number. For example, an [`f32`]
/// and [`f64`] have the significant digits (excluding the
/// hidden bit) for a value closest to `1e35` of:
/// - `f32`: `110100001001100001100`
/// - `f64`: `11010000100110000110000000000000000000000000000000`
///
/// However, the [`f64`] is displayed as `1.0000000409184788e+35`,
/// while the value closest to `1e35` in [`f64`] is
/// `11010000100110000101110010110001110100110110000010`. This
/// makes it look like precision has been lost but this is
/// due to the approximations used to represent binary values as
/// a decimal.
///
/// This does not respect signalling NaNs: if the value
/// is NaN or inf, then it will return that value.
///
/// Since [`struct@bf16`] has the same number of exponent
/// bits as [`f32`], this is effectively just checking if the
/// value is non-finite (infinite or NaN) or the value
/// is normal and the lower 16 bits are 0.
#[inline]
pub const fn from_f32_lossless(value: f32) -> Option<bf16> {
// NOTE: This logic is effectively just getting the top 16 bits
// and the bottom 16 bits, but it's done explicitly with mantissa
// digits for this reason. For explicit clarity, we remove the
// hidden bit in our exponent logic
const BF16_MANT_BITS: u32 = bf16::MANTISSA_DIGITS - 1;
const F32_MANT_BITS: u32 = f32::MANTISSA_DIGITS - 1;
const EXP_MASK: u32 = (f32::MAX_EXP as u32 * 2 - 1) << F32_MANT_BITS;
const TRUNCATED: u32 = F32_MANT_BITS - BF16_MANT_BITS;
const TRUNC_MASK: u32 = (1 << TRUNCATED) - 1;

// SAFETY: safe since it's plain old data
let bits: u32 = unsafe { core::mem::transmute(value) };

// `bits & exp_mask == exp_mask` -> infinite or NaN
// `truncated == 0` -> no bits truncated
// since the exp ranges are the same, any denormal handling
// is already implicit.
let exp = bits & EXP_MASK;
let is_special = exp == EXP_MASK;
if is_special || bits & TRUNC_MASK == 0 {
Some(Self::from_f32_const(value))
} else {
None
}
}

/// Constructs a [`struct@bf16`] value from a 64-bit floating point value.
///
/// This operation is lossy. If the 64-bit value is to large to fit, ±∞ will
Expand Down Expand Up @@ -107,6 +165,41 @@ impl bf16 {
bf16(convert::f64_to_bf16(value))
}

/// Create a [`struct@bf16`] loslessly from an [`f64`].
///
/// This is only true if the [`f64`] is non-finite
/// (infinite or NaN), zero, or the exponent can be
/// represented by a normal [`struct@bf16`] and no
/// non-zero bits would be truncated.
///
/// "Lossless" does not mean the data is represented the
/// same as a decimal number. For example, an [`f32`]
/// and [`f64`] have the significant digits (excluding the
/// hidden bit) for a value closest to `1e35` of:
/// - `f32`: `110100001001100001100`
/// - `f64`: `11010000100110000110000000000000000000000000000000`
///
/// However, the [`f64`] is displayed as `1.0000000409184788e+35`,
/// while the value closest to `1e35` in [`f64`] is
/// `11010000100110000101110010110001110100110110000010`. This
/// makes it look like precision has been lost but this is
/// due to the approximations used to represent binary values as
/// a decimal.
///
/// This does not respect signalling NaNs: if the value
/// is NaN or inf, then it will return that value.
#[inline]
pub const fn from_f64_lossless(value: f64) -> Option<bf16> {
try_from_lossless!(
value => value,
half => bf16,
full => f64,
half_bits => u16,
full_bits => u64,
to_half => from_f64
)
}

/// Converts a [`struct@bf16`] into the underlying bit representation.
#[inline]
#[must_use]
Expand Down Expand Up @@ -883,6 +976,24 @@ impl From<u8> for bf16 {
}
}

impl TryFrom<f32> for bf16 {
type Error = TryFromFloatError;

#[inline]
fn try_from(x: f32) -> Result<Self, Self::Error> {
Self::from_f32_lossless(x).ok_or(TryFromFloatError(()))
}
}

impl TryFrom<f64> for bf16 {
type Error = TryFromFloatError;

#[inline]
fn try_from(x: f64) -> Result<Self, Self::Error> {
Self::from_f64_lossless(x).ok_or(TryFromFloatError(()))
}
}

impl PartialEq for bf16 {
fn eq(&self, other: &bf16) -> bool {
eq(*self, *other)
Expand Down Expand Up @@ -1730,4 +1841,141 @@ mod test {
assert_eq!(bf16::from_f64(252.50f64).to_bits(), bf16::from_f64(252.0).to_bits());
assert_eq!(bf16::from_f64(252.51f64).to_bits(), bf16::from_f64(253.0).to_bits());
}

#[test]
fn from_f32_lossless() {
let from_f32 = |v: f32| bf16::from_f32_lossless(v);
let roundtrip = |v: f32, expected: Option<bf16>| {
let half = from_f32(v);
assert_eq!(half, expected);
if !expected.is_none() {
let as_f32 = expected.unwrap().to_f32_const();
assert_eq!(v, as_f32);
}
};

assert_eq!(from_f32(f32::NAN).map(bf16::is_nan), Some(true));
roundtrip(f32::INFINITY, Some(bf16::INFINITY));
roundtrip(f32::NEG_INFINITY, Some(bf16::NEG_INFINITY));
roundtrip(f32::from_bits(0b0_00000000_00000000000000000000000), Some(bf16(0)));
roundtrip(
f32::from_bits(0b1_00000000_00000000000000000000000),
Some(bf16(bf16::SIGN_MASK)),
);
roundtrip(f32::from_bits(1), None);
roundtrip(f32::from_bits(0b0_00001010_10101001010110100101110), None);
roundtrip(f32::from_bits(0b0_00001010_10101001010110100101110), None);
roundtrip(f32::from_bits(0b0_00001010_10101011000000000000000), None);
roundtrip(
f32::from_bits(0b0_00001010_10101010000000000000000),
Some(bf16(0b0_00001010_1010101)),
);
roundtrip(f32::from_bits(0b0_00000000_10000000000000000000000), Some(bf16(0x40)));
// special truncation with denormals, etc.
roundtrip(f32::from_bits(0b0_00000000_00000001000000000000000), None);
roundtrip(f32::from_bits(0b0_00000000_00000010000000000000000), Some(bf16(1)));
roundtrip(f32::from_bits(0b0_00000000_00000100000000000000000), Some(bf16(2)));
roundtrip(f32::from_bits(0b0_00000000_00000110000000000000000), Some(bf16(3)));
roundtrip(f32::from_bits(0b0_00000000_00000111000000000000000), None);
roundtrip(f32::from_bits(0b0_00001011_10100111101101101001001), None);
// 1.99170198e-35 and has bits until 16 to the end, so truncated 2
roundtrip(f32::from_bits(0b0_00001011_10100111100000000000000), None);
// 1.99170198e-35 and has bits until 15 to the end, so truncated 1
roundtrip(f32::from_bits(0b0_00001011_10100111000000000000000), None);
// 1.99170198e-35 and has bits until 15 to the end, so truncated 1
roundtrip(f32::from_bits(0b0_00001011_10100110000000000000000), Some(bf16(0x05d3)));
}

#[test]
fn from_f64_lossless() {
let from_f64 = |v: f64| bf16::from_f64_lossless(v);
let roundtrip = |v: f64, expected: Option<bf16>| {
let half = from_f64(v);
assert_eq!(half, expected);
if !expected.is_none() {
let as_f64 = expected.unwrap().to_f64_const();
assert_eq!(v, as_f64);
}
};

assert_eq!(from_f64(f64::NAN).map(bf16::is_nan), Some(true));
roundtrip(f64::INFINITY, Some(bf16::INFINITY));
roundtrip(f64::NEG_INFINITY, Some(bf16::NEG_INFINITY));
roundtrip(
f64::from_bits(0b0_00000000000_0000000000000000000000000000000000000000000000000000),
Some(bf16(0)),
);
roundtrip(
f64::from_bits(0b1_00000000000_0000000000000000000000000000000000000000000000000000),
Some(bf16(bf16::SIGN_MASK)),
);
roundtrip(
f64::from_bits(0b0_01110001010_1010100101011010010110110111111110000111101000001111),
None,
);
// 1.99170198e-35 and has bits until 44 to the end, so truncated 1
roundtrip(
f64::from_bits(0b0_01110001010_1010100100000000000000000000000000000000000000000000),
None,
);
roundtrip(
f64::from_bits(0b0_01110001010_1010100000000000000000000000000000000000000000000000),
Some(bf16(0x0554)),
);
roundtrip(
f64::from_bits(0b0_01110001010_1010101000000000000000000000000000000000000000000000),
Some(bf16(0x0555)),
);
roundtrip(
f64::from_bits(0b0_01110001010_1010110000000000000000000000000000000000000000000000),
Some(bf16(0x0556)),
);
roundtrip(
f64::from_bits(0b0_01110001010_1010111000000000000000000000000000000000000000000000),
Some(bf16(0x0557)),
);
roundtrip(
f64::from_bits(0b0_01110001010_1010101100000000000000000000000000000000000000000000),
None,
);
roundtrip(
f64::from_bits(0b0_01110001010_1010100110000000000000000000000000000000000000000000),
None,
);
roundtrip(
f64::from_bits(0b1_01110001010_1010100000000000000000000000000000000000000000000000),
Some(bf16(0x8554)),
);
roundtrip(
f64::from_bits(0b1_01110001010_1010101000000000000000000000000000000000000000000000),
Some(bf16(0x8555)),
);
// exp out of range but finite
roundtrip(
f64::from_bits(0b1_11110001010_1010101000000000000000000000000000000000000000000000),
None,
);
// explicitly check denormals
roundtrip(
f64::from_bits(0b0_01101111010_0000000000000000000000000000000000000000000000000000),
Some(bf16(1)),
);
roundtrip(
f64::from_bits(0b0_01101111011_1000000000000000000000000000000000000000000000000000),
Some(bf16(3)),
);
roundtrip(
f64::from_bits(0b0_01101111011_1100000000000000000000000000000000000000000000000000),
None,
);
// Due to being denormal, this is truncated out
roundtrip(
f64::from_bits(0b0_01101111010_0001000000000000000000000000000000000000000000000000),
None,
);
roundtrip(
f64::from_bits(0b0_01101111010_1000000000000000000000000000000000000000000000000000),
None,
);
}
}
Loading
Loading