Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add float8_e4m3 #161

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]

* Added new 8-bit float type following IEEE 754 convention:
`ml_dtypes.float8_e4m3`.

## [0.4.0] - 2024-04-1

* Updates `ml_dtypes` for compatibility with future NumPy 2.0 release.
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
an alternative to the standard [`float16`](https://en.wikipedia.org/wiki/Half-precision_floating-point_format) format
- `float8_*`: several experimental 8-bit floating point representations
including:
* `float8_e4m3`
* `float8_e4m3b11fnuz`
* `float8_e4m3fn`
* `float8_e4m3fnuz`
Expand Down Expand Up @@ -64,6 +65,10 @@ A `bfloat16` number is a single-precision float truncated at 16 bits.

Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.

### `float8_e4m3`

Exponent: 4, Mantissa: 3, bias: 7. IEEE 754, with NaN and inf.

### `float8_e4m3b11fnuz`

Exponent: 4, Mantissa: 3, bias: 11.
Expand Down
3 changes: 3 additions & 0 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"__version__",
"bfloat16",
"finfo",
"float8_e4m3",
"float8_e4m3b11fnuz",
"float8_e4m3fn",
"float8_e4m3fnuz",
Expand All @@ -34,6 +35,7 @@
from ml_dtypes._finfo import finfo
from ml_dtypes._iinfo import iinfo
from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float8_e4m3
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
from ml_dtypes._ml_dtypes_ext import float8_e4m3fn
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
Expand All @@ -46,6 +48,7 @@
import numpy as np

bfloat16: Type[np.generic]
float8_e4m3: Type[np.generic]
float8_e4m3b11fnuz: Type[np.generic]
float8_e4m3fn: Type[np.generic]
float8_e4m3fnuz: Type[np.generic]
Expand Down
64 changes: 64 additions & 0 deletions ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Dict

from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float8_e4m3
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
from ml_dtypes._ml_dtypes_ext import float8_e4m3fn
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
Expand All @@ -25,6 +26,7 @@
import numpy as np

_bfloat16_dtype = np.dtype(bfloat16)
_float8_e4m3_dtype = np.dtype(float8_e4m3)
_float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz)
_float8_e4m3fn_dtype = np.dtype(float8_e4m3fn)
_float8_e4m3fnuz_dtype = np.dtype(float8_e4m3fnuz)
Expand All @@ -41,6 +43,15 @@ def __init__(self):
self.smallest_subnormal = bfloat16(smallest_subnormal)


class _Float8E4m3MachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p-6")
self.smallest_normal = float8_e4m3(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-9")
self.smallest_subnormal = float8_e4m3(smallest_subnormal)


class _Float8E4m3b11fnuzMachArLike:

def __init__(self):
Expand Down Expand Up @@ -135,6 +146,51 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_e4m3_finfo():
def float_to_str(f):
return "%6.2e" % float(f)

tiny = float.fromhex("0x1p-6") # 1/64 min normal
resolution = 0.1
eps = float.fromhex("0x1p-3") # 1/8
epsneg = float.fromhex("0x1p-4") # 1/16
max_ = float.fromhex("0x1.Ep7") # 240 max normal

obj = object.__new__(np.finfo)
obj.dtype = _float8_e4m3_dtype
obj.bits = 8
obj.eps = float8_e4m3(eps)
obj.epsneg = float8_e4m3(epsneg)
obj.machep = -3
obj.negep = -4
obj.max = float8_e4m3(max_)
obj.min = float8_e4m3(-max_)
obj.nexp = 4
obj.nmant = 3
obj.iexp = obj.nexp
obj.maxexp = 8
obj.minexp = -6
obj.precision = 1
obj.resolution = float8_e4m3(resolution)
# pylint: disable=protected-access
obj._machar = _Float8E4m3MachArLike()
if not hasattr(obj, "tiny"):
obj.tiny = float8_e4m3(tiny)
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = obj._machar.smallest_normal
obj.smallest_subnormal = obj._machar.smallest_subnormal

obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(max_)
obj._str_epsneg = float_to_str(epsneg)
obj._str_eps = float_to_str(eps)
obj._str_resolution = float_to_str(resolution)
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_e4m3b11fnuz_finfo():
def float_to_str(f):
Expand Down Expand Up @@ -369,6 +425,14 @@ def __new__(cls, dtype):
if _bfloat16_dtype not in cls._finfo_cache:
cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
return cls._finfo_cache[_bfloat16_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3"
or dtype == _float8_e4m3_dtype
):
if _float8_e4m3_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3_dtype] = cls._float8_e4m3_finfo()
return cls._finfo_cache[_float8_e4m3_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3b11fnuz"
Expand Down
28 changes: 28 additions & 0 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ struct TypeDescriptor<bfloat16> : CustomFloatType<bfloat16> {
static constexpr char kNpyDescrByteorder = '=';
};

template <>
struct TypeDescriptor<float8_e4m3> : CustomFloatType<float8_e4m3> {
typedef float8_e4m3 T;
static constexpr bool is_floating = true;
static constexpr bool is_integral = false;
static constexpr const char* kTypeName = "float8_e4m3";
static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e4m3";
static constexpr const char* kTpDoc = "float8_e4m3 floating-point values";
// Set e4m3 kind as Void since kind=f (float) with itemsize=1 is used by e5m2
static constexpr char kNpyDescrKind = 'V'; // Void
static constexpr char kNpyDescrType = '7'; // '4' is reserved for e4m3fn
static constexpr char kNpyDescrByteorder = '='; // Native byte order
};

template <>
struct TypeDescriptor<float8_e4m3b11fnuz>
: CustomFloatType<float8_e4m3b11fnuz> {
Expand Down Expand Up @@ -269,6 +283,9 @@ bool Initialize() {
if (!RegisterFloatDtype<bfloat16>(numpy.get())) {
return false;
}
if (!RegisterFloatDtype<float8_e4m3>(numpy.get())) {
return false;
}
if (!RegisterFloatDtype<float8_e4m3b11fnuz>(numpy.get())) {
return false;
}
Expand Down Expand Up @@ -319,6 +336,12 @@ bool Initialize() {
success &= RegisterTwoWayCustomCast<float8_e5m2fnuz, float8_e4m3fn, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3fnuz, float8_e5m2, float>();
success &= RegisterTwoWayCustomCast<float8_e5m2fnuz, float8_e5m2, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3, bfloat16, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e4m3b11fnuz, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e5m2fnuz, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e4m3fnuz, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e4m3fn, float>();
success &= RegisterTwoWayCustomCast<float8_e4m3, float8_e5m2, float>();
success &= RegisterOneWayCustomCast<int2, int4, int8_t>();
success &= RegisterOneWayCustomCast<uint2, uint4, uint8_t>();
return success;
Expand Down Expand Up @@ -349,6 +372,11 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() {
return nullptr;
}

if (PyObject_SetAttrString(m.get(), "float8_e4m3",
reinterpret_cast<PyObject*>(
TypeDescriptor<float8_e4m3>::type_ptr)) < 0) {
return nullptr;
}
if (PyObject_SetAttrString(
m.get(), "float8_e4m3b11fnuz",
reinterpret_cast<PyObject*>(
Expand Down
Loading
Loading