diff --git a/CHANGELOG.md b/CHANGELOG.md
index 097e111..4a24e6b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,7 +5,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
-## [0.1.1][v0.1.1] - 2024-12-12
+## [0.1.2][v0.1.2] - 2024-12-12
+
+## Added
+
+- Added lossless `from_f*_lossless` functions and `TryFrom` implementations which will never have rounding error.
+
+## [0.1.1][v0.1.1] - 2024-12-12
## Changed
@@ -471,8 +477,9 @@ These were all changes for half, which `float16` is a fork of.
[Unreleased]: https://github.com/starkat99/half-rs/compare/v2.4.1...HEAD
-[v0.1.1]: https://github.com/Alexhuszagh/float16/compare/0.1.1...0.1.0
-[v0.1.0]: https://github.com/Alexhuszagh/float16/compare/0.1.0...v2.4.0
+[v0.1.2]: https://github.com/Alexhuszagh/float16/compare/v0.1.1...v0.1.2
+[v0.1.1]: https://github.com/Alexhuszagh/float16/compare/v0.1.0...v0.1.1
+[v0.1.0]: https://github.com/Alexhuszagh/float16/compare/v0.1.0...v2.4.0
[2.4.0]: https://github.com/starkat99/half-rs/compare/v2.3.1...v2.4.0
[2.3.1]: https://github.com/starkat99/half-rs/compare/v2.3.0...v2.3.1
[2.3.0]: https://github.com/starkat99/half-rs/compare/v2.2.1...v2.3.0
diff --git a/Cargo.lock b/Cargo.lock
index 0af5321..8c0532c 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -10,7 +10,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "float16"
-version = "0.1.1"
+version = "0.1.2"
dependencies = [
"cfg-if",
"rustc_version",
diff --git a/Cargo.toml b/Cargo.toml
index d96f01c..dd4a93a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,7 +1,7 @@
[package]
name = "float16"
# Remember to keep in sync with html_root_url crate attribute
-version = "0.1.1"
+version = "0.1.2"
authors = ["Kathryn Long ", "Alex Huszagh "]
description = "Half-precision floating point f16 and bf16 types for Rust implementing the IEEE 754-2008 standard binary16 and bfloat16 types."
repository = "https://github.com/Alexhuszagh/float16"
diff --git a/devel/Cargo.lock b/devel/Cargo.lock
index e5747a2..a124147 100644
--- a/devel/Cargo.lock
+++ b/devel/Cargo.lock
@@ -184,7 +184,7 @@ dependencies = [
[[package]]
name = "float16"
-version = "0.1.0"
+version = "0.1.1"
dependencies = [
"cfg-if",
"rustc_version",
diff --git a/devel/Cargo.toml b/devel/Cargo.toml
index 966313f..bce52bb 100644
--- a/devel/Cargo.toml
+++ b/devel/Cargo.toml
@@ -5,6 +5,9 @@ authors = ["Alex Huszagh "]
edition = "2021"
publish = false
+[features]
+std = ["float16/std"]
+
[dependencies.float16]
path = ".."
default-features = false
diff --git a/devel/tests/bfloat_tests.rs b/devel/tests/bfloat_tests.rs
index 1e55066..50512c9 100644
--- a/devel/tests/bfloat_tests.rs
+++ b/devel/tests/bfloat_tests.rs
@@ -37,3 +37,24 @@ fn qc_roundtrip_bf16_f64_is_identity(bits: u16) -> bool {
f.to_bits() == roundtrip.to_bits()
}
}
+
+#[quickcheck]
+fn qc_roundtrip_from_f32_lossless(f: f32) -> bool {
+ if let Some(b) = bf16::from_f32_lossless(f) {
+ (b.is_nan() == f.is_nan()) || b.as_f32() == f
+ } else {
+ true
+ }
+}
+
+#[quickcheck]
+fn qc_roundtrip_from_f64_lossless(f: f64) -> bool {
+ if let Some(b) = bf16::from_f64_lossless(f) {
+ if !((b.is_nan() && f.is_nan()) || b.as_f64() == f) {
+ println!("b {}, f {}", b.to_bits(), f.to_bits());
+ }
+ (b.is_nan() == f.is_nan()) || b.as_f64() == f
+ } else {
+ true
+ }
+}
diff --git a/src/bfloat.rs b/src/bfloat.rs
index ed332b6..c8da24c 100644
--- a/src/bfloat.rs
+++ b/src/bfloat.rs
@@ -22,6 +22,9 @@ use core::{
str::FromStr,
};
+use crate::error::TryFromFloatError;
+use crate::try_from::try_from_lossless;
+
pub(crate) mod convert;
/// A 16-bit floating point type implementing the [`bfloat16`] format.
@@ -76,6 +79,61 @@ impl bf16 {
bf16(convert::f32_to_bf16(value))
}
+ /// Create a [`struct@bf16`] loslessly from an [`f32`].
+ ///
+ /// This is only true if the [`f32`] is non-finite
+ /// (infinite or NaN), or no non-zero bits would
+ /// be truncated.
+ ///
+ /// "Lossless" does not mean the data is represented the
+ /// same as a decimal number. For example, an [`f32`]
+ /// and [`f64`] have the significant digits (excluding the
+ /// hidden bit) for a value closest to `1e35` of:
+ /// - `f32`: `110100001001100001100`
+ /// - `f64`: `11010000100110000110000000000000000000000000000000`
+ ///
+ /// However, the [`f64`] is displayed as `1.0000000409184788e+35`,
+ /// while the value closest to `1e35` in [`f64`] is
+ /// `11010000100110000101110010110001110100110110000010`. This
+ /// makes it look like precision has been lost but this is
+ /// due to the approximations used to represent binary values as
+ /// a decimal.
+ ///
+ /// This does not respect signalling NaNs: if the value
+ /// is NaN or inf, then it will return that value.
+ ///
+ /// Since [`struct@bf16`] has the same number of exponent
+ /// bits as [`f32`], this is effectively just checking if the
+ /// value is non-finite (infinite or NaN) or the value
+ /// is normal and the lower 16 bits are 0.
+ #[inline]
+ pub const fn from_f32_lossless(value: f32) -> Option {
+ // NOTE: This logic is effectively just getting the top 16 bits
+ // and the bottom 16 bits, but it's done explicitly with mantissa
+ // digits for this reason. For explicit clarity, we remove the
+ // hidden bit in our exponent logic
+ const BF16_MANT_BITS: u32 = bf16::MANTISSA_DIGITS - 1;
+ const F32_MANT_BITS: u32 = f32::MANTISSA_DIGITS - 1;
+ const EXP_MASK: u32 = (f32::MAX_EXP as u32 * 2 - 1) << F32_MANT_BITS;
+ const TRUNCATED: u32 = F32_MANT_BITS - BF16_MANT_BITS;
+ const TRUNC_MASK: u32 = (1 << TRUNCATED) - 1;
+
+ // SAFETY: safe since it's plain old data
+ let bits: u32 = unsafe { core::mem::transmute(value) };
+
+ // `bits & exp_mask == exp_mask` -> infinite or NaN
+ // `truncated == 0` -> no bits truncated
+ // since the exp ranges are the same, any denormal handling
+ // is already implicit.
+ let exp = bits & EXP_MASK;
+ let is_special = exp == EXP_MASK;
+ if is_special || bits & TRUNC_MASK == 0 {
+ Some(Self::from_f32_const(value))
+ } else {
+ None
+ }
+ }
+
/// Constructs a [`struct@bf16`] value from a 64-bit floating point value.
///
/// This operation is lossy. If the 64-bit value is to large to fit, ±∞ will
@@ -107,6 +165,41 @@ impl bf16 {
bf16(convert::f64_to_bf16(value))
}
+ /// Create a [`struct@bf16`] loslessly from an [`f64`].
+ ///
+ /// This is only true if the [`f64`] is non-finite
+ /// (infinite or NaN), zero, or the exponent can be
+ /// represented by a normal [`struct@bf16`] and no
+ /// non-zero bits would be truncated.
+ ///
+ /// "Lossless" does not mean the data is represented the
+ /// same as a decimal number. For example, an [`f32`]
+ /// and [`f64`] have the significant digits (excluding the
+ /// hidden bit) for a value closest to `1e35` of:
+ /// - `f32`: `110100001001100001100`
+ /// - `f64`: `11010000100110000110000000000000000000000000000000`
+ ///
+ /// However, the [`f64`] is displayed as `1.0000000409184788e+35`,
+ /// while the value closest to `1e35` in [`f64`] is
+ /// `11010000100110000101110010110001110100110110000010`. This
+ /// makes it look like precision has been lost but this is
+ /// due to the approximations used to represent binary values as
+ /// a decimal.
+ ///
+ /// This does not respect signalling NaNs: if the value
+ /// is NaN or inf, then it will return that value.
+ #[inline]
+ pub const fn from_f64_lossless(value: f64) -> Option {
+ try_from_lossless!(
+ value => value,
+ half => bf16,
+ full => f64,
+ half_bits => u16,
+ full_bits => u64,
+ to_half => from_f64
+ )
+ }
+
/// Converts a [`struct@bf16`] into the underlying bit representation.
#[inline]
#[must_use]
@@ -883,6 +976,24 @@ impl From for bf16 {
}
}
+impl TryFrom for bf16 {
+ type Error = TryFromFloatError;
+
+ #[inline]
+ fn try_from(x: f32) -> Result {
+ Self::from_f32_lossless(x).ok_or(TryFromFloatError(()))
+ }
+}
+
+impl TryFrom for bf16 {
+ type Error = TryFromFloatError;
+
+ #[inline]
+ fn try_from(x: f64) -> Result {
+ Self::from_f64_lossless(x).ok_or(TryFromFloatError(()))
+ }
+}
+
impl PartialEq for bf16 {
fn eq(&self, other: &bf16) -> bool {
eq(*self, *other)
@@ -1730,4 +1841,141 @@ mod test {
assert_eq!(bf16::from_f64(252.50f64).to_bits(), bf16::from_f64(252.0).to_bits());
assert_eq!(bf16::from_f64(252.51f64).to_bits(), bf16::from_f64(253.0).to_bits());
}
+
+ #[test]
+ fn from_f32_lossless() {
+ let from_f32 = |v: f32| bf16::from_f32_lossless(v);
+ let roundtrip = |v: f32, expected: Option| {
+ let half = from_f32(v);
+ assert_eq!(half, expected);
+ if !expected.is_none() {
+ let as_f32 = expected.unwrap().to_f32_const();
+ assert_eq!(v, as_f32);
+ }
+ };
+
+ assert_eq!(from_f32(f32::NAN).map(bf16::is_nan), Some(true));
+ roundtrip(f32::INFINITY, Some(bf16::INFINITY));
+ roundtrip(f32::NEG_INFINITY, Some(bf16::NEG_INFINITY));
+ roundtrip(f32::from_bits(0b0_00000000_00000000000000000000000), Some(bf16(0)));
+ roundtrip(
+ f32::from_bits(0b1_00000000_00000000000000000000000),
+ Some(bf16(bf16::SIGN_MASK)),
+ );
+ roundtrip(f32::from_bits(1), None);
+ roundtrip(f32::from_bits(0b0_00001010_10101001010110100101110), None);
+ roundtrip(f32::from_bits(0b0_00001010_10101001010110100101110), None);
+ roundtrip(f32::from_bits(0b0_00001010_10101011000000000000000), None);
+ roundtrip(
+ f32::from_bits(0b0_00001010_10101010000000000000000),
+ Some(bf16(0b0_00001010_1010101)),
+ );
+ roundtrip(f32::from_bits(0b0_00000000_10000000000000000000000), Some(bf16(0x40)));
+ // special truncation with denormals, etc.
+ roundtrip(f32::from_bits(0b0_00000000_00000001000000000000000), None);
+ roundtrip(f32::from_bits(0b0_00000000_00000010000000000000000), Some(bf16(1)));
+ roundtrip(f32::from_bits(0b0_00000000_00000100000000000000000), Some(bf16(2)));
+ roundtrip(f32::from_bits(0b0_00000000_00000110000000000000000), Some(bf16(3)));
+ roundtrip(f32::from_bits(0b0_00000000_00000111000000000000000), None);
+ roundtrip(f32::from_bits(0b0_00001011_10100111101101101001001), None);
+ // 1.99170198e-35 and has bits until 16 to the end, so truncated 2
+ roundtrip(f32::from_bits(0b0_00001011_10100111100000000000000), None);
+ // 1.99170198e-35 and has bits until 15 to the end, so truncated 1
+ roundtrip(f32::from_bits(0b0_00001011_10100111000000000000000), None);
+ // 1.99170198e-35 and has bits until 15 to the end, so truncated 1
+ roundtrip(f32::from_bits(0b0_00001011_10100110000000000000000), Some(bf16(0x05d3)));
+ }
+
+ #[test]
+ fn from_f64_lossless() {
+ let from_f64 = |v: f64| bf16::from_f64_lossless(v);
+ let roundtrip = |v: f64, expected: Option| {
+ let half = from_f64(v);
+ assert_eq!(half, expected);
+ if !expected.is_none() {
+ let as_f64 = expected.unwrap().to_f64_const();
+ assert_eq!(v, as_f64);
+ }
+ };
+
+ assert_eq!(from_f64(f64::NAN).map(bf16::is_nan), Some(true));
+ roundtrip(f64::INFINITY, Some(bf16::INFINITY));
+ roundtrip(f64::NEG_INFINITY, Some(bf16::NEG_INFINITY));
+ roundtrip(
+ f64::from_bits(0b0_00000000000_0000000000000000000000000000000000000000000000000000),
+ Some(bf16(0)),
+ );
+ roundtrip(
+ f64::from_bits(0b1_00000000000_0000000000000000000000000000000000000000000000000000),
+ Some(bf16(bf16::SIGN_MASK)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01110001010_1010100101011010010110110111111110000111101000001111),
+ None,
+ );
+ // 1.99170198e-35 and has bits until 44 to the end, so truncated 1
+ roundtrip(
+ f64::from_bits(0b0_01110001010_1010100100000000000000000000000000000000000000000000),
+ None,
+ );
+ roundtrip(
+ f64::from_bits(0b0_01110001010_1010100000000000000000000000000000000000000000000000),
+ Some(bf16(0x0554)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01110001010_1010101000000000000000000000000000000000000000000000),
+ Some(bf16(0x0555)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01110001010_1010110000000000000000000000000000000000000000000000),
+ Some(bf16(0x0556)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01110001010_1010111000000000000000000000000000000000000000000000),
+ Some(bf16(0x0557)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01110001010_1010101100000000000000000000000000000000000000000000),
+ None,
+ );
+ roundtrip(
+ f64::from_bits(0b0_01110001010_1010100110000000000000000000000000000000000000000000),
+ None,
+ );
+ roundtrip(
+ f64::from_bits(0b1_01110001010_1010100000000000000000000000000000000000000000000000),
+ Some(bf16(0x8554)),
+ );
+ roundtrip(
+ f64::from_bits(0b1_01110001010_1010101000000000000000000000000000000000000000000000),
+ Some(bf16(0x8555)),
+ );
+ // exp out of range but finite
+ roundtrip(
+ f64::from_bits(0b1_11110001010_1010101000000000000000000000000000000000000000000000),
+ None,
+ );
+ // explicitly check denormals
+ roundtrip(
+ f64::from_bits(0b0_01101111010_0000000000000000000000000000000000000000000000000000),
+ Some(bf16(1)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01101111011_1000000000000000000000000000000000000000000000000000),
+ Some(bf16(3)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01101111011_1100000000000000000000000000000000000000000000000000),
+ None,
+ );
+ // Due to being denormal, this is truncated out
+ roundtrip(
+ f64::from_bits(0b0_01101111010_0001000000000000000000000000000000000000000000000000),
+ None,
+ );
+ roundtrip(
+ f64::from_bits(0b0_01101111010_1000000000000000000000000000000000000000000000000000),
+ None,
+ );
+ }
}
diff --git a/src/binary16.rs b/src/binary16.rs
index 41affe5..ab953d1 100644
--- a/src/binary16.rs
+++ b/src/binary16.rs
@@ -22,6 +22,9 @@ use core::{
str::FromStr,
};
+use crate::error::TryFromFloatError;
+use crate::try_from::try_from_lossless;
+
pub(crate) mod arch;
/// A 16-bit floating point type implementing the IEEE 754-2008 standard
@@ -98,6 +101,41 @@ impl f16 {
f16(arch::f32_to_f16(value))
}
+ /// Create a [`struct@f16`] loslessly from an [`f32`].
+ ///
+ /// This is only true if the [`f32`] is non-finite
+ /// (infinite or NaN), or the exponent can be represented
+ /// by a normal [`struct@f16`] and no non-zero bits would
+ /// be truncated.
+ ///
+ /// "Lossless" does not mean the data is represented the
+ /// same as a decimal number. For example, an [`f32`]
+ /// and [`f64`] have the significant digits (excluding the
+ /// hidden bit) for a value closest to `1e35` of:
+ /// - `f32`: `110100001001100001100`
+ /// - `f64`: `11010000100110000110000000000000000000000000000000`
+ ///
+ /// However, the [`f64`] is displayed as `1.0000000409184788e+35`,
+ /// while the value closest to `1e35` in [`f64`] is
+ /// `11010000100110000101110010110001110100110110000010`. This
+ /// makes it look like precision has been lost but this is
+ /// due to the approximations used to represent binary values as
+ /// a decimal.
+ ///
+ /// This does not respect signalling NaNs: if the value
+ /// is NaN or inf, then it will return that value.
+ #[inline]
+ pub const fn from_f32_lossless(value: f32) -> Option {
+ try_from_lossless!(
+ value => value,
+ half => f16,
+ full => f32,
+ half_bits => u16,
+ full_bits => u32,
+ to_half => from_f32
+ )
+ }
+
/// Constructs a 16-bit floating point value from a 64-bit floating point
/// value.
///
@@ -160,6 +198,41 @@ impl f16 {
f16(arch::f64_to_f16(value))
}
+ /// Create a [`struct@f16`] loslessly from an [`f64`].
+ ///
+ /// This is only true if the [`f64`] is non-finite
+ /// (infinite or NaN), or the exponent can be represented
+ /// by a normal [`struct@f16`] and no non-zero bits would
+ /// be truncated.
+ ///
+ /// "Lossless" does not mean the data is represented the
+ /// same as a decimal number. For example, an [`f32`]
+ /// and [`f64`] have the significant digits (excluding the
+ /// hidden bit) for a value closest to `1e35` of:
+ /// - `f32`: `110100001001100001100`
+ /// - `f64`: `11010000100110000110000000000000000000000000000000`
+ ///
+ /// However, the [`f64`] is displayed as `1.0000000409184788e+35`,
+ /// while the value closest to `1e35` in [`f64`] is
+ /// `11010000100110000101110010110001110100110110000010`. This
+ /// makes it look like precision has been lost but this is
+ /// due to the approximations used to represent binary values as
+ /// a decimal.
+ ///
+ /// This does not respect signalling NaNs: if the value
+ /// is NaN or inf, then it will return that value.
+ #[inline]
+ pub const fn from_f64_lossless(value: f64) -> Option {
+ try_from_lossless!(
+ value => value,
+ half => f16,
+ full => f64,
+ half_bits => u16,
+ full_bits => u64,
+ to_half => from_f64
+ )
+ }
+
/// Converts a [`struct@f16`] into the underlying bit representation.
#[inline]
#[must_use]
@@ -963,6 +1036,24 @@ impl From for f16 {
}
}
+impl TryFrom for f16 {
+ type Error = TryFromFloatError;
+
+ #[inline]
+ fn try_from(x: f32) -> Result {
+ Self::from_f32_lossless(x).ok_or(TryFromFloatError(()))
+ }
+}
+
+impl TryFrom for f16 {
+ type Error = TryFromFloatError;
+
+ #[inline]
+ fn try_from(x: f64) -> Result {
+ Self::from_f64_lossless(x).ok_or(TryFromFloatError(()))
+ }
+}
+
impl PartialEq for f16 {
#[inline]
fn eq(&self, other: &f16) -> bool {
@@ -1908,4 +1999,126 @@ mod test {
assert_eq!(const_bits, bits);
assert!(inst_bits.abs_diff(bits) <= max_diff);
}
+
+ #[test]
+ fn from_f32_lossless() {
+ let from_f32 = |v: f32| f16::from_f32_lossless(v);
+ let roundtrip = |v: f32, expected: Option| {
+ let half = from_f32(v);
+ assert_eq!(half, expected);
+ if !expected.is_none() {
+ let as_f32 = expected.unwrap().to_f32_const();
+ assert_eq!(v, as_f32);
+ }
+ };
+
+ assert_eq!(from_f32(f32::NAN).map(f16::is_nan), Some(true));
+ roundtrip(f32::INFINITY, Some(f16::INFINITY));
+ roundtrip(f32::NEG_INFINITY, Some(f16::NEG_INFINITY));
+ roundtrip(f32::from_bits(0b0_00000000_00000000000000000000000), Some(f16(0)));
+ roundtrip(f32::from_bits(0b1_00000000_00000000000000000000000), Some(f16(f16::SIGN_MASK)));
+ roundtrip(f32::from_bits(1), None);
+
+ // special truncation with denormals, etc.
+ roundtrip(f32::from_bits(0b0_01100111_00000000000000000000000), Some(f16(1)));
+ roundtrip(f32::from_bits(0b0_01101000_00000000000000000000000), Some(f16(2)));
+ roundtrip(f32::from_bits(0b0_01101000_10000000000000000000000), Some(f16(3)));
+ roundtrip(f32::from_bits(0b0_01100111_10000000000000000000000), None);
+ roundtrip(f32::from_bits(0b0_01101000_11000000000000000000000), None);
+ // ~2.2888184e-5 and has bits until 16 to the end, so truncated 2. but this is
+ // denormal as f16
+ roundtrip(f32::from_bits(0b0_01101111_00000000000000000000000), Some(f16(0x100)));
+ roundtrip(f32::from_bits(0b0_01101111_10000000000000000000000), Some(f16(0x180)));
+ roundtrip(f32::from_bits(0b0_01101111_11000000000000000000000), Some(f16(0x1c0)));
+ roundtrip(f32::from_bits(0b0_01101111_11000001000000000000000), Some(f16(0x1c1)));
+ roundtrip(f32::from_bits(0b0_01101111_11000001100000000000000), None);
+ //2.0f32
+ roundtrip(f32::from_bits(0b0_10000000_00000000000000000000000), Some(f16(0x4000)));
+ roundtrip(f32::from_bits(0b0_10000000_10000000000000000000000), Some(f16(0x4200)));
+ roundtrip(f32::from_bits(0b0_10000000_10000000010000000000000), Some(f16(0x4201)));
+ roundtrip(f32::from_bits(0b0_10000000_10000000011000000000000), None);
+ // check overflow
+ roundtrip(f32::from_bits(0b0_10001111_00000000000000000000000), None);
+ roundtrip(f32::from_bits(0b0_10001110_00000000000000000000000), Some(f16(0x7800)));
+ }
+
+ #[test]
+ fn from_f64_lossless() {
+ let from_f64 = |v: f64| f16::from_f64_lossless(v);
+ let roundtrip = |v: f64, expected: Option| {
+ let half = from_f64(v);
+ assert_eq!(half, expected);
+ if !expected.is_none() {
+ let as_f64 = expected.unwrap().to_f64_const();
+ assert_eq!(v, as_f64);
+ }
+ };
+
+ assert_eq!(from_f64(f64::NAN).map(f16::is_nan), Some(true));
+ roundtrip(f64::INFINITY, Some(f16::INFINITY));
+ roundtrip(f64::NEG_INFINITY, Some(f16::NEG_INFINITY));
+ roundtrip(
+ f64::from_bits(0b0_00000000000_0000000000000000000000000000000000000000000000000000),
+ Some(f16(0)),
+ );
+ roundtrip(
+ f64::from_bits(0b1_00000000000_0000000000000000000000000000000000000000000000000000),
+ Some(f16(f16::SIGN_MASK)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01110001010_1010100101011010010110110111111110000111101000001111),
+ None,
+ );
+ // check overflow to inf
+ roundtrip(
+ f64::from_bits(0b0_10000001110_1000000000000000000000000000000000000000000000000000),
+ Some(f16(0x7a00)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_10000001111_1000000000000000000000000000000000000000000000000000),
+ None,
+ );
+ // check denormals and truncation
+ roundtrip(
+ f64::from_bits(0b0_01111100111_0000000000000000000000000000000000000000000000000000),
+ Some(f16(1)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01111100111_1000000000000000000000000000000000000000000000000000),
+ None,
+ );
+ roundtrip(
+ f64::from_bits(0b0_01111101000_0000000000000000000000000000000000000000000000000000),
+ Some(f16(2)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01111101000_1000000000000000000000000000000000000000000000000000),
+ Some(f16(3)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01111101000_1100000000000000000000000000000000000000000000000000),
+ None,
+ );
+ // check basic, normal and positive numbers
+ roundtrip(
+ f64::from_bits(0b0_01111111000_0000000000000000000000000000000000000000000000000000),
+ Some(f16(0x2000)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01111111000_1000000000000000000000000000000000000000000000000000),
+ Some(f16(0x2200)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01111111000_1110000000000000000000000000000000000000000000000000),
+ Some(f16(0x2380)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01111111000_1110000001000000000000000000000000000000000000000000),
+ Some(f16(0x2381)),
+ );
+ roundtrip(
+ f64::from_bits(0b0_01111111000_1110000001100000000000000000000000000000000000000000),
+ None,
+ );
+ }
}
diff --git a/src/error.rs b/src/error.rs
new file mode 100644
index 0000000..b4c0d5e
--- /dev/null
+++ b/src/error.rs
@@ -0,0 +1,14 @@
+//! Error type for numeric conversion functions.
+
+use core::fmt;
+
+/// The error type returned when a checked integral type conversion fails.
+pub struct TryFromFloatError(pub(crate) ());
+
+impl fmt::Display for TryFromFloatError {
+ #[inline]
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ let msg = "out of range integral type conversion attempted";
+ fmt::Display::fmt(msg, f)
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index 415725a..f02942a 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -70,11 +70,14 @@
mod bfloat;
mod binary16;
+mod error;
mod leading_zeros;
mod slice;
+mod try_from;
pub use bfloat::bf16;
pub use binary16::f16;
+pub use error::TryFromFloatError;
#[cfg(not(target_arch = "spirv"))]
pub use crate::slice::{HalfBitsSliceExt, HalfFloatSliceExt};
diff --git a/src/try_from.rs b/src/try_from.rs
new file mode 100644
index 0000000..1caa177
--- /dev/null
+++ b/src/try_from.rs
@@ -0,0 +1,134 @@
+// Try to convert a value from a wider type, only converting
+// the value if it can be losslessly converted, otherwise returning none.
+macro_rules! try_from_lossless {
+ (
+ value =>
+ $value:ident,half =>
+ $half:ident,full =>
+ $full:ident,half_bits =>
+ $half_bits:ident,full_bits =>
+ $full_bits:ident,to_half =>
+ $to_half:ident
+ ) => {{
+ // let's use `f16` and `f64` as an example:
+ // the `f64` is broken down into the following components:
+ // - sign: 1
+ // - exp: 111111111110000000000000000000000000000000000000000000000000000
+ // - mant: 1111111111111111111111111111111111111111111111111111
+ //
+ // value is stored as `2^n` where the lowest bit of the exp is the implicit
+ // hidden bit, that is, the `1`, while the top bit of `mantissa` is `1/2`,
+ // then `1/4`, etc. for an `f16`, we then have:
+ // - sign: 1
+ // - exp: 11111
+ // - mant: 1111111111
+ // or the bottom 42 bits are truncated during the conversion, if the exponents
+ // are in range. we only need to consider special cases, that is, subnormal
+ // floats, where all exponent bits are 0, if both types have the same number of
+ // exponent bits (`f32` to `bf16`). so only if the bottom `52 - 10` bits are
+ // `0`, then it has a lossless conversion
+ //
+ // we do have special cases for non-finite values, NaN and +/- infinity. since
+ // a NaN is still a NaN no matter the lower `51` bits in the longer type, if
+ // we can ignore the result no matter the lower bits.
+ let bits: $full_bits = unsafe { core::mem::transmute($value) };
+
+ // get our masks and extract the IEEE754 components.
+ const FULL_MANTISSA_BITS: u32 = <$full>::MANTISSA_DIGITS - 1;
+ const FULL_SIGN_MASK: $full_bits = 1 << (<$full_bits>::BITS - 1);
+ const FULL_EXPONENT_MASK: $full_bits =
+ (<$full>::MAX_EXP as $full_bits * 2 - 1) << FULL_MANTISSA_BITS;
+ const FULL_MANTISSA_MASK: $full_bits = (1 << FULL_MANTISSA_BITS) - 1;
+ let full_sign = bits & FULL_SIGN_MASK;
+ let full_exp = bits & FULL_EXPONENT_MASK;
+ let full_mant = bits & FULL_MANTISSA_MASK;
+
+ const HALF_MANTISSA_BITS: u32 = <$half>::MANTISSA_DIGITS - 1;
+ const HALF_EXPONENT_MASK: $half_bits =
+ (<$half>::MAX_EXP as $half_bits * 2 - 1) << HALF_MANTISSA_BITS;
+ let sign_shift = <$full_bits>::BITS - <$half_bits>::BITS;
+ let half_sign = (full_sign >> sign_shift) as $half_bits;
+
+ // we use the number of bits without the hidden bit.
+ // we want to know the number of bits truncated and a mask for
+ // all bits that could be truncated.
+ const TRUNCATED_BITS: u32 = FULL_MANTISSA_BITS - HALF_MANTISSA_BITS;
+
+ // check for if we have a special (non-finite) number
+ if full_exp == FULL_EXPONENT_MASK {
+ let half_exp = HALF_EXPONENT_MASK;
+ let half_mant = (full_mant >> TRUNCATED_BITS) as $half_bits;
+ return Some($half(half_sign | half_exp | half_mant));
+ }
+
+ // check for zero, which would otherwise underflow
+ if (bits & !FULL_SIGN_MASK) == 0 {
+ return Some($half(half_sign));
+ }
+
+ // need to get our unbiased exponent. exponents are stored with
+ // the value as (2^exp - (2^(expbits-1) - 1)`. the max, unbiased
+ // exp for `bf16` is `127` and the min non-denormal one is `-126`.
+ // we need the hidden bit in this biased exp.
+ const FULL_BIAS: i32 = <$full>::MAX_EXP - 1;
+ let full_biased = (full_exp >> FULL_MANTISSA_BITS) as i32;
+ let full_unbiased = full_biased - FULL_BIAS;
+
+ // now we need to know if our exponent is in the range. our range is from
+ // if your small exp is valid for our float, that is, unbiased it's in
+ // the range `2 - 2^(expbits-1)` or `1 - bias` for a normal float
+ // (biased exp `>= 1`), but a denormal float works so we want
+ // `1 - bias`. Our max exp finite is `2^(expbits-1) - 1` or `bias`.
+ // all special values always are valid, so we also accept when all
+ // the exponent bits are set. we have a special case: when the two
+ // exponents are the same number of bits: then it's **ALWAYS** valid.
+ //
+ // but this still needs to consider denormal values, or where we have
+ // no exp bits
+ const HALF_BIAS: i32 = <$half>::MAX_EXP - 1;
+ const HALF_MIN_EXP: i32 = 1 - HALF_BIAS;
+ const FULL_EXP_BITS: u32 = <$full_bits>::BITS - FULL_MANTISSA_BITS - 1;
+ const HALF_EXP_BITS: u32 = <$half_bits>::BITS - HALF_MANTISSA_BITS - 1;
+ const HALF_MAX_EXP: i32 = HALF_BIAS;
+ const HALF_MIN_DENORMAL_EXP: i32 = HALF_MIN_EXP - HALF_MANTISSA_BITS as i32;
+ let exp_in_range = if FULL_EXP_BITS == HALF_EXP_BITS {
+ true
+ } else {
+ full_unbiased >= HALF_MIN_DENORMAL_EXP && full_unbiased <= HALF_MAX_EXP
+ };
+ if !exp_in_range {
+ return None;
+ }
+
+ // get if we have any truncated bits, otherwise, we have an exact result
+ let half_biased = full_unbiased + HALF_BIAS;
+ let is_denormal = half_biased <= 0;
+ let truncated_bits = if is_denormal {
+ // NOTE: This needs an extra bit for what was formerly the hidden bit
+ (TRUNCATED_BITS as i32 - half_biased + 1) as u32
+ } else {
+ TRUNCATED_BITS
+ };
+ let truncated_mask: $full_bits = (1 << truncated_bits) - 1;
+ if bits & truncated_mask != 0 {
+ return None;
+ }
+
+ // now we need to reassemble our float components. remember if we have
+ // a denormal float in the result we need to move our implicit hidden
+ // bit out.
+ let full_hidden_bit: $full_bits = 1 << FULL_MANTISSA_BITS;
+ let (half_mant, half_exp) = if is_denormal {
+ let half_mant = ((full_mant | full_hidden_bit) >> truncated_bits) as $half_bits;
+ (half_mant, 0)
+ } else {
+ let half_mant = (full_mant >> truncated_bits) as $half_bits;
+ let half_exp = (half_biased as $half_bits) << HALF_MANTISSA_BITS;
+ (half_mant, half_exp)
+ };
+
+ Some($half(half_sign | half_exp | half_mant))
+ }};
+}
+
+pub(crate) use try_from_lossless;