From 07196108331dd72759a846b502202523683c5ebd Mon Sep 17 00:00:00 2001 From: Alex Huszagh Date: Sun, 15 Dec 2024 11:43:46 -0600 Subject: [PATCH] Add lossless `TryFrom` implementations. This adds `from_f*_lossless` functions, which then are implemented as `TryFrom`, which will only return values if they can fully roundtrip. Related to https://github.com/starkat99/half-rs/issues/90 --- CHANGELOG.md | 13 +- Cargo.lock | 2 +- Cargo.toml | 2 +- devel/Cargo.lock | 2 +- devel/Cargo.toml | 3 + devel/tests/bfloat_tests.rs | 21 +++ src/bfloat.rs | 248 ++++++++++++++++++++++++++++++++++++ src/binary16.rs | 213 +++++++++++++++++++++++++++++++ src/error.rs | 14 ++ src/lib.rs | 3 + src/try_from.rs | 134 +++++++++++++++++++ 11 files changed, 649 insertions(+), 6 deletions(-) create mode 100644 src/error.rs create mode 100644 src/try_from.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 097e111..4a24e6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] -## [0.1.1][v0.1.1] - 2024-12-12 +## [0.1.2][v0.1.2] - 2024-12-12 + +## Added + +- Added lossless `from_f*_lossless` functions and `TryFrom` implementations which will never have rounding error. + +## [0.1.1][v0.1.1] - 2024-12-12 ## Changed @@ -471,8 +477,9 @@ These were all changes for half, which `float16` is a fork of. [Unreleased]: https://github.com/starkat99/half-rs/compare/v2.4.1...HEAD -[v0.1.1]: https://github.com/Alexhuszagh/float16/compare/0.1.1...0.1.0 -[v0.1.0]: https://github.com/Alexhuszagh/float16/compare/0.1.0...v2.4.0 +[v0.1.2]: https://github.com/Alexhuszagh/float16/compare/v0.1.1...v0.1.2 +[v0.1.1]: https://github.com/Alexhuszagh/float16/compare/v0.1.0...v0.1.1 +[v0.1.0]: https://github.com/Alexhuszagh/float16/compare/v0.1.0...v2.4.0 [2.4.0]: https://github.com/starkat99/half-rs/compare/v2.3.1...v2.4.0 [2.3.1]: https://github.com/starkat99/half-rs/compare/v2.3.0...v2.3.1 [2.3.0]: https://github.com/starkat99/half-rs/compare/v2.2.1...v2.3.0 diff --git a/Cargo.lock b/Cargo.lock index 0af5321..8c0532c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10,7 +10,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "float16" -version = "0.1.1" +version = "0.1.2" dependencies = [ "cfg-if", "rustc_version", diff --git a/Cargo.toml b/Cargo.toml index d96f01c..dd4a93a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "float16" # Remember to keep in sync with html_root_url crate attribute -version = "0.1.1" +version = "0.1.2" authors = ["Kathryn Long ", "Alex Huszagh "] description = "Half-precision floating point f16 and bf16 types for Rust implementing the IEEE 754-2008 standard binary16 and bfloat16 types." repository = "https://github.com/Alexhuszagh/float16" diff --git a/devel/Cargo.lock b/devel/Cargo.lock index e5747a2..a124147 100644 --- a/devel/Cargo.lock +++ b/devel/Cargo.lock @@ -184,7 +184,7 @@ dependencies = [ [[package]] name = "float16" -version = "0.1.0" +version = "0.1.1" dependencies = [ "cfg-if", "rustc_version", diff --git a/devel/Cargo.toml b/devel/Cargo.toml index 966313f..bce52bb 100644 --- a/devel/Cargo.toml +++ b/devel/Cargo.toml @@ -5,6 +5,9 @@ authors = ["Alex Huszagh "] edition = "2021" publish = false +[features] +std = ["float16/std"] + [dependencies.float16] path = ".." default-features = false diff --git a/devel/tests/bfloat_tests.rs b/devel/tests/bfloat_tests.rs index 1e55066..50512c9 100644 --- a/devel/tests/bfloat_tests.rs +++ b/devel/tests/bfloat_tests.rs @@ -37,3 +37,24 @@ fn qc_roundtrip_bf16_f64_is_identity(bits: u16) -> bool { f.to_bits() == roundtrip.to_bits() } } + +#[quickcheck] +fn qc_roundtrip_from_f32_lossless(f: f32) -> bool { + if let Some(b) = bf16::from_f32_lossless(f) { + (b.is_nan() == f.is_nan()) || b.as_f32() == f + } else { + true + } +} + +#[quickcheck] +fn qc_roundtrip_from_f64_lossless(f: f64) -> bool { + if let Some(b) = bf16::from_f64_lossless(f) { + if !((b.is_nan() && f.is_nan()) || b.as_f64() == f) { + println!("b {}, f {}", b.to_bits(), f.to_bits()); + } + (b.is_nan() == f.is_nan()) || b.as_f64() == f + } else { + true + } +} diff --git a/src/bfloat.rs b/src/bfloat.rs index ed332b6..c8da24c 100644 --- a/src/bfloat.rs +++ b/src/bfloat.rs @@ -22,6 +22,9 @@ use core::{ str::FromStr, }; +use crate::error::TryFromFloatError; +use crate::try_from::try_from_lossless; + pub(crate) mod convert; /// A 16-bit floating point type implementing the [`bfloat16`] format. @@ -76,6 +79,61 @@ impl bf16 { bf16(convert::f32_to_bf16(value)) } + /// Create a [`struct@bf16`] loslessly from an [`f32`]. + /// + /// This is only true if the [`f32`] is non-finite + /// (infinite or NaN), or no non-zero bits would + /// be truncated. + /// + /// "Lossless" does not mean the data is represented the + /// same as a decimal number. For example, an [`f32`] + /// and [`f64`] have the significant digits (excluding the + /// hidden bit) for a value closest to `1e35` of: + /// - `f32`: `110100001001100001100` + /// - `f64`: `11010000100110000110000000000000000000000000000000` + /// + /// However, the [`f64`] is displayed as `1.0000000409184788e+35`, + /// while the value closest to `1e35` in [`f64`] is + /// `11010000100110000101110010110001110100110110000010`. This + /// makes it look like precision has been lost but this is + /// due to the approximations used to represent binary values as + /// a decimal. + /// + /// This does not respect signalling NaNs: if the value + /// is NaN or inf, then it will return that value. + /// + /// Since [`struct@bf16`] has the same number of exponent + /// bits as [`f32`], this is effectively just checking if the + /// value is non-finite (infinite or NaN) or the value + /// is normal and the lower 16 bits are 0. + #[inline] + pub const fn from_f32_lossless(value: f32) -> Option { + // NOTE: This logic is effectively just getting the top 16 bits + // and the bottom 16 bits, but it's done explicitly with mantissa + // digits for this reason. For explicit clarity, we remove the + // hidden bit in our exponent logic + const BF16_MANT_BITS: u32 = bf16::MANTISSA_DIGITS - 1; + const F32_MANT_BITS: u32 = f32::MANTISSA_DIGITS - 1; + const EXP_MASK: u32 = (f32::MAX_EXP as u32 * 2 - 1) << F32_MANT_BITS; + const TRUNCATED: u32 = F32_MANT_BITS - BF16_MANT_BITS; + const TRUNC_MASK: u32 = (1 << TRUNCATED) - 1; + + // SAFETY: safe since it's plain old data + let bits: u32 = unsafe { core::mem::transmute(value) }; + + // `bits & exp_mask == exp_mask` -> infinite or NaN + // `truncated == 0` -> no bits truncated + // since the exp ranges are the same, any denormal handling + // is already implicit. + let exp = bits & EXP_MASK; + let is_special = exp == EXP_MASK; + if is_special || bits & TRUNC_MASK == 0 { + Some(Self::from_f32_const(value)) + } else { + None + } + } + /// Constructs a [`struct@bf16`] value from a 64-bit floating point value. /// /// This operation is lossy. If the 64-bit value is to large to fit, ±∞ will @@ -107,6 +165,41 @@ impl bf16 { bf16(convert::f64_to_bf16(value)) } + /// Create a [`struct@bf16`] loslessly from an [`f64`]. + /// + /// This is only true if the [`f64`] is non-finite + /// (infinite or NaN), zero, or the exponent can be + /// represented by a normal [`struct@bf16`] and no + /// non-zero bits would be truncated. + /// + /// "Lossless" does not mean the data is represented the + /// same as a decimal number. For example, an [`f32`] + /// and [`f64`] have the significant digits (excluding the + /// hidden bit) for a value closest to `1e35` of: + /// - `f32`: `110100001001100001100` + /// - `f64`: `11010000100110000110000000000000000000000000000000` + /// + /// However, the [`f64`] is displayed as `1.0000000409184788e+35`, + /// while the value closest to `1e35` in [`f64`] is + /// `11010000100110000101110010110001110100110110000010`. This + /// makes it look like precision has been lost but this is + /// due to the approximations used to represent binary values as + /// a decimal. + /// + /// This does not respect signalling NaNs: if the value + /// is NaN or inf, then it will return that value. + #[inline] + pub const fn from_f64_lossless(value: f64) -> Option { + try_from_lossless!( + value => value, + half => bf16, + full => f64, + half_bits => u16, + full_bits => u64, + to_half => from_f64 + ) + } + /// Converts a [`struct@bf16`] into the underlying bit representation. #[inline] #[must_use] @@ -883,6 +976,24 @@ impl From for bf16 { } } +impl TryFrom for bf16 { + type Error = TryFromFloatError; + + #[inline] + fn try_from(x: f32) -> Result { + Self::from_f32_lossless(x).ok_or(TryFromFloatError(())) + } +} + +impl TryFrom for bf16 { + type Error = TryFromFloatError; + + #[inline] + fn try_from(x: f64) -> Result { + Self::from_f64_lossless(x).ok_or(TryFromFloatError(())) + } +} + impl PartialEq for bf16 { fn eq(&self, other: &bf16) -> bool { eq(*self, *other) @@ -1730,4 +1841,141 @@ mod test { assert_eq!(bf16::from_f64(252.50f64).to_bits(), bf16::from_f64(252.0).to_bits()); assert_eq!(bf16::from_f64(252.51f64).to_bits(), bf16::from_f64(253.0).to_bits()); } + + #[test] + fn from_f32_lossless() { + let from_f32 = |v: f32| bf16::from_f32_lossless(v); + let roundtrip = |v: f32, expected: Option| { + let half = from_f32(v); + assert_eq!(half, expected); + if !expected.is_none() { + let as_f32 = expected.unwrap().to_f32_const(); + assert_eq!(v, as_f32); + } + }; + + assert_eq!(from_f32(f32::NAN).map(bf16::is_nan), Some(true)); + roundtrip(f32::INFINITY, Some(bf16::INFINITY)); + roundtrip(f32::NEG_INFINITY, Some(bf16::NEG_INFINITY)); + roundtrip(f32::from_bits(0b0_00000000_00000000000000000000000), Some(bf16(0))); + roundtrip( + f32::from_bits(0b1_00000000_00000000000000000000000), + Some(bf16(bf16::SIGN_MASK)), + ); + roundtrip(f32::from_bits(1), None); + roundtrip(f32::from_bits(0b0_00001010_10101001010110100101110), None); + roundtrip(f32::from_bits(0b0_00001010_10101001010110100101110), None); + roundtrip(f32::from_bits(0b0_00001010_10101011000000000000000), None); + roundtrip( + f32::from_bits(0b0_00001010_10101010000000000000000), + Some(bf16(0b0_00001010_1010101)), + ); + roundtrip(f32::from_bits(0b0_00000000_10000000000000000000000), Some(bf16(0x40))); + // special truncation with denormals, etc. + roundtrip(f32::from_bits(0b0_00000000_00000001000000000000000), None); + roundtrip(f32::from_bits(0b0_00000000_00000010000000000000000), Some(bf16(1))); + roundtrip(f32::from_bits(0b0_00000000_00000100000000000000000), Some(bf16(2))); + roundtrip(f32::from_bits(0b0_00000000_00000110000000000000000), Some(bf16(3))); + roundtrip(f32::from_bits(0b0_00000000_00000111000000000000000), None); + roundtrip(f32::from_bits(0b0_00001011_10100111101101101001001), None); + // 1.99170198e-35 and has bits until 16 to the end, so truncated 2 + roundtrip(f32::from_bits(0b0_00001011_10100111100000000000000), None); + // 1.99170198e-35 and has bits until 15 to the end, so truncated 1 + roundtrip(f32::from_bits(0b0_00001011_10100111000000000000000), None); + // 1.99170198e-35 and has bits until 15 to the end, so truncated 1 + roundtrip(f32::from_bits(0b0_00001011_10100110000000000000000), Some(bf16(0x05d3))); + } + + #[test] + fn from_f64_lossless() { + let from_f64 = |v: f64| bf16::from_f64_lossless(v); + let roundtrip = |v: f64, expected: Option| { + let half = from_f64(v); + assert_eq!(half, expected); + if !expected.is_none() { + let as_f64 = expected.unwrap().to_f64_const(); + assert_eq!(v, as_f64); + } + }; + + assert_eq!(from_f64(f64::NAN).map(bf16::is_nan), Some(true)); + roundtrip(f64::INFINITY, Some(bf16::INFINITY)); + roundtrip(f64::NEG_INFINITY, Some(bf16::NEG_INFINITY)); + roundtrip( + f64::from_bits(0b0_00000000000_0000000000000000000000000000000000000000000000000000), + Some(bf16(0)), + ); + roundtrip( + f64::from_bits(0b1_00000000000_0000000000000000000000000000000000000000000000000000), + Some(bf16(bf16::SIGN_MASK)), + ); + roundtrip( + f64::from_bits(0b0_01110001010_1010100101011010010110110111111110000111101000001111), + None, + ); + // 1.99170198e-35 and has bits until 44 to the end, so truncated 1 + roundtrip( + f64::from_bits(0b0_01110001010_1010100100000000000000000000000000000000000000000000), + None, + ); + roundtrip( + f64::from_bits(0b0_01110001010_1010100000000000000000000000000000000000000000000000), + Some(bf16(0x0554)), + ); + roundtrip( + f64::from_bits(0b0_01110001010_1010101000000000000000000000000000000000000000000000), + Some(bf16(0x0555)), + ); + roundtrip( + f64::from_bits(0b0_01110001010_1010110000000000000000000000000000000000000000000000), + Some(bf16(0x0556)), + ); + roundtrip( + f64::from_bits(0b0_01110001010_1010111000000000000000000000000000000000000000000000), + Some(bf16(0x0557)), + ); + roundtrip( + f64::from_bits(0b0_01110001010_1010101100000000000000000000000000000000000000000000), + None, + ); + roundtrip( + f64::from_bits(0b0_01110001010_1010100110000000000000000000000000000000000000000000), + None, + ); + roundtrip( + f64::from_bits(0b1_01110001010_1010100000000000000000000000000000000000000000000000), + Some(bf16(0x8554)), + ); + roundtrip( + f64::from_bits(0b1_01110001010_1010101000000000000000000000000000000000000000000000), + Some(bf16(0x8555)), + ); + // exp out of range but finite + roundtrip( + f64::from_bits(0b1_11110001010_1010101000000000000000000000000000000000000000000000), + None, + ); + // explicitly check denormals + roundtrip( + f64::from_bits(0b0_01101111010_0000000000000000000000000000000000000000000000000000), + Some(bf16(1)), + ); + roundtrip( + f64::from_bits(0b0_01101111011_1000000000000000000000000000000000000000000000000000), + Some(bf16(3)), + ); + roundtrip( + f64::from_bits(0b0_01101111011_1100000000000000000000000000000000000000000000000000), + None, + ); + // Due to being denormal, this is truncated out + roundtrip( + f64::from_bits(0b0_01101111010_0001000000000000000000000000000000000000000000000000), + None, + ); + roundtrip( + f64::from_bits(0b0_01101111010_1000000000000000000000000000000000000000000000000000), + None, + ); + } } diff --git a/src/binary16.rs b/src/binary16.rs index 41affe5..ab953d1 100644 --- a/src/binary16.rs +++ b/src/binary16.rs @@ -22,6 +22,9 @@ use core::{ str::FromStr, }; +use crate::error::TryFromFloatError; +use crate::try_from::try_from_lossless; + pub(crate) mod arch; /// A 16-bit floating point type implementing the IEEE 754-2008 standard @@ -98,6 +101,41 @@ impl f16 { f16(arch::f32_to_f16(value)) } + /// Create a [`struct@f16`] loslessly from an [`f32`]. + /// + /// This is only true if the [`f32`] is non-finite + /// (infinite or NaN), or the exponent can be represented + /// by a normal [`struct@f16`] and no non-zero bits would + /// be truncated. + /// + /// "Lossless" does not mean the data is represented the + /// same as a decimal number. For example, an [`f32`] + /// and [`f64`] have the significant digits (excluding the + /// hidden bit) for a value closest to `1e35` of: + /// - `f32`: `110100001001100001100` + /// - `f64`: `11010000100110000110000000000000000000000000000000` + /// + /// However, the [`f64`] is displayed as `1.0000000409184788e+35`, + /// while the value closest to `1e35` in [`f64`] is + /// `11010000100110000101110010110001110100110110000010`. This + /// makes it look like precision has been lost but this is + /// due to the approximations used to represent binary values as + /// a decimal. + /// + /// This does not respect signalling NaNs: if the value + /// is NaN or inf, then it will return that value. + #[inline] + pub const fn from_f32_lossless(value: f32) -> Option { + try_from_lossless!( + value => value, + half => f16, + full => f32, + half_bits => u16, + full_bits => u32, + to_half => from_f32 + ) + } + /// Constructs a 16-bit floating point value from a 64-bit floating point /// value. /// @@ -160,6 +198,41 @@ impl f16 { f16(arch::f64_to_f16(value)) } + /// Create a [`struct@f16`] loslessly from an [`f64`]. + /// + /// This is only true if the [`f64`] is non-finite + /// (infinite or NaN), or the exponent can be represented + /// by a normal [`struct@f16`] and no non-zero bits would + /// be truncated. + /// + /// "Lossless" does not mean the data is represented the + /// same as a decimal number. For example, an [`f32`] + /// and [`f64`] have the significant digits (excluding the + /// hidden bit) for a value closest to `1e35` of: + /// - `f32`: `110100001001100001100` + /// - `f64`: `11010000100110000110000000000000000000000000000000` + /// + /// However, the [`f64`] is displayed as `1.0000000409184788e+35`, + /// while the value closest to `1e35` in [`f64`] is + /// `11010000100110000101110010110001110100110110000010`. This + /// makes it look like precision has been lost but this is + /// due to the approximations used to represent binary values as + /// a decimal. + /// + /// This does not respect signalling NaNs: if the value + /// is NaN or inf, then it will return that value. + #[inline] + pub const fn from_f64_lossless(value: f64) -> Option { + try_from_lossless!( + value => value, + half => f16, + full => f64, + half_bits => u16, + full_bits => u64, + to_half => from_f64 + ) + } + /// Converts a [`struct@f16`] into the underlying bit representation. #[inline] #[must_use] @@ -963,6 +1036,24 @@ impl From for f16 { } } +impl TryFrom for f16 { + type Error = TryFromFloatError; + + #[inline] + fn try_from(x: f32) -> Result { + Self::from_f32_lossless(x).ok_or(TryFromFloatError(())) + } +} + +impl TryFrom for f16 { + type Error = TryFromFloatError; + + #[inline] + fn try_from(x: f64) -> Result { + Self::from_f64_lossless(x).ok_or(TryFromFloatError(())) + } +} + impl PartialEq for f16 { #[inline] fn eq(&self, other: &f16) -> bool { @@ -1908,4 +1999,126 @@ mod test { assert_eq!(const_bits, bits); assert!(inst_bits.abs_diff(bits) <= max_diff); } + + #[test] + fn from_f32_lossless() { + let from_f32 = |v: f32| f16::from_f32_lossless(v); + let roundtrip = |v: f32, expected: Option| { + let half = from_f32(v); + assert_eq!(half, expected); + if !expected.is_none() { + let as_f32 = expected.unwrap().to_f32_const(); + assert_eq!(v, as_f32); + } + }; + + assert_eq!(from_f32(f32::NAN).map(f16::is_nan), Some(true)); + roundtrip(f32::INFINITY, Some(f16::INFINITY)); + roundtrip(f32::NEG_INFINITY, Some(f16::NEG_INFINITY)); + roundtrip(f32::from_bits(0b0_00000000_00000000000000000000000), Some(f16(0))); + roundtrip(f32::from_bits(0b1_00000000_00000000000000000000000), Some(f16(f16::SIGN_MASK))); + roundtrip(f32::from_bits(1), None); + + // special truncation with denormals, etc. + roundtrip(f32::from_bits(0b0_01100111_00000000000000000000000), Some(f16(1))); + roundtrip(f32::from_bits(0b0_01101000_00000000000000000000000), Some(f16(2))); + roundtrip(f32::from_bits(0b0_01101000_10000000000000000000000), Some(f16(3))); + roundtrip(f32::from_bits(0b0_01100111_10000000000000000000000), None); + roundtrip(f32::from_bits(0b0_01101000_11000000000000000000000), None); + // ~2.2888184e-5 and has bits until 16 to the end, so truncated 2. but this is + // denormal as f16 + roundtrip(f32::from_bits(0b0_01101111_00000000000000000000000), Some(f16(0x100))); + roundtrip(f32::from_bits(0b0_01101111_10000000000000000000000), Some(f16(0x180))); + roundtrip(f32::from_bits(0b0_01101111_11000000000000000000000), Some(f16(0x1c0))); + roundtrip(f32::from_bits(0b0_01101111_11000001000000000000000), Some(f16(0x1c1))); + roundtrip(f32::from_bits(0b0_01101111_11000001100000000000000), None); + //2.0f32 + roundtrip(f32::from_bits(0b0_10000000_00000000000000000000000), Some(f16(0x4000))); + roundtrip(f32::from_bits(0b0_10000000_10000000000000000000000), Some(f16(0x4200))); + roundtrip(f32::from_bits(0b0_10000000_10000000010000000000000), Some(f16(0x4201))); + roundtrip(f32::from_bits(0b0_10000000_10000000011000000000000), None); + // check overflow + roundtrip(f32::from_bits(0b0_10001111_00000000000000000000000), None); + roundtrip(f32::from_bits(0b0_10001110_00000000000000000000000), Some(f16(0x7800))); + } + + #[test] + fn from_f64_lossless() { + let from_f64 = |v: f64| f16::from_f64_lossless(v); + let roundtrip = |v: f64, expected: Option| { + let half = from_f64(v); + assert_eq!(half, expected); + if !expected.is_none() { + let as_f64 = expected.unwrap().to_f64_const(); + assert_eq!(v, as_f64); + } + }; + + assert_eq!(from_f64(f64::NAN).map(f16::is_nan), Some(true)); + roundtrip(f64::INFINITY, Some(f16::INFINITY)); + roundtrip(f64::NEG_INFINITY, Some(f16::NEG_INFINITY)); + roundtrip( + f64::from_bits(0b0_00000000000_0000000000000000000000000000000000000000000000000000), + Some(f16(0)), + ); + roundtrip( + f64::from_bits(0b1_00000000000_0000000000000000000000000000000000000000000000000000), + Some(f16(f16::SIGN_MASK)), + ); + roundtrip( + f64::from_bits(0b0_01110001010_1010100101011010010110110111111110000111101000001111), + None, + ); + // check overflow to inf + roundtrip( + f64::from_bits(0b0_10000001110_1000000000000000000000000000000000000000000000000000), + Some(f16(0x7a00)), + ); + roundtrip( + f64::from_bits(0b0_10000001111_1000000000000000000000000000000000000000000000000000), + None, + ); + // check denormals and truncation + roundtrip( + f64::from_bits(0b0_01111100111_0000000000000000000000000000000000000000000000000000), + Some(f16(1)), + ); + roundtrip( + f64::from_bits(0b0_01111100111_1000000000000000000000000000000000000000000000000000), + None, + ); + roundtrip( + f64::from_bits(0b0_01111101000_0000000000000000000000000000000000000000000000000000), + Some(f16(2)), + ); + roundtrip( + f64::from_bits(0b0_01111101000_1000000000000000000000000000000000000000000000000000), + Some(f16(3)), + ); + roundtrip( + f64::from_bits(0b0_01111101000_1100000000000000000000000000000000000000000000000000), + None, + ); + // check basic, normal and positive numbers + roundtrip( + f64::from_bits(0b0_01111111000_0000000000000000000000000000000000000000000000000000), + Some(f16(0x2000)), + ); + roundtrip( + f64::from_bits(0b0_01111111000_1000000000000000000000000000000000000000000000000000), + Some(f16(0x2200)), + ); + roundtrip( + f64::from_bits(0b0_01111111000_1110000000000000000000000000000000000000000000000000), + Some(f16(0x2380)), + ); + roundtrip( + f64::from_bits(0b0_01111111000_1110000001000000000000000000000000000000000000000000), + Some(f16(0x2381)), + ); + roundtrip( + f64::from_bits(0b0_01111111000_1110000001100000000000000000000000000000000000000000), + None, + ); + } } diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..b4c0d5e --- /dev/null +++ b/src/error.rs @@ -0,0 +1,14 @@ +//! Error type for numeric conversion functions. + +use core::fmt; + +/// The error type returned when a checked integral type conversion fails. +pub struct TryFromFloatError(pub(crate) ()); + +impl fmt::Display for TryFromFloatError { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let msg = "out of range integral type conversion attempted"; + fmt::Display::fmt(msg, f) + } +} diff --git a/src/lib.rs b/src/lib.rs index 415725a..f02942a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -70,11 +70,14 @@ mod bfloat; mod binary16; +mod error; mod leading_zeros; mod slice; +mod try_from; pub use bfloat::bf16; pub use binary16::f16; +pub use error::TryFromFloatError; #[cfg(not(target_arch = "spirv"))] pub use crate::slice::{HalfBitsSliceExt, HalfFloatSliceExt}; diff --git a/src/try_from.rs b/src/try_from.rs new file mode 100644 index 0000000..1caa177 --- /dev/null +++ b/src/try_from.rs @@ -0,0 +1,134 @@ +// Try to convert a value from a wider type, only converting +// the value if it can be losslessly converted, otherwise returning none. +macro_rules! try_from_lossless { + ( + value => + $value:ident,half => + $half:ident,full => + $full:ident,half_bits => + $half_bits:ident,full_bits => + $full_bits:ident,to_half => + $to_half:ident + ) => {{ + // let's use `f16` and `f64` as an example: + // the `f64` is broken down into the following components: + // - sign: 1 + // - exp: 111111111110000000000000000000000000000000000000000000000000000 + // - mant: 1111111111111111111111111111111111111111111111111111 + // + // value is stored as `2^n` where the lowest bit of the exp is the implicit + // hidden bit, that is, the `1`, while the top bit of `mantissa` is `1/2`, + // then `1/4`, etc. for an `f16`, we then have: + // - sign: 1 + // - exp: 11111 + // - mant: 1111111111 + // or the bottom 42 bits are truncated during the conversion, if the exponents + // are in range. we only need to consider special cases, that is, subnormal + // floats, where all exponent bits are 0, if both types have the same number of + // exponent bits (`f32` to `bf16`). so only if the bottom `52 - 10` bits are + // `0`, then it has a lossless conversion + // + // we do have special cases for non-finite values, NaN and +/- infinity. since + // a NaN is still a NaN no matter the lower `51` bits in the longer type, if + // we can ignore the result no matter the lower bits. + let bits: $full_bits = unsafe { core::mem::transmute($value) }; + + // get our masks and extract the IEEE754 components. + const FULL_MANTISSA_BITS: u32 = <$full>::MANTISSA_DIGITS - 1; + const FULL_SIGN_MASK: $full_bits = 1 << (<$full_bits>::BITS - 1); + const FULL_EXPONENT_MASK: $full_bits = + (<$full>::MAX_EXP as $full_bits * 2 - 1) << FULL_MANTISSA_BITS; + const FULL_MANTISSA_MASK: $full_bits = (1 << FULL_MANTISSA_BITS) - 1; + let full_sign = bits & FULL_SIGN_MASK; + let full_exp = bits & FULL_EXPONENT_MASK; + let full_mant = bits & FULL_MANTISSA_MASK; + + const HALF_MANTISSA_BITS: u32 = <$half>::MANTISSA_DIGITS - 1; + const HALF_EXPONENT_MASK: $half_bits = + (<$half>::MAX_EXP as $half_bits * 2 - 1) << HALF_MANTISSA_BITS; + let sign_shift = <$full_bits>::BITS - <$half_bits>::BITS; + let half_sign = (full_sign >> sign_shift) as $half_bits; + + // we use the number of bits without the hidden bit. + // we want to know the number of bits truncated and a mask for + // all bits that could be truncated. + const TRUNCATED_BITS: u32 = FULL_MANTISSA_BITS - HALF_MANTISSA_BITS; + + // check for if we have a special (non-finite) number + if full_exp == FULL_EXPONENT_MASK { + let half_exp = HALF_EXPONENT_MASK; + let half_mant = (full_mant >> TRUNCATED_BITS) as $half_bits; + return Some($half(half_sign | half_exp | half_mant)); + } + + // check for zero, which would otherwise underflow + if (bits & !FULL_SIGN_MASK) == 0 { + return Some($half(half_sign)); + } + + // need to get our unbiased exponent. exponents are stored with + // the value as (2^exp - (2^(expbits-1) - 1)`. the max, unbiased + // exp for `bf16` is `127` and the min non-denormal one is `-126`. + // we need the hidden bit in this biased exp. + const FULL_BIAS: i32 = <$full>::MAX_EXP - 1; + let full_biased = (full_exp >> FULL_MANTISSA_BITS) as i32; + let full_unbiased = full_biased - FULL_BIAS; + + // now we need to know if our exponent is in the range. our range is from + // if your small exp is valid for our float, that is, unbiased it's in + // the range `2 - 2^(expbits-1)` or `1 - bias` for a normal float + // (biased exp `>= 1`), but a denormal float works so we want + // `1 - bias`. Our max exp finite is `2^(expbits-1) - 1` or `bias`. + // all special values always are valid, so we also accept when all + // the exponent bits are set. we have a special case: when the two + // exponents are the same number of bits: then it's **ALWAYS** valid. + // + // but this still needs to consider denormal values, or where we have + // no exp bits + const HALF_BIAS: i32 = <$half>::MAX_EXP - 1; + const HALF_MIN_EXP: i32 = 1 - HALF_BIAS; + const FULL_EXP_BITS: u32 = <$full_bits>::BITS - FULL_MANTISSA_BITS - 1; + const HALF_EXP_BITS: u32 = <$half_bits>::BITS - HALF_MANTISSA_BITS - 1; + const HALF_MAX_EXP: i32 = HALF_BIAS; + const HALF_MIN_DENORMAL_EXP: i32 = HALF_MIN_EXP - HALF_MANTISSA_BITS as i32; + let exp_in_range = if FULL_EXP_BITS == HALF_EXP_BITS { + true + } else { + full_unbiased >= HALF_MIN_DENORMAL_EXP && full_unbiased <= HALF_MAX_EXP + }; + if !exp_in_range { + return None; + } + + // get if we have any truncated bits, otherwise, we have an exact result + let half_biased = full_unbiased + HALF_BIAS; + let is_denormal = half_biased <= 0; + let truncated_bits = if is_denormal { + // NOTE: This needs an extra bit for what was formerly the hidden bit + (TRUNCATED_BITS as i32 - half_biased + 1) as u32 + } else { + TRUNCATED_BITS + }; + let truncated_mask: $full_bits = (1 << truncated_bits) - 1; + if bits & truncated_mask != 0 { + return None; + } + + // now we need to reassemble our float components. remember if we have + // a denormal float in the result we need to move our implicit hidden + // bit out. + let full_hidden_bit: $full_bits = 1 << FULL_MANTISSA_BITS; + let (half_mant, half_exp) = if is_denormal { + let half_mant = ((full_mant | full_hidden_bit) >> truncated_bits) as $half_bits; + (half_mant, 0) + } else { + let half_mant = (full_mant >> truncated_bits) as $half_bits; + let half_exp = (half_biased as $half_bits) << HALF_MANTISSA_BITS; + (half_mant, half_exp) + }; + + Some($half(half_sign | half_exp | half_mant)) + }}; +} + +pub(crate) use try_from_lossless;