From 1833c0cfe8bdaa29379c763dd733a09903a35801 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Wed, 31 Jul 2024 21:35:23 +0000 Subject: [PATCH] Add float8_e4m3 --- CHANGELOG.md | 3 + README.md | 5 + ml_dtypes/__init__.py | 3 + ml_dtypes/_finfo.py | 64 +++++++++++++ ml_dtypes/_src/dtypes.cc | 28 ++++++ ml_dtypes/include/float8.h | 126 ++++++++++++++++++++++++- ml_dtypes/tests/custom_float_test.py | 8 ++ ml_dtypes/tests/finfo_test.py | 1 + ml_dtypes/tests/float8_test.cc | 134 +++++++++++++++++++++++++-- 9 files changed, 363 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 618fe20b..9baeb540 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,9 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] +* Added new 8-bit float type following IEEE 754 convention: + `ml_dtypes.float8_e4m3`. + ## [0.4.0] - 2024-04-1 * Updates `ml_dtypes` for compatibility with future NumPy 2.0 release. diff --git a/README.md b/README.md index 164da000..4921b49b 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ an alternative to the standard [`float16`](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) format - `float8_*`: several experimental 8-bit floating point representations including: + * `float8_e4m3` * `float8_e4m3b11fnuz` * `float8_e4m3fn` * `float8_e4m3fnuz` @@ -64,6 +65,10 @@ A `bfloat16` number is a single-precision float truncated at 16 bits. Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf. +### `float8_e4m3` + +Exponent: 4, Mantissa: 3, bias: 7. IEEE 754, with NaN and inf. + ### `float8_e4m3b11fnuz` Exponent: 4, Mantissa: 3, bias: 11. diff --git a/ml_dtypes/__init__.py b/ml_dtypes/__init__.py index 80800285..fe0b1891 100644 --- a/ml_dtypes/__init__.py +++ b/ml_dtypes/__init__.py @@ -17,6 +17,7 @@ "__version__", "bfloat16", "finfo", + "float8_e4m3", "float8_e4m3b11fnuz", "float8_e4m3fn", "float8_e4m3fnuz", @@ -34,6 +35,7 @@ from ml_dtypes._finfo import finfo from ml_dtypes._iinfo import iinfo from ml_dtypes._ml_dtypes_ext import bfloat16 +from ml_dtypes._ml_dtypes_ext import float8_e4m3 from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz from ml_dtypes._ml_dtypes_ext import float8_e4m3fn from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz @@ -46,6 +48,7 @@ import numpy as np bfloat16: Type[np.generic] +float8_e4m3: Type[np.generic] float8_e4m3b11fnuz: Type[np.generic] float8_e4m3fn: Type[np.generic] float8_e4m3fnuz: Type[np.generic] diff --git a/ml_dtypes/_finfo.py b/ml_dtypes/_finfo.py index 451f2766..3f7aa48d 100644 --- a/ml_dtypes/_finfo.py +++ b/ml_dtypes/_finfo.py @@ -17,6 +17,7 @@ from typing import Dict from ml_dtypes._ml_dtypes_ext import bfloat16 +from ml_dtypes._ml_dtypes_ext import float8_e4m3 from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz from ml_dtypes._ml_dtypes_ext import float8_e4m3fn from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz @@ -25,6 +26,7 @@ import numpy as np _bfloat16_dtype = np.dtype(bfloat16) +_float8_e4m3_dtype = np.dtype(float8_e4m3) _float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype = np.dtype(float8_e4m3fnuz) @@ -41,6 +43,15 @@ def __init__(self): self.smallest_subnormal = bfloat16(smallest_subnormal) +class _Float8E4m3MachArLike: + + def __init__(self): + smallest_normal = float.fromhex("0x1p-6") + self.smallest_normal = float8_e4m3(smallest_normal) + smallest_subnormal = float.fromhex("0x1p-9") + self.smallest_subnormal = float8_e4m3(smallest_subnormal) + + class _Float8E4m3b11fnuzMachArLike: def __init__(self): @@ -135,6 +146,51 @@ def float_to_str(f): # pylint: enable=protected-access return obj + @staticmethod + def _float8_e4m3_finfo(): + def float_to_str(f): + return "%6.2e" % float(f) + + tiny = float.fromhex("0x1p-6") # 1/64 min normal + resolution = 0.1 + eps = float.fromhex("0x1p-3") # 1/8 + epsneg = float.fromhex("0x1p-4") # 1/16 + max_ = float.fromhex("0x1.Ep7") # 240 max normal + + obj = object.__new__(np.finfo) + obj.dtype = _float8_e4m3_dtype + obj.bits = 8 + obj.eps = float8_e4m3(eps) + obj.epsneg = float8_e4m3(epsneg) + obj.machep = -3 + obj.negep = -4 + obj.max = float8_e4m3(max_) + obj.min = float8_e4m3(-max_) + obj.nexp = 4 + obj.nmant = 3 + obj.iexp = obj.nexp + obj.maxexp = 8 + obj.minexp = -6 + obj.precision = 1 + obj.resolution = float8_e4m3(resolution) + # pylint: disable=protected-access + obj._machar = _Float8E4m3MachArLike() + if not hasattr(obj, "tiny"): + obj.tiny = float8_e4m3(tiny) + if not hasattr(obj, "smallest_normal"): + obj.smallest_normal = obj._machar.smallest_normal + obj.smallest_subnormal = obj._machar.smallest_subnormal + + obj._str_tiny = float_to_str(tiny) + obj._str_smallest_normal = float_to_str(tiny) + obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) + obj._str_max = float_to_str(max_) + obj._str_epsneg = float_to_str(epsneg) + obj._str_eps = float_to_str(eps) + obj._str_resolution = float_to_str(resolution) + # pylint: enable=protected-access + return obj + @staticmethod def _float8_e4m3b11fnuz_finfo(): def float_to_str(f): @@ -369,6 +425,14 @@ def __new__(cls, dtype): if _bfloat16_dtype not in cls._finfo_cache: cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo() return cls._finfo_cache[_bfloat16_dtype] + if ( + isinstance(dtype, str) + and dtype == "float8_e4m3" + or dtype == _float8_e4m3_dtype + ): + if _float8_e4m3_dtype not in cls._finfo_cache: + cls._finfo_cache[_float8_e4m3_dtype] = cls._float8_e4m3_finfo() + return cls._finfo_cache[_float8_e4m3_dtype] if ( isinstance(dtype, str) and dtype == "float8_e4m3b11fnuz" diff --git a/ml_dtypes/_src/dtypes.cc b/ml_dtypes/_src/dtypes.cc index 1dde49b4..87f7578f 100644 --- a/ml_dtypes/_src/dtypes.cc +++ b/ml_dtypes/_src/dtypes.cc @@ -60,6 +60,20 @@ struct TypeDescriptor : CustomFloatType { static constexpr char kNpyDescrByteorder = '='; }; +template <> +struct TypeDescriptor : CustomFloatType { + typedef float8_e4m3 T; + static constexpr bool is_floating = true; + static constexpr bool is_integral = false; + static constexpr const char* kTypeName = "float8_e4m3"; + static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e4m3"; + static constexpr const char* kTpDoc = "float8_e4m3 floating-point values"; + // Set e4m3 kind as Void since kind=f (float) with itemsize=1 is used by e5m2 + static constexpr char kNpyDescrKind = 'V'; // Void + static constexpr char kNpyDescrType = '7'; // '4' is reserved for e4m3fn + static constexpr char kNpyDescrByteorder = '='; // Native byte order +}; + template <> struct TypeDescriptor : CustomFloatType { @@ -269,6 +283,9 @@ bool Initialize() { if (!RegisterFloatDtype(numpy.get())) { return false; } + if (!RegisterFloatDtype(numpy.get())) { + return false; + } if (!RegisterFloatDtype(numpy.get())) { return false; } @@ -319,6 +336,12 @@ bool Initialize() { success &= RegisterTwoWayCustomCast(); success &= RegisterTwoWayCustomCast(); success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); success &= RegisterOneWayCustomCast(); success &= RegisterOneWayCustomCast(); return success; @@ -349,6 +372,11 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() { return nullptr; } + if (PyObject_SetAttrString(m.get(), "float8_e4m3", + reinterpret_cast( + TypeDescriptor::type_ptr)) < 0) { + return nullptr; + } if (PyObject_SetAttrString( m.get(), "float8_e4m3b11fnuz", reinterpret_cast( diff --git a/ml_dtypes/include/float8.h b/ml_dtypes/include/float8.h index 65b177a1..93aa0da4 100644 --- a/ml_dtypes/include/float8.h +++ b/ml_dtypes/include/float8.h @@ -43,6 +43,7 @@ namespace ml_dtypes { namespace float8_internal { // Forward-declarations of classes. +class float8_e4m3; class float8_e4m3fn; class float8_e4m3fnuz; class float8_e4m3b11fnuz; @@ -243,6 +244,20 @@ template using RequiresIsDerivedFromFloat8Base = std::enable_if_t, T>, int>; +class float8_e4m3 : public float8_base { + // Exponent: 4, Mantissa: 3, bias: 7. + // IEEE 754. + private: + using Base = float8_base; + friend class float8_base; + using Base::Base; + + public: + template = 0> + explicit EIGEN_DEVICE_FUNC float8_e4m3(T f8) + : float8_e4m3(ConvertFrom(f8)) {} +}; + class float8_e4m3fn : public float8_base { // Exponent: 4, Mantissa: 3, bias: 7. // Extended range: no inf, NaN represented by 0bS111'1111. @@ -369,6 +384,8 @@ class float8_e5m2fnuz : public float8_base { public: explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(const float8_e5m2& f8) : float8_e5m2fnuz(ConvertFrom(f8)) {} + explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(const float8_e4m3& f8) + : float8_e5m2fnuz(ConvertFrom(f8)) {} explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(const float8_e4m3b11fnuz& f8) : float8_e5m2fnuz(ConvertFrom(f8)) {} explicit EIGEN_DEVICE_FUNC float8_e5m2fnuz(const float8_e4m3fn& f8) @@ -473,6 +490,70 @@ struct numeric_limits_float8_base { // NOLINTEND }; +struct numeric_limits_float8_e4m3 : public numeric_limits_float8_base { + private: + static inline constexpr const int kExponentBias = 7; + static inline constexpr const int kMantissaBits = 3; + + public: + // NOLINTBEGIN: these names must match std::numeric_limits. + static inline constexpr const int digits = kMantissaBits + 1; + static inline constexpr const int digits10 = Digits10FromDigits(digits); + static inline constexpr const int max_digits10 = + MaxDigits10FromDigits(digits); + static inline constexpr const int min_exponent = (1 - kExponentBias) + 1; + static inline constexpr const int min_exponent10 = + MinExponent10FromMinExponent(min_exponent); + static inline constexpr const int max_exponent = 0b1111 - kExponentBias; + static inline constexpr const int max_exponent10 = + MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); + static inline constexpr const bool is_iec559 = true; + static inline constexpr const bool has_infinity = true; + static inline constexpr const bool has_signaling_NaN = true; + // NOLINTEND + + // 1.0 * 2^(0b0001 - 7) = 1.0 * 2^-6 = 1/64 (min normal) + static constexpr float8_e4m3 min() { + return float8_e4m3::FromRep(1 << kMantissaBits); + } + // -(1 + 0b111 * 2^-2) * 2^(0b1110 - 7) = -(1 + 7/8) * 2^7 = -240 + static constexpr float8_e4m3 lowest() { + return float8_e4m3::FromRep(0b1'1110'111); + } + // (1 + 0b111 * 2^-2) * 2^(0b1110 - 7) = (1 + 7/8) * 2^7 = 240 + static constexpr float8_e4m3 max() { + return float8_e4m3::FromRep(0b0'1110'111); + } + // 1.0 * 2^-3 = 0.125 + static constexpr float8_e4m3 epsilon() { + return float8_e4m3::FromRep((-kMantissaBits + kExponentBias) + << kMantissaBits); + } + // 1.0 * 2^-1 = 0.5 + static constexpr float8_e4m3 round_error() { + return float8_e4m3::FromRep((-1 + kExponentBias) << kMantissaBits); + } + static constexpr float8_e4m3 infinity() { + return float8_e4m3::FromRep(0b0'1111'000); + } + static constexpr float8_e4m3 quiet_NaN() { + // IEEE 754-2019 6.2.1: "All binary NaN bit strings have the sign bit S set + // to 0 or 1 and all the bits of the biased exponent field E set to 1 + // (see 3.4). A quiet NaN bit string should be encoded with the first bit + // (d1) of the trailing significand field T being 1." + return float8_e4m3::FromRep(0b0'1111'100); + } + static constexpr float8_e4m3 signaling_NaN() { + // IEEE 754-2019 6.2.1: "A signaling NaN bit string should be encoded with + // the first bit of the trailing significand field being 0." + return float8_e4m3::FromRep(0b0'1111'001); + } + // 2^(-6) * 2^(-3) = 2^-9 = 1/512 (min denormal) + static constexpr float8_e4m3 denorm_min() { + return float8_e4m3::FromRep(0b0'0000'001); + } +}; + struct numeric_limits_float8_e4m3fn : public numeric_limits_float8_base { private: static inline constexpr const int kExponentBias = 7; @@ -769,6 +850,10 @@ struct numeric_limits_float8_e5m2fnuz : public numeric_limits_float8_base { namespace std { // Standard-library overrides. Note that these are picked up by Eigen as well. +template <> +struct numeric_limits + : public ml_dtypes::float8_internal::numeric_limits_float8_e4m3 {}; + template <> struct numeric_limits : public ml_dtypes::float8_internal::numeric_limits_float8_e4m3fn {}; @@ -793,6 +878,14 @@ struct numeric_limits namespace ml_dtypes { namespace float8_internal { +constexpr inline float8_e4m3 abs(const float8_e4m3& a) { + return float8_e4m3::FromRep(a.rep() & 0b0'1111'111); +} + +constexpr inline bool(isnan)(const float8_e4m3& a) { + return abs(a).rep() > std::numeric_limits::infinity().rep(); +} + // Free-functions for use with ADL and in Eigen. constexpr inline float8_e4m3fn abs(const float8_e4m3fn& a) { return float8_e4m3fn::FromRep(a.rep() & 0b0'1111'111); @@ -1175,7 +1268,7 @@ struct ConvertImpl struct ConvertImpl { static EIGEN_DEVICE_FUNC inline float8_e5m2 run(float8_e4m3fn from) { @@ -1278,6 +1371,7 @@ EIGEN_DEVICE_FUNC To float8_base::ConvertTo(Derived from) { } // namespace float8_internal // Exported types. +using float8_e4m3 = float8_internal::float8_e4m3; using float8_e4m3fn = float8_internal::float8_e4m3fn; using float8_e4m3fnuz = float8_internal::float8_e4m3fnuz; using float8_e4m3b11fnuz = float8_internal::float8_e4m3b11fnuz; @@ -1290,6 +1384,18 @@ using float8_e5m2fnuz = float8_internal::float8_e5m2fnuz; namespace Eigen { namespace numext { +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC ml_dtypes::float8_e4m3 +bit_cast(const uint8_t &src) { + return ml_dtypes::float8_e4m3::FromRep(src); +} + +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint8_t +bit_cast(const ml_dtypes::float8_e4m3 &src) { + return src.rep(); +} + template <> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC ml_dtypes::float8_e4m3fn bit_cast(const uint8_t& src) { @@ -1319,6 +1425,12 @@ bit_cast(const ml_dtypes::float8_e5m2& src) { // Work-around for isinf/isnan/isfinite issue on aarch64. namespace internal { +template <> +EIGEN_DEVICE_FUNC inline bool isinf_impl( + const ml_dtypes::float8_e4m3& x) { + return ml_dtypes::float8_internal::isinf(x); +} + template <> EIGEN_DEVICE_FUNC inline bool isinf_impl( const ml_dtypes::float8_e4m3fn& x) { @@ -1349,6 +1461,12 @@ EIGEN_DEVICE_FUNC inline bool isinf_impl( return ml_dtypes::float8_internal::isinf(x); } +template <> +EIGEN_DEVICE_FUNC inline bool isnan_impl( + const ml_dtypes::float8_e4m3& x) { + return ml_dtypes::float8_internal::isnan(x); +} + template <> EIGEN_DEVICE_FUNC inline bool isnan_impl( const ml_dtypes::float8_e4m3fn& x) { @@ -1379,6 +1497,12 @@ EIGEN_DEVICE_FUNC inline bool isnan_impl( return ml_dtypes::float8_internal::isnan(x); } +template <> +EIGEN_DEVICE_FUNC inline bool isfinite_impl( + const ml_dtypes::float8_e4m3& x) { + return ml_dtypes::float8_internal::isfinite(x); +} + template <> EIGEN_DEVICE_FUNC inline bool isfinite_impl( const ml_dtypes::float8_e4m3fn& x) { diff --git a/ml_dtypes/tests/custom_float_test.py b/ml_dtypes/tests/custom_float_test.py index 5333b487..cd76b3ec 100644 --- a/ml_dtypes/tests/custom_float_test.py +++ b/ml_dtypes/tests/custom_float_test.py @@ -30,6 +30,7 @@ import numpy as np bfloat16 = ml_dtypes.bfloat16 +float8_e4m3 = ml_dtypes.float8_e4m3 float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz @@ -108,6 +109,7 @@ def dtype_has_inf(dtype): FLOAT_DTYPES = [ bfloat16, + float8_e4m3, float8_e4m3b11fnuz, float8_e4m3fn, float8_e4m3fnuz, @@ -146,6 +148,11 @@ def dtype_has_inf(dtype): # Values that should round trip exactly to integer and back. INT_VALUES = { bfloat16: [0, 1, 2, 10, 34, 47, 128, 255, 256, 512], + float8_e4m3: list( + itertools.chain.from_iterable( + range(1 << n, 2 << n, 1 << max(0, n - 3)) for n in range(8) + ) + ), float8_e4m3b11fnuz: [*range(16), *range(16, 30, 2)], float8_e4m3fn: list( itertools.chain.from_iterable( @@ -171,6 +178,7 @@ def dtype_has_inf(dtype): BITS_TYPE = { bfloat16: np.uint16, + float8_e4m3: np.uint8, float8_e4m3b11fnuz: np.uint8, float8_e4m3fn: np.uint8, float8_e4m3fnuz: np.uint8, diff --git a/ml_dtypes/tests/finfo_test.py b/ml_dtypes/tests/finfo_test.py index 855c00ba..3999476b 100644 --- a/ml_dtypes/tests/finfo_test.py +++ b/ml_dtypes/tests/finfo_test.py @@ -19,6 +19,7 @@ ALL_DTYPES = [ ml_dtypes.bfloat16, + ml_dtypes.float8_e4m3, ml_dtypes.float8_e4m3b11fnuz, ml_dtypes.float8_e4m3fn, ml_dtypes.float8_e4m3fnuz, diff --git a/ml_dtypes/tests/float8_test.cc b/ml_dtypes/tests/float8_test.cc index ae20290a..c3a4841e 100644 --- a/ml_dtypes/tests/float8_test.cc +++ b/ml_dtypes/tests/float8_test.cc @@ -40,6 +40,8 @@ struct Float8TestParamNames { return "float8_e4m3fn"; } else if constexpr (std::is_same_v) { return "float8_e4m3b11fnuz"; + } else if constexpr (std::is_same_v) { + return "float8_e4m3"; } else if constexpr (std::is_same_v) { return "float8_e5m2"; } else if constexpr (std::is_same_v) { @@ -52,11 +54,42 @@ struct Float8TestParamNames { }; using Float8Types = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(Float8Test, Float8Types, Float8TestParamNames); TEST(Float8E4m3Test, NumericLimits) { + EXPECT_TRUE( + Eigen::numext::isnan(std::numeric_limits::quiet_NaN())); + EXPECT_TRUE( + Eigen::numext::isnan(std::numeric_limits::signaling_NaN())); + EXPECT_EQ(static_cast(std::numeric_limits::min()), + std::exp2(-6)); + EXPECT_EQ(static_cast(std::numeric_limits::max()), 240); + EXPECT_EQ(static_cast(std::numeric_limits::lowest()), + -240); + EXPECT_EQ(static_cast(std::numeric_limits::epsilon()), + 0.125); + EXPECT_EQ(static_cast(std::numeric_limits::round_error()), + 0.5); + EXPECT_TRUE( + Eigen::numext::isinf(std::numeric_limits::infinity())); + EXPECT_EQ(static_cast(std::numeric_limits::denorm_min()), + std::exp2(-9)); + EXPECT_EQ(std::numeric_limits::digits, 4); + EXPECT_EQ(std::numeric_limits::digits10, 0); + EXPECT_EQ(std::numeric_limits::max_digits10, 3); + EXPECT_EQ(std::numeric_limits::min_exponent, -5); + EXPECT_EQ(std::numeric_limits::min_exponent10, -1); + EXPECT_EQ(std::numeric_limits::max_exponent, 8); + EXPECT_EQ(std::numeric_limits::max_exponent10, 2); + EXPECT_EQ(std::numeric_limits::is_iec559, true); + EXPECT_EQ(std::numeric_limits::has_infinity, true); + EXPECT_EQ(std::numeric_limits::has_quiet_NaN, true); + EXPECT_EQ(std::numeric_limits::has_signaling_NaN, true); +} + +TEST(Float8E4m3fnTest, NumericLimits) { EXPECT_TRUE( Eigen::numext::isnan(std::numeric_limits::quiet_NaN())); EXPECT_TRUE(Eigen::numext::isnan( @@ -453,7 +486,7 @@ TYPED_TEST(Float8Test, DoubleRound) { #endif } -TEST(Float8Test, Float8E5m2_To_Float8E4m3) { +TEST(Float8Test, Float8E5m2_To_Float8E4m3fn) { // Saturation. float8_e5m2 max = std::numeric_limits::max(); float8_e4m3fn saturated = float8_e4m3fn::ConvertFrom(max); @@ -473,12 +506,12 @@ TEST(Float8Test, Float8E5m2_To_Float8E4m3) { EXPECT_EQ(truncated_subnorm.rep(), 0x03); } -TEST(Float8Test, Half_To_Float8E4m3) { +TEST(Float8Test, Half_To_Float8E4m3fn) { Eigen::half big_half(0x1.dfcp+8f); - float8_e4m3fn big_e4m3 = + float8_e4m3fn big_e4m3fn = float8_e4m3fn::ConvertFrom( big_half); - EXPECT_EQ(big_e4m3.rep(), std::numeric_limits::max().rep()); + EXPECT_EQ(big_e4m3fn.rep(), std::numeric_limits::max().rep()); } TEST(Float8Test, Float8E5m2_To_Float8E4m3b11fnuz) { @@ -529,7 +562,7 @@ TEST(Float8Test, Float8E5m2_To_Float8E4m3b11fnuz) { } } -TEST(Float8Test, Float8E4m3b11fnuz_To_Float8E4m3) { +TEST(Float8Test, Float8E4m3b11fnuz_To_Float8E4m3fn) { // Saturation. float8_e4m3b11fnuz max = std::numeric_limits::max(); float8_e4m3fn saturated = float8_e4m3fn::ConvertFrom(max); @@ -578,6 +611,20 @@ TEST(Float8Test, Float8E4m3b11fnuz_To_Float8E4m3) { } TEST(Float8Test, Float8E4m3_To_Float8E5m2) { + // Truncation and rounding of a number ever-so-slightly less than 2. + float8_e4m3 less_than_two = float8_e4m3::FromRep(0x3F); + float8_e5m2 truncated = + float8_e5m2::template ConvertFrom(less_than_two); + EXPECT_LT(static_cast(truncated), 2); + + float8_e5m2 rounded = + float8_e5m2::template ConvertFrom(less_than_two); + EXPECT_EQ(static_cast(rounded), 2); +} + +TEST(Float8Test, Float8E4m3fn_To_Float8E5m2) { // Truncation and rounding of a number ever-so-slightly less than 2. float8_e4m3fn less_than_two = float8_e4m3fn::FromRep(0x3F); float8_e5m2 truncated = @@ -591,6 +638,67 @@ TEST(Float8Test, Float8E4m3_To_Float8E5m2) { EXPECT_EQ(static_cast(rounded), 2); } +TEST(Float8Test, Half_To_Float8E4m3) { + // Special values, NaN. + Eigen::half inf = + Eigen::numext::bit_cast(static_cast(0x7C00)); + EXPECT_EQ(static_cast(inf).rep(), 0x78); + Eigen::half ninf = + Eigen::numext::bit_cast(static_cast(0xFC00)); + EXPECT_EQ(static_cast(ninf).rep(), 0xF8); + + Eigen::half nan = + Eigen::numext::bit_cast(static_cast(0x7C01)); + EXPECT_EQ(static_cast(nan).rep(), 0x7C); + Eigen::half nnan = + Eigen::numext::bit_cast(static_cast(0xFC01)); + EXPECT_EQ(static_cast(nnan).rep(), 0xFC); + + // Rounding vs truncation. + Eigen::half less_than_two = + Eigen::numext::bit_cast(static_cast(0x3FFF)); + EXPECT_EQ((float8_e4m3::ConvertFrom(less_than_two) + .rep()), + 0x40); + EXPECT_EQ((float8_e4m3::ConvertFrom(less_than_two) + .rep()), + 0x3F); + EXPECT_EQ((float8_e4m3::ConvertFrom(-less_than_two) + .rep()), + 0xC0); + EXPECT_EQ((float8_e4m3::ConvertFrom(-less_than_two) + .rep()), + 0xBF); + + // Saturation. + // f8e4m3=0.1110.111 0x1.Ep+7 f16=0.10110.1110000000 uint16=0x5B80 + // f8e4m3=0.1111.000 0x1.0p+8 f16=0.10111.0000000000 uint16=0x5C00 + for (uint16_t i = 0x5B80; i < 0x5C00; ++i) { + Eigen::half big_half = Eigen::numext::bit_cast(i); + float big_float = static_cast(big_half); + EXPECT_EQ( + (float8_e4m3::ConvertFrom( + big_half) + .rep()), + (float8_e4m3::ConvertFrom( + big_float) + .rep())) + << i; + EXPECT_EQ( + (float8_e4m3::ConvertFrom( + -big_half) + .rep()), + (float8_e4m3::ConvertFrom( + -big_float) + .rep())) + << i; + } +} + TEST(Float8Test, Half_To_Float8E5m2) { // Special values, NaN. Eigen::half inf = @@ -723,6 +831,15 @@ TYPED_TEST(Float8Test, CallTheConstOperator) { } } +TEST(Float8E4m3Test, SmallCastToDenormal) { + // Special edge-case where rounding to a normalized value would + // normally round down, but rounding to a subnormal rounds up. + float x = std::ldexp(1.3125, -8); + float8_e4m3 y = static_cast(x); + float z = static_cast(y); + EXPECT_EQ(z, std::ldexp(1.5, -8)); +} + TEST(Float8E5m2Test, SmallCastToDenormal) { // Special edge-case where rounding to a normalized value would // normally round down, but rounding to a subnormal rounds up. @@ -755,6 +872,7 @@ struct Float8CastTestParamNames { GEN_LONG_DOUBLE_PAIR(Type) \ std::pair, std::pair, \ std::pair, std::pair, \ + std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ @@ -763,7 +881,7 @@ struct Float8CastTestParamNames { #define GEN_TYPE_PAIRS() \ GEN_DEST_TYPES(float8_e4m3fn), GEN_DEST_TYPES(float8_e4m3b11fnuz), \ GEN_DEST_TYPES(float8_e5m2), GEN_DEST_TYPES(float8_e4m3fnuz), \ - GEN_DEST_TYPES(float8_e5m2fnuz) + GEN_DEST_TYPES(float8_e5m2fnuz), GEN_DEST_TYPES(float8_e4m3) using Float8CastTypePairs = ::testing::Types;