diff --git a/CHANGELOG.md b/CHANGELOG.md index 48788cdf..4d567d32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): * Added new 8-bit float types following IEEE 754 convention: `ml_dtypes.float8_e4m3` and `ml_dtypes.float8_e3m4`. +* Added new 4-bit and 6-bit float types: + `ml_dtypes.float4_e2m1fn`, `ml_dtypes.float6_e2m3fn` and `ml_dtypes.float6_e3m2fn`. * Fix outputs of float `divmod` and `floor_divide` when denominator is zero. ## [0.4.0] - 2024-04-1 diff --git a/README.md b/README.md index 45a18bd1..bf6f1066 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,10 @@ * `float8_e4m3fnuz` * `float8_e5m2` * `float8_e5m2fnuz` +- Microscaling (MX) sub-byte floating point representations including: + * `float4_e2m1fn` + * `float6_e2m3fn` + * `float6_e3m2fn` - `int2`, `int4`, `uint2` and `uint4`: low precision integer types. See below for specifications of these number formats. @@ -66,6 +70,39 @@ A `bfloat16` number is a single-precision float truncated at 16 bits. Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf. +### `float4_e2m1fn` + +Exponent: 2, Mantissa: 1, bias: 1. + +Extended range: no inf, no NaN. + +Microscaling format, 4 bits (encoding: `0bSEEM`) using byte storage (higher 4 +bits are unused). NaN representation is undefined. + +Possible absolute values: [`0`, `0.5`, `1`, `1.5`, `2`, `3`, `4`, `6`] + +### `float6_e2m3fn` + +Exponent: 2, Mantissa: 3, bias: 1. + +Extended range: no inf, no NaN. + +Microscaling format, 6 bits (encoding: `0bSEEMMM`) using byte storage (higher 2 +bits are unused). NaN representation is undefined. + +Possible values range: [`-7.5`; `7.5`] + +### `float6_e3m2fn` + +Exponent: 3, Mantissa: 2, bias: 3. + +Extended range: no inf, no NaN. + +Microscaling format, 4 bits (encoding: `0bSEEEMM`) using byte storage (higher 2 +bits are unused). NaN representation is undefined. + +Possible values range: [`-28`; `28`] + ### `float8_e3m4` Exponent: 3, Mantissa: 4, bias: 3. IEEE 754, with NaN and inf. diff --git a/ml_dtypes/__init__.py b/ml_dtypes/__init__.py index 3942db9d..094b6ca7 100644 --- a/ml_dtypes/__init__.py +++ b/ml_dtypes/__init__.py @@ -17,6 +17,9 @@ "__version__", "bfloat16", "finfo", + "float4_e2m1fn", + "float6_e2m3fn", + "float6_e3m2fn", "float8_e3m4", "float8_e4m3", "float8_e4m3b11fnuz", @@ -36,6 +39,9 @@ from ml_dtypes._finfo import finfo from ml_dtypes._iinfo import iinfo from ml_dtypes._ml_dtypes_ext import bfloat16 +from ml_dtypes._ml_dtypes_ext import float4_e2m1fn +from ml_dtypes._ml_dtypes_ext import float6_e2m3fn +from ml_dtypes._ml_dtypes_ext import float6_e3m2fn from ml_dtypes._ml_dtypes_ext import float8_e3m4 from ml_dtypes._ml_dtypes_ext import float8_e4m3 from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz @@ -50,6 +56,9 @@ import numpy as np bfloat16: Type[np.generic] +float4_e2m1fn: Type[np.generic] +float6_e2m3fn: Type[np.generic] +float6_e3m2fn: Type[np.generic] float8_e3m4: Type[np.generic] float8_e4m3: Type[np.generic] float8_e4m3b11fnuz: Type[np.generic] diff --git a/ml_dtypes/_finfo.py b/ml_dtypes/_finfo.py index 9d62e3a2..84989d5a 100644 --- a/ml_dtypes/_finfo.py +++ b/ml_dtypes/_finfo.py @@ -17,6 +17,9 @@ from typing import Dict from ml_dtypes._ml_dtypes_ext import bfloat16 +from ml_dtypes._ml_dtypes_ext import float4_e2m1fn +from ml_dtypes._ml_dtypes_ext import float6_e2m3fn +from ml_dtypes._ml_dtypes_ext import float6_e3m2fn from ml_dtypes._ml_dtypes_ext import float8_e3m4 from ml_dtypes._ml_dtypes_ext import float8_e4m3 from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz @@ -27,6 +30,9 @@ import numpy as np _bfloat16_dtype = np.dtype(bfloat16) +_float4_e2m1fn_dtype = np.dtype(float4_e2m1fn) +_float6_e2m3fn_dtype = np.dtype(float6_e2m3fn) +_float6_e3m2fn_dtype = np.dtype(float6_e3m2fn) _float8_e3m4_dtype = np.dtype(float8_e3m4) _float8_e4m3_dtype = np.dtype(float8_e4m3) _float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz) @@ -45,6 +51,33 @@ def __init__(self): self.smallest_subnormal = bfloat16(smallest_subnormal) +class _Float4E2m1fnMachArLike: + + def __init__(self): + smallest_normal = float.fromhex("0x1p0") + self.smallest_normal = float4_e2m1fn(smallest_normal) + smallest_subnormal = float.fromhex("0x0.8p0") + self.smallest_subnormal = float4_e2m1fn(smallest_subnormal) + + +class _Float6E2m3fnMachArLike: + + def __init__(self): + smallest_normal = float.fromhex("0x1p0") + self.smallest_normal = float6_e2m3fn(smallest_normal) + smallest_subnormal = float.fromhex("0x0.2p0") + self.smallest_subnormal = float6_e2m3fn(smallest_subnormal) + + +class _Float6E3m2fnMachArLike: + + def __init__(self): + smallest_normal = float.fromhex("0x1p-2") + self.smallest_normal = float6_e3m2fn(smallest_normal) + smallest_subnormal = float.fromhex("0x0.4p-2") + self.smallest_subnormal = float6_e3m2fn(smallest_subnormal) + + class _Float8E3m4MachArLike: def __init__(self): @@ -110,7 +143,7 @@ def __init__(self): class finfo(np.finfo): # pylint: disable=invalid-name,missing-class-docstring __doc__ = np.finfo.__doc__ - _finfo_cache: Dict[np.dtype, np.finfo] = {} + _finfo_cache: Dict[type, np.finfo] = {} @staticmethod def _bfloat16_finfo(): @@ -157,6 +190,129 @@ def float_to_str(f): # pylint: enable=protected-access return obj + @staticmethod + def _float4_e2m1fn_finfo(): + eps = float.fromhex("0x0.8p0") # 0.5 + max_ = float.fromhex("0x1.8p2") # 6.0 + + obj = object.__new__(np.finfo) + obj.dtype = _float4_e2m1fn_dtype + obj.bits = 4 + obj.eps = eps + obj.epsneg = eps + obj.machep = -1 + obj.negep = -1 + obj.max = float4_e2m1fn(max_) + obj.min = float4_e2m1fn(-max_) + obj.nexp = 2 + obj.nmant = 1 + obj.iexp = obj.nexp + obj.maxexp = 3 + obj.minexp = 0 + obj.precision = 0 + obj.resolution = float4_e2m1fn(1.0) + # pylint: disable=protected-access + obj._machar = _Float4E2m1fnMachArLike() + tiny = obj._machar.smallest_normal + if not hasattr(obj, "tiny"): + obj.tiny = tiny + if not hasattr(obj, "smallest_normal"): + obj.smallest_normal = tiny + obj.smallest_subnormal = obj._machar.smallest_subnormal + + float_to_str = str + obj._str_tiny = float_to_str(tiny) + obj._str_smallest_normal = float_to_str(tiny) + obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) + obj._str_max = float_to_str(obj.max) + obj._str_epsneg = float_to_str(obj.epsneg) + obj._str_eps = float_to_str(obj.eps) + obj._str_resolution = float_to_str(obj.resolution) + # pylint: enable=protected-access + return obj + + @staticmethod + def _float6_e2m3fn_finfo(): + eps = float.fromhex("0x0.2p0") # 0.125 + max_ = float.fromhex("0x1.Ep2") # 7.5 + + obj = object.__new__(np.finfo) + obj.dtype = _float6_e2m3fn_dtype + obj.bits = 6 + obj.eps = eps + obj.epsneg = eps + obj.machep = -3 + obj.negep = -3 + obj.max = float6_e2m3fn(max_) + obj.min = float6_e2m3fn(-max_) + obj.nexp = 2 + obj.nmant = 3 + obj.iexp = obj.nexp + obj.maxexp = 3 + obj.minexp = 0 + obj.precision = 0 + obj.resolution = float6_e2m3fn(1.0) + # pylint: disable=protected-access + obj._machar = _Float6E2m3fnMachArLike() + tiny = obj._machar.smallest_normal + if not hasattr(obj, "tiny"): + obj.tiny = tiny + if not hasattr(obj, "smallest_normal"): + obj.smallest_normal = tiny + obj.smallest_subnormal = obj._machar.smallest_subnormal + + float_to_str = str + obj._str_tiny = float_to_str(tiny) + obj._str_smallest_normal = float_to_str(tiny) + obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) + obj._str_max = float_to_str(obj.max) + obj._str_epsneg = float_to_str(obj.epsneg) + obj._str_eps = float_to_str(obj.eps) + obj._str_resolution = float_to_str(obj.resolution) + # pylint: enable=protected-access + return obj + + @staticmethod + def _float6_e3m2fn_finfo(): + eps = float.fromhex("0x1p-2") # 0.25 + max_ = float.fromhex("0x1.Cp4") # 28 + + obj = object.__new__(np.finfo) + obj.dtype = _float6_e3m2fn_dtype + obj.bits = 6 + obj.eps = eps + obj.epsneg = eps / 2 + obj.machep = -2 + obj.negep = -3 + obj.max = float6_e3m2fn(max_) + obj.min = float6_e3m2fn(-max_) + obj.nexp = 3 + obj.nmant = 2 + obj.iexp = obj.nexp + obj.maxexp = 5 + obj.minexp = -2 + obj.precision = 0 + obj.resolution = float6_e3m2fn(1.0) + # pylint: disable=protected-access + obj._machar = _Float6E3m2fnMachArLike() + tiny = obj._machar.smallest_normal + if not hasattr(obj, "tiny"): + obj.tiny = tiny + if not hasattr(obj, "smallest_normal"): + obj.smallest_normal = tiny + obj.smallest_subnormal = obj._machar.smallest_subnormal + + float_to_str = str + obj._str_tiny = float_to_str(tiny) + obj._str_smallest_normal = float_to_str(tiny) + obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) + obj._str_max = float_to_str(obj.max) + obj._str_epsneg = float_to_str(obj.epsneg) + obj._str_eps = float_to_str(obj.eps) + obj._str_resolution = float_to_str(obj.resolution) + # pylint: enable=protected-access + return obj + @staticmethod def _float8_e3m4_finfo(): def float_to_str(f): @@ -472,71 +628,35 @@ def float_to_str(f): # pylint: enable=protected-access return obj + _finfo_type_map = { + bfloat16: _bfloat16_finfo, + float4_e2m1fn: _float4_e2m1fn_finfo, + float6_e2m3fn: _float6_e2m3fn_finfo, + float6_e3m2fn: _float6_e3m2fn_finfo, + float8_e3m4: _float8_e3m4_finfo, + float8_e4m3: _float8_e4m3_finfo, + float8_e4m3fn: _float8_e4m3fn_finfo, + float8_e4m3fnuz: _float8_e4m3fnuz_finfo, + float8_e4m3b11fnuz: _float8_e4m3b11fnuz_finfo, + float8_e5m2: _float8_e5m2_finfo, + float8_e5m2fnuz: _float8_e5m2fnuz_finfo, + } + _finfo_name_map = {t.__name__: t for t in _finfo_type_map} + def __new__(cls, dtype): - if ( - isinstance(dtype, str) - and dtype == "bfloat16" - or dtype == _bfloat16_dtype - ): - if _bfloat16_dtype not in cls._finfo_cache: - cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo() - return cls._finfo_cache[_bfloat16_dtype] - if ( - isinstance(dtype, str) - and dtype == "float8_e3m4" - or dtype == _float8_e3m4_dtype - ): - if _float8_e3m4_dtype not in cls._finfo_cache: - cls._finfo_cache[_float8_e3m4_dtype] = cls._float8_e3m4_finfo() - return cls._finfo_cache[_float8_e3m4_dtype] - if ( - isinstance(dtype, str) - and dtype == "float8_e4m3" - or dtype == _float8_e4m3_dtype - ): - if _float8_e4m3_dtype not in cls._finfo_cache: - cls._finfo_cache[_float8_e4m3_dtype] = cls._float8_e4m3_finfo() - return cls._finfo_cache[_float8_e4m3_dtype] - if ( - isinstance(dtype, str) - and dtype == "float8_e4m3b11fnuz" - or dtype == _float8_e4m3b11fnuz_dtype - ): - if _float8_e4m3b11fnuz_dtype not in cls._finfo_cache: - cls._finfo_cache[_float8_e4m3b11fnuz_dtype] = ( - cls._float8_e4m3b11fnuz_finfo() - ) - return cls._finfo_cache[_float8_e4m3b11fnuz_dtype] - if ( - isinstance(dtype, str) - and dtype == "float8_e4m3fn" - or dtype == _float8_e4m3fn_dtype - ): - if _float8_e4m3fn_dtype not in cls._finfo_cache: - cls._finfo_cache[_float8_e4m3fn_dtype] = cls._float8_e4m3fn_finfo() - return cls._finfo_cache[_float8_e4m3fn_dtype] - if ( - isinstance(dtype, str) - and dtype == "float8_e4m3fnuz" - or dtype == _float8_e4m3fnuz_dtype - ): - if _float8_e4m3fnuz_dtype not in cls._finfo_cache: - cls._finfo_cache[_float8_e4m3fnuz_dtype] = cls._float8_e4m3fnuz_finfo() - return cls._finfo_cache[_float8_e4m3fnuz_dtype] - if ( - isinstance(dtype, str) - and dtype == "float8_e5m2" - or dtype == _float8_e5m2_dtype - ): - if _float8_e5m2_dtype not in cls._finfo_cache: - cls._finfo_cache[_float8_e5m2_dtype] = cls._float8_e5m2_finfo() - return cls._finfo_cache[_float8_e5m2_dtype] - if ( - isinstance(dtype, str) - and dtype == "float8_e5m2fnuz" - or dtype == _float8_e5m2fnuz_dtype - ): - if _float8_e5m2fnuz_dtype not in cls._finfo_cache: - cls._finfo_cache[_float8_e5m2fnuz_dtype] = cls._float8_e5m2fnuz_finfo() - return cls._finfo_cache[_float8_e5m2fnuz_dtype] + key = ( + cls._finfo_name_map.get(dtype) + if isinstance(dtype, str) + else dtype.type + if isinstance(dtype, np.dtype) + else dtype + ) + finfo = cls._finfo_cache.get(key) + if finfo is not None: + return finfo + + init = cls._finfo_type_map.get(key) + if init is not None: + cls._finfo_cache[dtype] = init.__func__() + return cls._finfo_cache[dtype] return super().__new__(cls, dtype) diff --git a/ml_dtypes/_src/dtypes.cc b/ml_dtypes/_src/dtypes.cc index 287a60bf..39f3ffdc 100644 --- a/ml_dtypes/_src/dtypes.cc +++ b/ml_dtypes/_src/dtypes.cc @@ -35,6 +35,7 @@ limitations under the License. #include "_src/intn_numpy.h" #include "include/float8.h" #include "include/intn.h" +#include "include/mxfloat.h" namespace ml_dtypes { @@ -176,6 +177,45 @@ struct TypeDescriptor : CustomFloatType { static constexpr char kNpyDescrByteorder = '='; }; +template <> +struct TypeDescriptor : CustomFloatType { + typedef float6_e2m3fn T; + static constexpr bool is_floating = true; + static constexpr bool is_integral = false; + static constexpr const char* kTypeName = "float6_e2m3fn"; + static constexpr const char* kQualifiedTypeName = "ml_dtypes.float6_e2m3fn"; + static constexpr const char* kTpDoc = "float6_e2m3fn floating-point values"; + static constexpr char kNpyDescrKind = 'V'; + static constexpr char kNpyDescrType = '8'; + static constexpr char kNpyDescrByteorder = '='; +}; + +template <> +struct TypeDescriptor : CustomFloatType { + typedef float6_e3m2fn T; + static constexpr bool is_floating = true; + static constexpr bool is_integral = false; + static constexpr const char* kTypeName = "float6_e3m2fn"; + static constexpr const char* kQualifiedTypeName = "ml_dtypes.float6_e3m2fn"; + static constexpr const char* kTpDoc = "float6_e3m2fn floating-point values"; + static constexpr char kNpyDescrKind = 'V'; + static constexpr char kNpyDescrType = '9'; + static constexpr char kNpyDescrByteorder = '='; +}; + +template <> +struct TypeDescriptor : CustomFloatType { + typedef float4_e2m1fn T; + static constexpr bool is_floating = true; + static constexpr bool is_integral = false; + static constexpr const char* kTypeName = "float4_e2m1fn"; + static constexpr const char* kQualifiedTypeName = "ml_dtypes.float4_e2m1fn"; + static constexpr const char* kTpDoc = "float4_e2m1fn floating-point values"; + static constexpr char kNpyDescrKind = 'V'; + static constexpr char kNpyDescrType = '0'; + static constexpr char kNpyDescrByteorder = '='; +}; + template <> struct TypeDescriptor : IntNTypeDescriptor { typedef int2 T; @@ -278,6 +318,34 @@ bool RegisterOneWayCustomCast() { return true; } +// Register two-way floating point casts between the first and the other types. +template +bool RegisterTwoWayFloatCasts() { return true; } + +template +bool RegisterTwoWayFloatCasts() { + return RegisterTwoWayCustomCast() && + RegisterTwoWayFloatCasts(); +} + +// Register two-way floating point casts between all pairs of types. +template +bool RegisterAllFloatCasts() { return true; } + +template +bool RegisterAllFloatCasts() { + return RegisterTwoWayFloatCasts() && + RegisterAllFloatCasts(); +} + +// Initialize type attribute in the module object. +template +bool InitModuleType(PyObject* obj, const char* name) { + return PyObject_SetAttrString( + obj, name, + reinterpret_cast(TypeDescriptor::type_ptr)) >= 0; +} + } // namespace // Initializes the module. @@ -294,78 +362,33 @@ bool Initialize() { return false; } - if (!RegisterFloatDtype(numpy.get())) { - return false; - } - if (!RegisterFloatDtype(numpy.get())) { - return false; - } - if (!RegisterFloatDtype(numpy.get())) { - return false; - } - if (!RegisterFloatDtype(numpy.get())) { - return false; - } - if (!RegisterFloatDtype(numpy.get())) { - return false; - } - if (!RegisterFloatDtype(numpy.get())) { - return false; - } - if (!RegisterFloatDtype(numpy.get())) { - return false; - } - if (!RegisterFloatDtype(numpy.get())) { + if (!RegisterFloatDtype(numpy.get()) || + !RegisterFloatDtype(numpy.get()) || + !RegisterFloatDtype(numpy.get()) || + !RegisterFloatDtype(numpy.get()) || + !RegisterFloatDtype(numpy.get()) || + !RegisterFloatDtype(numpy.get()) || + !RegisterFloatDtype(numpy.get()) || + !RegisterFloatDtype(numpy.get()) || + !RegisterFloatDtype(numpy.get()) || + !RegisterFloatDtype(numpy.get()) || + !RegisterFloatDtype(numpy.get())) { return false; } - if (!RegisterIntNDtype(numpy.get())) { - return false; - } - if (!RegisterIntNDtype(numpy.get())) { - return false; - } - if (!RegisterIntNDtype(numpy.get())) { - return false; - } - if (!RegisterIntNDtype(numpy.get())) { + if (!RegisterIntNDtype(numpy.get()) || + !RegisterIntNDtype(numpy.get()) || + !RegisterIntNDtype(numpy.get()) || + !RegisterIntNDtype(numpy.get())) { return false; } // Register casts between pairs of custom float dtypes. - bool success = true; - success &= RegisterCustomFloatCast(); - success &= - RegisterTwoWayCustomCast(); - success &= RegisterCustomFloatCast(); - success &= - RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= - RegisterTwoWayCustomCast(); - success &= - RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); - success &= RegisterTwoWayCustomCast(); + bool success = + RegisterAllFloatCasts(); success &= RegisterOneWayCustomCast(); success &= RegisterOneWayCustomCast(); return success; @@ -396,68 +419,21 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() { return nullptr; } - if (PyObject_SetAttrString(m.get(), "float8_e3m4", - reinterpret_cast( - TypeDescriptor::type_ptr)) < 0) { - return nullptr; - } - if (PyObject_SetAttrString(m.get(), "float8_e4m3", - reinterpret_cast( - TypeDescriptor::type_ptr)) < 0) { - return nullptr; - } - if (PyObject_SetAttrString( - m.get(), "float8_e4m3b11fnuz", - reinterpret_cast( - TypeDescriptor::type_ptr)) < 0) { - return nullptr; - } - if (PyObject_SetAttrString(m.get(), "float8_e4m3fn", - reinterpret_cast( - TypeDescriptor::type_ptr)) < - 0) { - return nullptr; - } - if (PyObject_SetAttrString(m.get(), "float8_e4m3fnuz", - reinterpret_cast( - TypeDescriptor::type_ptr)) < - 0) { - return nullptr; - } - if (PyObject_SetAttrString(m.get(), "float8_e5m2", - reinterpret_cast( - TypeDescriptor::type_ptr)) < 0) { - return nullptr; - } - if (PyObject_SetAttrString(m.get(), "float8_e5m2fnuz", - reinterpret_cast( - TypeDescriptor::type_ptr)) < - 0) { - return nullptr; - } - if (PyObject_SetAttrString(m.get(), "bfloat16", - reinterpret_cast( - TypeDescriptor::type_ptr)) < 0) { - return nullptr; - } - if (PyObject_SetAttrString( - m.get(), "int2", - reinterpret_cast(TypeDescriptor::type_ptr)) < 0) { - return nullptr; - } - if (PyObject_SetAttrString( - m.get(), "int4", - reinterpret_cast(TypeDescriptor::type_ptr)) < 0) { - return nullptr; - } - if (PyObject_SetAttrString( - m.get(), "uint2", - reinterpret_cast(TypeDescriptor::type_ptr)) < 0) { - return nullptr; - } - if (PyObject_SetAttrString( - m.get(), "uint4", - reinterpret_cast(TypeDescriptor::type_ptr)) < 0) { + if (!InitModuleType(m.get(), "float4_e2m1fn") || + !InitModuleType(m.get(), "float6_e2m3fn") || + !InitModuleType(m.get(), "float6_e3m2fn") || + !InitModuleType(m.get(), "float8_e3m4") || + !InitModuleType(m.get(), "float8_e4m3") || + !InitModuleType(m.get(), "float8_e4m3b11fnuz") || + !InitModuleType(m.get(), "float8_e4m3fn") || + !InitModuleType(m.get(), "float8_e4m3fnuz") || + !InitModuleType(m.get(), "float8_e5m2") || + !InitModuleType(m.get(), "float8_e5m2fnuz") || + !InitModuleType(m.get(), "bfloat16") || + !InitModuleType(m.get(), "int2") || + !InitModuleType(m.get(), "int4") || + !InitModuleType(m.get(), "uint2") || + !InitModuleType(m.get(), "uint4")) { return nullptr; } diff --git a/ml_dtypes/_src/ufuncs.h b/ml_dtypes/_src/ufuncs.h index ef6f07e1..19a7a00e 100644 --- a/ml_dtypes/_src/ufuncs.h +++ b/ml_dtypes/_src/ufuncs.h @@ -325,12 +325,15 @@ std::pair, BitsType> SignAndMagnitude(T x) { // For types that represent NaN by -0, (i.e. *fnuz), abs(x) remains -0 without // flipping the sign. Therefore, we need to explicitly check the // most-significant bit. + // For types without NaNs (i.e. mxfloat), use xor to keep the sign bit, which + // may be not the most-significant bit. constexpr BitsType kSignMask = BitsType(1) << (sizeof(BitsType) * CHAR_BIT - 1); + constexpr bool has_nan = std::numeric_limits::has_quiet_NaN; const BitsType x_abs_bits = Eigen::numext::bit_cast>(Eigen::numext::abs(x)); const BitsType x_bits = Eigen::numext::bit_cast>(x); - return {x_bits & kSignMask, x_abs_bits}; + return {has_nan ? x_bits & kSignMask : x_bits ^ x_abs_bits, x_abs_bits}; } template @@ -705,6 +708,7 @@ struct Spacing { CopySign copysign; if constexpr (!std::numeric_limits::has_infinity) { if (Eigen::numext::abs(x) == std::numeric_limits::max()) { + if constexpr (!std::numeric_limits::has_quiet_NaN) return T(); return copysign(std::numeric_limits::quiet_NaN(), x); } } diff --git a/ml_dtypes/include/float8.h b/ml_dtypes/include/float8.h index f7bf4bde..06039615 100644 --- a/ml_dtypes/include/float8.h +++ b/ml_dtypes/include/float8.h @@ -59,6 +59,7 @@ class float8_base { constexpr float8_base(uint8_t rep, ConstructFromRepTag) : rep_{rep} {} public: + static constexpr int kBits = 8; constexpr float8_base() : rep_(0) {} template @@ -200,7 +201,7 @@ class float8_base { const uint8_t x_abs_bits = Eigen::numext::bit_cast(Eigen::numext::abs(x)); const uint8_t x_bits = Eigen::numext::bit_cast(x); - const uint8_t x_sign = x_bits ^ x_abs_bits; + const uint8_t x_sign = (x_bits ^ x_abs_bits) << (CHAR_BIT - Derived::kBits); return {x_sign, x_abs_bits}; } static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int8_t @@ -470,9 +471,11 @@ constexpr int MinExponent10FromMinExponent(int min_exponent) { // emax * log10(2)) constexpr int MaxExponent10FromMaxExponentAndDigits(int max_exponent, int digits) { - // We only support digits in {3,4}. This table would grow if we wanted to + // We only support digits in {2,5}. This table would grow if we wanted to // handle more values. constexpr double kLog10OfOnePredecessor[] = { + // log10(1 - 2**-2) + -0.12493873660829993, // log10(1 - 2**-3) -0.057991946977686754, // log10(1 - 2**-4) @@ -480,7 +483,7 @@ constexpr int MaxExponent10FromMaxExponentAndDigits(int max_exponent, // log10(1 - 2**-5) -0.013788284485633295, }; - return static_cast(ConstexprFloor(kLog10OfOnePredecessor[digits - 3] + + return static_cast(ConstexprFloor(kLog10OfOnePredecessor[digits - 2] + max_exponent * kLog10Of2)); } @@ -1186,7 +1189,6 @@ struct ConvertImpl>> { using FromTraits = Traits; using FromBits = typename FromTraits::BitsType; - static constexpr int kFromBits = FromTraits::kBits; static constexpr int kFromMantissaBits = FromTraits::kMantissaBits; static constexpr int kFromExponentBits = FromTraits::kExponentBits; static constexpr int kFromExponentBias = FromTraits::kExponentBias; @@ -1194,7 +1196,6 @@ struct ConvertImpl; using ToBits = typename ToTraits::BitsType; - static constexpr int kToBits = ToTraits::kBits; static constexpr int kToMantissaBits = ToTraits::kMantissaBits; static constexpr int kToExponentBits = ToTraits::kExponentBits; static constexpr int kToExponentBias = ToTraits::kExponentBias; @@ -1213,7 +1214,7 @@ struct ConvertImpl(from) >> (kFromBits - 1); + Eigen::numext::bit_cast(from) >> (FromTraits::kBits - 1); const FromBits from_bits = Eigen::numext::bit_cast(Eigen::numext::abs(from)); @@ -1241,8 +1242,9 @@ struct ConvertImpl +#include + +#include "include/float8.h" +#include "Eigen/Core" + +namespace ml_dtypes { +namespace mxfloat_internal { + +// Use 8-bit storage for 6-bit and 4-bit types. +template +class mxfloat6_base : public float8_internal::float8_base { + using Base = float8_internal::float8_base; + using Base::Base; + + public: + static constexpr int kBits = 6; + + explicit EIGEN_DEVICE_FUNC operator bool() const { + return (Base::rep() & 0x1F) != 0; + } + constexpr Derived operator-() const { + return Derived::FromRep(Base::rep() ^ 0x20); + } + Derived operator-(const Derived& other) const { + return Base::operator-(other); + } +}; + +template +class mxfloat4_base : public float8_internal::float8_base { + using Base = float8_internal::float8_base; + using Base::Base; + + public: + static constexpr int kBits = 4; + + explicit EIGEN_DEVICE_FUNC operator bool() const { + return (Base::rep() & 0x07) != 0; + } + constexpr Derived operator-() const { + return Derived::FromRep(Base::rep() ^ 0x08); + } + Derived operator-(const Derived& other) const { + return Base::operator-(other); + } +}; + +class float6_e2m3fn : public mxfloat6_base { + // Exponent: 2, Mantissa: 3, bias: 1. + // Extended range: no inf, no NaN. + using Base = mxfloat6_base; + using Base::Base; + + public: + template = 0> + explicit EIGEN_DEVICE_FUNC float6_e2m3fn(T f8) + : float6_e2m3fn(ConvertFrom(f8)) {} +}; + +class float6_e3m2fn : public mxfloat6_base { + // Exponent: 3, Mantissa: 2, bias: 3. + // Extended range: no inf, no NaN. + using Base = mxfloat6_base; + using Base::Base; + + public: + template = 0> + explicit EIGEN_DEVICE_FUNC float6_e3m2fn(T f8) + : float6_e3m2fn(ConvertFrom(f8)) {} +}; + +class float4_e2m1fn : public mxfloat4_base { + // Exponent: 2, Mantissa: 1, bias: 1. + // Extended range: no inf, no NaN. + using Base = mxfloat4_base; + using Base::Base; + + public: + template = 0> + explicit EIGEN_DEVICE_FUNC float4_e2m1fn(T f8) + : float4_e2m1fn(ConvertFrom(f8)) {} +}; + +// Common properties for specializing std::numeric_limits. +template +struct numeric_limits_mxfloat_tpl { + protected: + static constexpr int kExponentBias = (1 << (E - 1)) - 1; + static constexpr int kMantissaBits = M; + + public: + // NOLINTBEGIN: these names must match std::numeric_limits. + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = false; + static constexpr bool has_signaling_NaN = false; + static constexpr std::float_denorm_style has_denorm = std::denorm_present; + static constexpr bool has_denorm_loss = false; + static constexpr std::float_round_style round_style = std::round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = kMantissaBits + 1; + static constexpr int digits10 = float8_internal::Digits10FromDigits(digits); + static constexpr int max_digits10 = + float8_internal::MaxDigits10FromDigits(digits); + static constexpr int radix = std::numeric_limits::radix; + static constexpr int min_exponent = (1 - kExponentBias) + 1; + static constexpr int min_exponent10 = + float8_internal::MinExponent10FromMinExponent(min_exponent); + static constexpr int max_exponent = kExponentBias + 2; + static constexpr int max_exponent10 = + float8_internal::MaxExponent10FromMaxExponentAndDigits(max_exponent, + digits); + static constexpr bool traps = std::numeric_limits::traps; + static constexpr bool tinyness_before = + std::numeric_limits::tinyness_before; + // NOLINTEND +}; + +struct numeric_limits_float6_e2m3fn : public numeric_limits_mxfloat_tpl<2, 3> { + // 1.0 * 2^(0) = 1 + static constexpr float6_e2m3fn min() { + return float6_e2m3fn::FromRep(0b0'01'000); + } + // -1.875 * 2^(2) = -7.5 + static constexpr float6_e2m3fn lowest() { + return float6_e2m3fn::FromRep(0b1'11'111); + } + // 1.875 * 2^(2) = 7.5 + static constexpr float6_e2m3fn max() { + return float6_e2m3fn::FromRep(0b0'11'111); + } + // 0.125 * 2^(0) = 0.125 + static constexpr float6_e2m3fn epsilon() { + return float6_e2m3fn::FromRep(0b0'00'001); + } + // 0.25 * 2^(0) = 0.25 + static constexpr float6_e2m3fn round_error() { + return float6_e2m3fn::FromRep(0b0'00'010); + } + // 0.25 * 2^(0) = 0.125 + static constexpr float6_e2m3fn denorm_min() { + return float6_e2m3fn::FromRep(0b0'00'001); + } + + // Conversion from NaNs is implementation-defined (by MX specification). + static constexpr float6_e2m3fn quiet_NaN() { + return float6_e2m3fn::FromRep(0b1'00'000); + } + static constexpr float6_e2m3fn signaling_NaN() { + return float6_e2m3fn::FromRep(0b1'00'000); + } + static constexpr float6_e2m3fn infinity() { + return float6_e2m3fn::FromRep(0b0'11'111); + } +}; + +struct numeric_limits_float6_e3m2fn : public numeric_limits_mxfloat_tpl<3, 2> { + // 1.0 * 2^(-2) = 0.25 + static constexpr float6_e3m2fn min() { + return float6_e3m2fn::FromRep(0b0'001'00); + } + // -1.75 * 2^(4) = -28 + static constexpr float6_e3m2fn lowest() { + return float6_e3m2fn::FromRep(0b1'111'11); + } + // 1.75 * 2^(4) = 28 + static constexpr float6_e3m2fn max() { + return float6_e3m2fn::FromRep(0b0'111'11); + } + // 1.0 * 2^(-2) = 0.25 + static constexpr float6_e3m2fn epsilon() { + return float6_e3m2fn::FromRep(0b0'001'00); + } + // 1.0 * 2^(0) = 1 + static constexpr float6_e3m2fn round_error() { + return float6_e3m2fn::FromRep(0b0'011'00); + } + // 0.25 * 2^(-2) = 0.0625 + static constexpr float6_e3m2fn denorm_min() { + return float6_e3m2fn::FromRep(0b0'000'01); + } + + // Conversion from NaNs is implementation-defined (by MX specification). + static constexpr float6_e3m2fn quiet_NaN() { + return float6_e3m2fn::FromRep(0b1'000'00); + } + static constexpr float6_e3m2fn signaling_NaN() { + return float6_e3m2fn::FromRep(0b1'000'00); + } + static constexpr float6_e3m2fn infinity() { + return float6_e3m2fn::FromRep(0b0'111'11); + } +}; + +struct numeric_limits_float4_e2m1fn : public numeric_limits_mxfloat_tpl<2, 1> { + // 1.0 * 2^(0) = 1 + static constexpr float4_e2m1fn min() { + return float4_e2m1fn::FromRep(0b0'01'0); + } + // -1.5 * 2^(2) = -6 + static constexpr float4_e2m1fn lowest() { + return float4_e2m1fn::FromRep(0b1'11'1); + } + // 1.5 * 2^(2) = 6 + static constexpr float4_e2m1fn max() { + return float4_e2m1fn::FromRep(0b0'11'1); + } + // 0.5 * 2^(0) = 0.5 + static constexpr float4_e2m1fn epsilon() { + return float4_e2m1fn::FromRep(0b0'00'1); + } + // 1.0 * 2^(0) = 1 + static constexpr float4_e2m1fn round_error() { + return float4_e2m1fn::FromRep(0b0'01'0); + } + // 0.5 * 2^(0) = 0.5 + static constexpr float4_e2m1fn denorm_min() { + return float4_e2m1fn::FromRep(0b0'00'1); + } + + // Conversion from NaNs is implementation-defined (by MX specification). + static constexpr float4_e2m1fn quiet_NaN() { + return float4_e2m1fn::FromRep(0b1'00'0); + } + static constexpr float4_e2m1fn signaling_NaN() { + return float4_e2m1fn::FromRep(0b1'00'0); + } + static constexpr float4_e2m1fn infinity() { + return float4_e2m1fn::FromRep(0b0'11'1); + } +}; + +// Free-functions for use with ADL and in Eigen. +constexpr inline float6_e2m3fn abs(const float6_e2m3fn& a) { + return float6_e2m3fn::FromRep(a.rep() & 0b0'11'111); +} + +constexpr inline bool(isnan)(const float6_e2m3fn& a) { return false; } + +constexpr inline float6_e3m2fn abs(const float6_e3m2fn& a) { + return float6_e3m2fn::FromRep(a.rep() & 0b0'111'11); +} + +constexpr inline bool(isnan)(const float6_e3m2fn& a) { return false; } + +constexpr inline float4_e2m1fn abs(const float4_e2m1fn& a) { + return float4_e2m1fn::FromRep(a.rep() & 0b0'11'1); +} + +constexpr inline bool(isnan)(const float4_e2m1fn& a) { return false; } + +// Define traits required for floating point conversion. +template +struct TraitsBase : public float8_internal::TraitsBase { + static constexpr int kBits = E + M + 1; + static constexpr int kMantissaBits = M; + static constexpr int kExponentBits = E; + static constexpr int kExponentBias = (1 << (E - 1)) - 1; + static constexpr uint8_t kExponentMask = ((1 << E) - 1) << M; +}; + +} // namespace mxfloat_internal + +// Exported types. +using float6_e2m3fn = mxfloat_internal::float6_e2m3fn; +using float6_e3m2fn = mxfloat_internal::float6_e3m2fn; +using float4_e2m1fn = mxfloat_internal::float4_e2m1fn; + +} // namespace ml_dtypes + +// Standard library overrides. +namespace std { + +template <> +struct numeric_limits + : public ml_dtypes::mxfloat_internal::numeric_limits_float6_e2m3fn {}; + +template <> +struct numeric_limits + : public ml_dtypes::mxfloat_internal::numeric_limits_float6_e3m2fn {}; + +template <> +struct numeric_limits + : public ml_dtypes::mxfloat_internal::numeric_limits_float4_e2m1fn {}; + +} // namespace std + +// Conversion traits. +namespace ml_dtypes { +namespace float8_internal { + +template <> +struct Traits + : public mxfloat_internal::TraitsBase {}; + +template <> +struct Traits + : public mxfloat_internal::TraitsBase {}; + +template <> +struct Traits + : public mxfloat_internal::TraitsBase {}; + +} // namespace float8_internal +} // namespace ml_dtypes + +// Eigen library overrides. +namespace Eigen { +namespace numext { + +#define MXFLOAT_EIGEN_BITCAST_IMPL(Type) \ + template <> \ + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint8_t bit_cast( \ + const Type& x) { \ + return x.rep(); \ + } \ + template <> \ + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Type bit_cast( \ + const uint8_t& x) { \ + return Type::FromRep(x); \ + } + +MXFLOAT_EIGEN_BITCAST_IMPL(ml_dtypes::float6_e2m3fn) +MXFLOAT_EIGEN_BITCAST_IMPL(ml_dtypes::float6_e3m2fn) +MXFLOAT_EIGEN_BITCAST_IMPL(ml_dtypes::float4_e2m1fn) + +#undef MXFLOAT_EIGEN_BITCAST_IMPL + +} // namespace numext + +// Work-around for isinf/isnan/isfinite issue on aarch64. +namespace internal { + +#define MXFLOAT_EIGEN_ISFINITE_IMPL(Type) \ + template <> \ + EIGEN_DEVICE_FUNC inline bool isinf_impl(const Type&) { \ + return false; \ + } \ + template <> \ + EIGEN_DEVICE_FUNC inline bool isnan_impl(const Type&) { \ + return false; \ + } \ + template <> \ + EIGEN_DEVICE_FUNC inline bool isfinite_impl(const Type&) { \ + return true; \ + } + +MXFLOAT_EIGEN_ISFINITE_IMPL(ml_dtypes::float6_e2m3fn) +MXFLOAT_EIGEN_ISFINITE_IMPL(ml_dtypes::float6_e3m2fn) +MXFLOAT_EIGEN_ISFINITE_IMPL(ml_dtypes::float4_e2m1fn) + +#undef MXFLOAT_EIGEN_ISFINITE_IMPL + +} // namespace internal +} // namespace Eigen + +#endif // ML_DTYPES_MXFLOAT_H_ diff --git a/ml_dtypes/tests/custom_float_test.py b/ml_dtypes/tests/custom_float_test.py index 00f9b1a0..f30d7a7f 100644 --- a/ml_dtypes/tests/custom_float_test.py +++ b/ml_dtypes/tests/custom_float_test.py @@ -31,6 +31,9 @@ import numpy as np bfloat16 = ml_dtypes.bfloat16 +float4_e2m1fn = ml_dtypes.float4_e2m1fn +float6_e2m3fn = ml_dtypes.float6_e2m3fn +float6_e3m2fn = ml_dtypes.float6_e3m2fn float8_e3m4 = ml_dtypes.float8_e3m4 float8_e4m3 = ml_dtypes.float8_e4m3 float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz @@ -89,7 +92,7 @@ def binary_operation_test(a, b, op, float_type): expected = op(np.float32(a), np.float32(b)) result = op(a, b) if math.isnan(expected): - if not math.isnan(result): + if dtype_has_nan(float_type) and not math.isnan(result): raise AssertionError("%s expected to be nan." % repr(result)) else: np.testing.assert_equal( @@ -99,18 +102,25 @@ def binary_operation_test(a, b, op, float_type): def dtype_has_inf(dtype): """Determines if the dtype has an `inf` representation.""" - inf = float("inf") - is_inf = False try: - x = dtype(inf) - is_inf = np.isinf(x) + return np.isinf(dtype(float("inf"))) except (OverflowError, ValueError): - pass - return is_inf + return False + + +def dtype_has_nan(dtype): + """Determines if the dtype has an `nan` representation.""" + try: + return np.isnan(dtype(float("nan"))) + except (OverflowError, ValueError): + return False FLOAT_DTYPES = [ bfloat16, + float4_e2m1fn, + float6_e2m3fn, + float6_e3m2fn, float8_e3m4, float8_e4m3, float8_e4m3b11fnuz, @@ -140,17 +150,25 @@ def dtype_has_inf(dtype): 7, float(ml_dtypes.finfo(dtype).max), -float(ml_dtypes.finfo(dtype).max), - float("nan"), - float("-nan"), + float("nan") if dtype_has_nan(dtype) else 0.0, + float("-nan") if dtype_has_nan(dtype) else 0.0, float("inf") if dtype_has_inf(dtype) else 0.0, float("-inf") if dtype_has_inf(dtype) else 0.0, ] for dtype in FLOAT_DTYPES } +# Remove values unsupported by some types. +FLOAT_VALUES[float4_e2m1fn] = [ + x for x in FLOAT_VALUES[float4_e2m1fn] if x not in {3.5, 5, 7} +] + # Values that should round trip exactly to integer and back. INT_VALUES = { bfloat16: [0, 1, 2, 10, 34, 47, 128, 255, 256, 512], + float4_e2m1fn: [0, 1, 2, 3, 4, 6], + float6_e2m3fn: [0, 1, 2, 3, 4, 5, 6, 7], + float6_e3m2fn: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28], float8_e3m4: list( itertools.chain.from_iterable( range(1 << n, 2 << n, 1 << max(0, n - 3)) for n in range(4) @@ -184,17 +202,6 @@ def dtype_has_inf(dtype): ), } -BITS_TYPE = { - bfloat16: np.uint16, - float8_e3m4: np.uint8, - float8_e4m3: np.uint8, - float8_e4m3b11fnuz: np.uint8, - float8_e4m3fn: np.uint8, - float8_e4m3fnuz: np.uint8, - float8_e5m2: np.uint8, - float8_e5m2fnuz: np.uint8, -} - # pylint: disable=g-complex-comprehension @multi_threaded( @@ -282,8 +289,9 @@ def testStr(self, float_type): def testFromStr(self, float_type): self.assertEqual(float_type(1.2), float_type("1.2")) - self.assertTrue(np.isnan(float_type("nan"))) - self.assertTrue(np.isnan(float_type("-nan"))) + if dtype_has_nan(float_type): + self.assertTrue(np.isnan(float_type("nan"))) + self.assertTrue(np.isnan(float_type("-nan"))) if dtype_has_inf(float_type): self.assertEqual(float_type(float("inf")), float_type("inf")) self.assertEqual(float_type(float("-inf")), float_type("-inf")) @@ -451,6 +459,9 @@ def testNotEqual(self, float_type): self.assertIsInstance(result, np.bool_) def testNan(self, float_type): + if not dtype_has_nan(float_type): + self.skipTest("no NaN encoding") + a = np.isnan(float_type(float("nan"))) self.assertTrue(a) numpy_assert_allclose( @@ -496,6 +507,9 @@ def testArgmax(self, float_type): def testArgmaxOnNan(self, float_type): """Ensures we return the right thing for multiple NaNs.""" + if not dtype_has_nan(float_type): + self.skipTest("no NaN encoding") + one_with_nans = np.array( [1.0, float("nan"), float("nan")], dtype=np.float32 ) @@ -657,9 +671,7 @@ def testArray(self, float_type): self.assertTrue((x == x).all()) def testComparisons(self, float_type): - x0, x1, y0 = 30, 7, 17 - if float_type == ml_dtypes.float8_e3m4: - x0, x1, y0 = 15, 3, 9 + x0, x1, y0 = 6, 1, 3 x = np.array([x0, x1, -x0], dtype=np.float32) bx = x.astype(float_type) y = np.array([y0, x1, 0], dtype=np.float32) @@ -672,8 +684,8 @@ def testComparisons(self, float_type): np.testing.assert_equal(x >= y, bx >= by) def testEqual2(self, float_type): - a = np.array([31], float_type) - b = np.array([15], float_type) + a = np.array([7], float_type) + b = np.array([3], float_type) self.assertFalse(a.__eq__(b)) def testCanCast(self, float_type): @@ -756,7 +768,7 @@ def testCasts(self, float_type): @ignore_warning(category=ComplexWarning) def testConformNumpyComplex(self, float_type): for dtype in [np.complex64, np.complex128, np.clongdouble]: - x = np.array([1.5, 2.5 + 2.0j, 3.5], dtype=dtype) + x = np.array([0.5, 1.5 + 2.0j, 4.0], dtype=dtype) y_np = x.astype(np.float32) y_tf = x.astype(float_type) numpy_assert_allclose(y_np, y_tf, atol=2e-2, float_type=float_type) @@ -771,19 +783,12 @@ def testArange(self, float_type): np.arange(100, dtype=float_type), ) np.testing.assert_equal( - np.arange(-8, 8, 1, dtype=np.float32).astype(float_type), - np.arange(-8, 8, 1, dtype=float_type), + np.arange(-6, 6, 2, dtype=np.float32).astype(float_type), + np.arange(-6, 6, 2, dtype=float_type), ) np.testing.assert_equal( - np.arange(-0.0, -2.0, -0.25, dtype=np.float32).astype(float_type), - np.arange(-0.0, -2.0, -0.25, dtype=float_type), - ) - m = 16 - if float_type == ml_dtypes.float8_e3m4: - m = 14 - np.testing.assert_equal( - np.arange(-m, m, 2.0, dtype=np.float32).astype(float_type), - np.arange(-m, m, 2.0, dtype=float_type), + np.arange(-0.0, -2.0, -0.5, dtype=np.float32).astype(float_type), + np.arange(-0.0, -2.0, -0.5, dtype=float_type), ) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") @@ -847,6 +852,7 @@ def testDivmod(self, float_type): rng = np.random.RandomState(seed=42) x = rng.randn(3, 7).astype(float_type) y = rng.randn(4, 1, 7).astype(float_type) + y = np.where(y == 0, float_type(1), y) o1, o2 = np.divmod(x, y) e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32)) numpy_assert_allclose( @@ -941,46 +947,54 @@ def testFrexp(self, float_type): mant1, exp1 = np.frexp(x) mant2, exp2 = np.frexp(x.astype(np.float32)) np.testing.assert_equal(exp1, exp2) - numpy_assert_allclose(mant1, mant2, rtol=1e-2, float_type=float_type) + + kwargs = {"rtol": 0.01} + if float_type == float6_e2m3fn: + kwargs = {"rtol": 0.1} + elif float_type == float4_e2m1fn: + kwargs = {"atol": 0.25} + numpy_assert_allclose(mant1, mant2, float_type=float_type, **kwargs) def testCopySign(self, float_type): - for bits in list(range(1, 128)): + bits_type = np.uint16 if float_type == bfloat16 else np.uint8 + bit_size = ml_dtypes.finfo(float_type).bits + bit_sign = 1 << (bit_size - 1) + + for bits in range(1, min(bit_sign, 256)): with self.subTest(bits): - bits_type = BITS_TYPE[float_type] val = bits_type(bits).view(float_type) val_with_sign = np.copysign(val, float_type(-1)) val_with_sign_bits = val_with_sign.view(bits_type) - num_bits = np.iinfo(bits_type).bits - np.testing.assert_equal( - bits | (1 << (num_bits - 1)), val_with_sign_bits - ) + self.assertEqual(bits | bit_sign, val_with_sign_bits) def testNextAfter(self, float_type): one = np.array(1.0, dtype=float_type) two = np.array(2.0, dtype=float_type) zero = np.array(0.0, dtype=float_type) - nan = np.array(np.nan, dtype=float_type) np.testing.assert_equal( np.nextafter(one, two) - one, ml_dtypes.finfo(float_type).eps ) np.testing.assert_equal( - np.nextafter(one, zero) - one, -ml_dtypes.finfo(float_type).eps / 2 + np.nextafter(one, zero) - one, -ml_dtypes.finfo(float_type).epsneg ) - np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True) - np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True) np.testing.assert_equal(np.nextafter(one, one), one) smallest_denormal = ml_dtypes.finfo(float_type).smallest_subnormal np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal) np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal) - for a, b in itertools.permutations([0.0, nan], 2): - np.testing.assert_equal( - np.nextafter( - np.array(a, dtype=np.float32), np.array(b, dtype=np.float32) - ), - np.nextafter( - np.array(a, dtype=float_type), np.array(b, dtype=float_type) - ), - ) + + if dtype_has_nan(float_type): + nan = np.array(np.nan, dtype=float_type) + np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True) + np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True) + for a, b in itertools.permutations([0.0, nan], 2): + np.testing.assert_equal( + np.nextafter( + np.array(a, dtype=np.float32), np.array(b, dtype=np.float32) + ), + np.nextafter( + np.array(a, dtype=float_type), np.array(b, dtype=float_type) + ), + ) @ignore_warning(category=RuntimeWarning, message="invalid value encountered") def testSpacing(self, float_type): @@ -1014,13 +1028,19 @@ def testSpacing(self, float_type): nextup = np.nextafter(x_float_type, toward) if np.isnan(spacing): self.assertTrue(np.isnan(nextup - x_float_type)) - else: + elif spacing: np.testing.assert_equal(spacing, nextup - x_float_type) + else: + # If type has no NaN or infinity, spacing of the maximum value is + # expected to be zero (next value does not exist). + self.assertFalse(dtype_has_nan(float_type)) + self.assertEqual(abs(x_float_type), ml_dtypes.finfo(float_type).max) # Check that spacing for special values gives the correct answer. with self.subTest(name="NonFinite"): - nan = float_type(float("nan")) - np.testing.assert_equal(np.spacing(nan), np.spacing(np.float32(nan))) + if dtype_has_nan(float_type): + nan = float_type(float("nan")) + np.testing.assert_equal(np.spacing(nan), np.spacing(np.float32(nan))) if dtype_has_inf(float_type): inf = float_type(float("inf")) np.testing.assert_equal(np.spacing(inf), np.spacing(np.float32(inf))) diff --git a/ml_dtypes/tests/finfo_test.py b/ml_dtypes/tests/finfo_test.py index c7135fc7..0823b471 100644 --- a/ml_dtypes/tests/finfo_test.py +++ b/ml_dtypes/tests/finfo_test.py @@ -20,6 +20,9 @@ ALL_DTYPES = [ ml_dtypes.bfloat16, + ml_dtypes.float4_e2m1fn, + ml_dtypes.float6_e2m3fn, + ml_dtypes.float6_e3m2fn, ml_dtypes.float8_e3m4, ml_dtypes.float8_e4m3, ml_dtypes.float8_e4m3b11fnuz, @@ -36,7 +39,15 @@ ml_dtypes.float8_e5m2fnuz, ] +DTYPES_WITH_NO_INFINITY_AND_NO_NAN = [ + ml_dtypes.float4_e2m1fn, + ml_dtypes.float6_e2m3fn, + ml_dtypes.float6_e3m2fn, +] + UINT_TYPES = { + 4: np.uint8, + 6: np.uint8, 8: np.uint8, 16: np.uint16, } @@ -70,7 +81,9 @@ def assert_representable(val): def assert_infinite(val): val = make_val(val) - if dtype in DTYPES_WITH_NO_INFINITY: + if dtype in DTYPES_WITH_NO_INFINITY_AND_NO_NAN: + self.assertEqual(val, info.max) + elif dtype in DTYPES_WITH_NO_INFINITY: self.assertTrue(np.isnan(val), f"expected NaN, got {val}") else: self.assertTrue(np.isposinf(val), f"expected inf, got {val}") @@ -81,16 +94,17 @@ def assert_zero(val): self.assertEqual(np.array(0, dtype).dtype, dtype) self.assertIs(info.dtype, dtype) - self.assertEqual(info.bits, np.array(0, dtype).itemsize * 8) + if info.bits >= 8: + self.assertEqual(info.bits, np.array(0, dtype).itemsize * 8) self.assertEqual(info.nmant + info.nexp + 1, info.bits) assert_representable(info.tiny) - assert_representable(info.max) - assert_infinite(np.spacing(info.max)) - assert_representable(info.min) - assert_infinite(-np.spacing(info.min)) + + if dtype not in DTYPES_WITH_NO_INFINITY_AND_NO_NAN: + assert_infinite(np.spacing(info.max)) + assert_infinite(-np.spacing(info.min)) assert_representable(2.0 ** (info.maxexp - 1)) assert_infinite(2.0**info.maxexp) diff --git a/ml_dtypes/tests/mxfloat_test.cc b/ml_dtypes/tests/mxfloat_test.cc new file mode 100644 index 00000000..834c1055 --- /dev/null +++ b/ml_dtypes/tests/mxfloat_test.cc @@ -0,0 +1,331 @@ +/* Copyright 2024 The ml_dtypes Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "include/mxfloat.h" + +#include + +namespace ml_dtypes { +namespace { + +TEST(FloatMXe2m3Test, NumericLimits) { + using limits = std::numeric_limits; + EXPECT_EQ(static_cast(limits::min()), 1.0); + EXPECT_EQ(static_cast(limits::max()), 7.5); + EXPECT_EQ(static_cast(limits::lowest()), -7.5); + EXPECT_EQ(static_cast(limits::epsilon()), 0.125); + EXPECT_EQ(static_cast(limits::round_error()), 0.25); + EXPECT_EQ(static_cast(limits::denorm_min()), 0.125); + EXPECT_EQ(limits::digits, 4); + EXPECT_EQ(limits::digits10, 0); + EXPECT_EQ(limits::max_digits10, 3); + EXPECT_EQ(limits::min_exponent, 1); + EXPECT_EQ(limits::min_exponent10, 0); + EXPECT_EQ(limits::max_exponent, 3); + EXPECT_EQ(limits::max_exponent10, 0); + EXPECT_EQ(limits::is_iec559, false); + EXPECT_EQ(limits::has_infinity, false); + EXPECT_EQ(limits::has_quiet_NaN, false); + EXPECT_EQ(limits::has_signaling_NaN, false); +} + +TEST(FloatMXe3m2Test, NumericLimits) { + using limits = std::numeric_limits; + EXPECT_EQ(static_cast(limits::min()), 0.25); + EXPECT_EQ(static_cast(limits::max()), 28.0); + EXPECT_EQ(static_cast(limits::lowest()), -28.0); + EXPECT_EQ(static_cast(limits::epsilon()), 0.25); + EXPECT_EQ(static_cast(limits::round_error()), 1.0); + EXPECT_EQ(static_cast(limits::denorm_min()), 0.0625); + EXPECT_EQ(limits::digits, 3); + EXPECT_EQ(limits::digits10, 0); + EXPECT_EQ(limits::max_digits10, 2); + EXPECT_EQ(limits::min_exponent, -1); + EXPECT_EQ(limits::min_exponent10, 0); + EXPECT_EQ(limits::max_exponent, 5); + EXPECT_EQ(limits::max_exponent10, 1); + EXPECT_EQ(limits::is_iec559, false); + EXPECT_EQ(limits::has_infinity, false); + EXPECT_EQ(limits::has_quiet_NaN, false); + EXPECT_EQ(limits::has_signaling_NaN, false); +} + +TEST(Float4e2m1Test, NumericLimits) { + using limits = std::numeric_limits; + EXPECT_EQ(static_cast(limits::min()), 1.0); + EXPECT_EQ(static_cast(limits::max()), 6.0); + EXPECT_EQ(static_cast(limits::lowest()), -6.0); + EXPECT_EQ(static_cast(limits::epsilon()), 0.5); + EXPECT_EQ(static_cast(limits::round_error()), 1.0); + EXPECT_EQ(static_cast(limits::denorm_min()), 0.5); + EXPECT_EQ(limits::digits, 2); + EXPECT_EQ(limits::digits10, 0); + EXPECT_EQ(limits::max_digits10, 2); + EXPECT_EQ(limits::min_exponent, 1); + EXPECT_EQ(limits::min_exponent10, 0); + EXPECT_EQ(limits::max_exponent, 3); + EXPECT_EQ(limits::max_exponent10, 0); + EXPECT_EQ(limits::is_iec559, false); + EXPECT_EQ(limits::has_infinity, false); + EXPECT_EQ(limits::has_quiet_NaN, false); + EXPECT_EQ(limits::has_signaling_NaN, false); +} + +template +constexpr int NumValues() { + return 1 << T::kBits; +} + +template +class FloatMXTest : public ::testing::Test {}; + +struct FloatMXTestNameGenerator { + template + static std::string GetName(int) { + if constexpr (std::is_same_v) return "float6_e2m3fn"; + if constexpr (std::is_same_v) return "float6_e3m2fn"; + if constexpr (std::is_same_v) return "float4_e2m1fn"; + } +}; + +using FloatMXTypes = + ::testing::Types; +TYPED_TEST_SUITE(FloatMXTest, FloatMXTypes, FloatMXTestNameGenerator); + +TYPED_TEST(FloatMXTest, NoInfinity) { + using FloatMX = TypeParam; + + EXPECT_EQ(static_cast(INFINITY), + std::numeric_limits::max()); + EXPECT_EQ(static_cast(-INFINITY), + std::numeric_limits::lowest()); +} + +TYPED_TEST(FloatMXTest, Negate) { + using FloatMX = TypeParam; + + int sign_bit = 1 << (FloatMX::kBits - 1); + for (int i = 0; i < sign_bit; ++i) { + FloatMX pos = FloatMX::FromRep(i); + FloatMX neg = FloatMX::FromRep(i | sign_bit); + EXPECT_EQ((-pos).rep(), neg.rep()); + EXPECT_EQ((-neg).rep(), pos.rep()); + } +} + +TYPED_TEST(FloatMXTest, BitCasts) { + using FloatMX = TypeParam; + + FloatMX x = FloatMX::FromRep(0x11); + EXPECT_EQ(Eigen::numext::bit_cast(x), x.rep()); + EXPECT_EQ(Eigen::numext::bit_cast(x.rep()), x); +} + +TYPED_TEST(FloatMXTest, UpCasts) { + using FloatMX = TypeParam; + + for (int i = 0; i < NumValues(); ++i) { + FloatMX mx = FloatMX::FromRep(i); + + double f64 = static_cast(mx); + float f32 = static_cast(mx); + Eigen::bfloat16 bf16 = static_cast(mx); + Eigen::half f16 = static_cast(mx); + + EXPECT_EQ(f64, f32) << i; + EXPECT_EQ(f32, bf16) << i; + EXPECT_EQ(bf16, f16) << i; + } +} + +TYPED_TEST(FloatMXTest, DownCasts) { + using FloatMX = TypeParam; + + for (int i = 0; i < NumValues(); ++i) { + float x = static_cast(FloatMX::FromRep(i)); + + FloatMX f64 = static_cast(static_cast(x)); + FloatMX f32 = static_cast(static_cast(x)); + FloatMX bf16 = static_cast(static_cast(x)); + FloatMX f16 = static_cast(static_cast(x)); + + EXPECT_EQ(f64.rep(), i); + EXPECT_EQ(f32.rep(), i); + EXPECT_EQ(bf16.rep(), i); + EXPECT_EQ(f16.rep(), i); + } +} + +TYPED_TEST(FloatMXTest, ConvertFromWithSaturation) { + using FloatMX = TypeParam; + + FloatMX upper = + FloatMX::template ConvertFrom( + static_cast(std::numeric_limits::max()) * 2); + EXPECT_EQ(upper, std::numeric_limits::max()); + + FloatMX lower = + FloatMX::template ConvertFrom( + static_cast(std::numeric_limits::lowest()) * 2); + EXPECT_EQ(lower, std::numeric_limits::lowest()); +} + +TYPED_TEST(FloatMXTest, ConvertFromWithTruncation) { + using FloatMX = TypeParam; + + // Truncation and rounding of a number ever-so-slightly less than 2. + float less_than_two = Eigen::numext::bit_cast(0x3FFFFFFF); + FloatMX truncated = + FloatMX::template ConvertFrom( + less_than_two); + EXPECT_LT(static_cast(truncated), 2); + + FloatMX rounded = + FloatMX::template ConvertFrom( + less_than_two); + EXPECT_EQ(static_cast(rounded), 2); + + // Truncation and rounding of a subnormal. + int digits = std::numeric_limits::digits; + for (int i = 1; i < (1 << (digits - 1)); ++i) { + float less_than_subnorm = + std::nexttoward(static_cast(FloatMX::FromRep(i)), 0); + + FloatMX truncated_subnorm = + FloatMX::template ConvertFrom( + less_than_subnorm); + EXPECT_EQ(truncated_subnorm.rep(), i - 1); + + FloatMX rounded_subnorm = + FloatMX::template ConvertFrom( + less_than_subnorm); + EXPECT_EQ(rounded_subnorm.rep(), i); + } +} + +TYPED_TEST(FloatMXTest, ConvertFromRoundToNearest) { + using FloatMX = TypeParam; + + // Try all pairs of values and check the middle point (which should be exactly + // representable as a float), as well as adjacent values. + for (int i = 1; i < NumValues(); ++i) { + FloatMX left = FloatMX::FromRep(i - 1); + FloatMX right = FloatMX::FromRep(i); + if (!right) continue; // Skip jump to negative zero. + + float l = static_cast(left); + float r = static_cast(right); + float m = (l + r) / 2; + float m_minus_eps = std::nexttoward(m, l); + float m_plus_eps = std::nexttoward(m, r); + + EXPECT_EQ(static_cast(m).rep(), i & 1 ? left.rep() : right.rep()); + EXPECT_EQ(static_cast(m_minus_eps).rep(), left.rep()); + EXPECT_EQ(static_cast(m_plus_eps).rep(), right.rep()); + } +} + +TYPED_TEST(FloatMXTest, CompareOperator) { + using FloatMX = TypeParam; + + for (int i = 0; i < NumValues(); ++i) { + FloatMX a = FloatMX::FromRep(i); + for (int j = 0; j < NumValues(); ++j) { + FloatMX b = FloatMX::FromRep(j); + + EXPECT_EQ(a == b, float{a} == float{b}); + EXPECT_EQ(a != b, float{a} != float{b}); + EXPECT_EQ(a < b, float{a} < float{b}); + EXPECT_EQ(a <= b, float{a} <= float{b}); + EXPECT_EQ(a > b, float{a} > float{b}); + EXPECT_EQ(a >= b, float{a} >= float{b}); + } + } +} + +#define GEN_FLOAT_TYPE_PAIRS(Type) \ + std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair, \ + std::pair, std::pair + +#define GEN_TEST_TYPE_PAIRS() \ + GEN_FLOAT_TYPE_PAIRS(float6_e2m3fn), GEN_FLOAT_TYPE_PAIRS(float6_e3m2fn), \ + GEN_FLOAT_TYPE_PAIRS(float4_e2m1fn), \ + std::pair, \ + std::pair, \ + std::pair + +template +class FloatMXCastTest : public ::testing::Test {}; + +struct FloatMXCastTestNameGenerator { + template + static std::string GetName(int) { + std::string first_name = + ::testing::internal::GetTypeName(); + std::string second_name = + ::testing::internal::GetTypeName(); + return first_name + "_" + second_name; + } +}; + +using FloatMXCastTypePairs = ::testing::Types; +TYPED_TEST_SUITE(FloatMXCastTest, FloatMXCastTypePairs, + FloatMXCastTestNameGenerator); + +TYPED_TEST(FloatMXCastTest, FromFloatMX) { + using FloatMX = typename TypeParam::first_type; + using DestType = typename TypeParam::second_type; + + for (int i = 0; i < NumValues(); ++i) { + FloatMX mx = FloatMX::FromRep(i); + DestType converted = static_cast(mx); + DestType expected = static_cast(static_cast(mx)); + EXPECT_EQ(converted, expected); + } +} + +TYPED_TEST(FloatMXCastTest, ToFloatMX) { + using FloatMX = typename TypeParam::first_type; + using SrcType = typename TypeParam::second_type; + using SrcTraits = typename float8_internal::Traits; + + // For float8, iterate over all possible values. + // For other floating point types, discard lower mantissa bits that do not + // participate in rounding calculation to keep the test size reasonable. + constexpr bool is_fp8 = sizeof(SrcType) == 1; + + int test_bits = SrcTraits::kBits, shift = 0; + if (!is_fp8) { + int e_bits = test_bits - std::numeric_limits::digits; + int m_bits = std::numeric_limits::digits + 1; + test_bits = 1 + e_bits + m_bits; + shift = sizeof(SrcType) * CHAR_BIT - test_bits; + } + + using BitsType = typename SrcTraits::BitsType; + for (int i = 0; i < (1 << test_bits); ++i) { + BitsType value = static_cast(i) << shift; + SrcType fp = Eigen::numext::bit_cast(value); + FloatMX converted = static_cast(fp); + FloatMX expected = static_cast(static_cast(fp)); + EXPECT_EQ(converted, expected); + } +} + +} // namespace +} // namespace ml_dtypes