Skip to content

Commit

Permalink
Merge pull request jax-ml#181 from sergey-kozub:mxfloat
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673843535
  • Loading branch information
The ml_dtypes Authors committed Sep 12, 2024
2 parents 6f02f77 + b68531f commit 40e66e5
Show file tree
Hide file tree
Showing 11 changed files with 1,177 additions and 271 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

* Added new 8-bit float types following IEEE 754 convention:
`ml_dtypes.float8_e4m3` and `ml_dtypes.float8_e3m4`.
* Added new 4-bit and 6-bit float types:
`ml_dtypes.float4_e2m1fn`, `ml_dtypes.float6_e2m3fn` and `ml_dtypes.float6_e3m2fn`.
* Fix outputs of float `divmod` and `floor_divide` when denominator is zero.

## [0.4.0] - 2024-04-1
Expand Down
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
* `float8_e4m3fnuz`
* `float8_e5m2`
* `float8_e5m2fnuz`
- Microscaling (MX) sub-byte floating point representations including:
* `float4_e2m1fn`
* `float6_e2m3fn`
* `float6_e3m2fn`
- `int2`, `int4`, `uint2` and `uint4`: low precision integer types.

See below for specifications of these number formats.
Expand Down Expand Up @@ -66,6 +70,39 @@ A `bfloat16` number is a single-precision float truncated at 16 bits.

Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.

### `float4_e2m1fn`

Exponent: 2, Mantissa: 1, bias: 1.

Extended range: no inf, no NaN.

Microscaling format, 4 bits (encoding: `0bSEEM`) using byte storage (higher 4
bits are unused). NaN representation is undefined.

Possible absolute values: [`0`, `0.5`, `1`, `1.5`, `2`, `3`, `4`, `6`]

### `float6_e2m3fn`

Exponent: 2, Mantissa: 3, bias: 1.

Extended range: no inf, no NaN.

Microscaling format, 6 bits (encoding: `0bSEEMMM`) using byte storage (higher 2
bits are unused). NaN representation is undefined.

Possible values range: [`-7.5`; `7.5`]

### `float6_e3m2fn`

Exponent: 3, Mantissa: 2, bias: 3.

Extended range: no inf, no NaN.

Microscaling format, 4 bits (encoding: `0bSEEEMM`) using byte storage (higher 2
bits are unused). NaN representation is undefined.

Possible values range: [`-28`; `28`]

### `float8_e3m4`

Exponent: 3, Mantissa: 4, bias: 3. IEEE 754, with NaN and inf.
Expand Down
9 changes: 9 additions & 0 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
"__version__",
"bfloat16",
"finfo",
"float4_e2m1fn",
"float6_e2m3fn",
"float6_e3m2fn",
"float8_e3m4",
"float8_e4m3",
"float8_e4m3b11fnuz",
Expand All @@ -36,6 +39,9 @@
from ml_dtypes._finfo import finfo
from ml_dtypes._iinfo import iinfo
from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float4_e2m1fn
from ml_dtypes._ml_dtypes_ext import float6_e2m3fn
from ml_dtypes._ml_dtypes_ext import float6_e3m2fn
from ml_dtypes._ml_dtypes_ext import float8_e3m4
from ml_dtypes._ml_dtypes_ext import float8_e4m3
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
Expand All @@ -50,6 +56,9 @@
import numpy as np

bfloat16: Type[np.generic]
float4_e2m1fn: Type[np.generic]
float6_e2m3fn: Type[np.generic]
float6_e3m2fn: Type[np.generic]
float8_e3m4: Type[np.generic]
float8_e4m3: Type[np.generic]
float8_e4m3b11fnuz: Type[np.generic]
Expand Down
253 changes: 186 additions & 67 deletions ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from typing import Dict

from ml_dtypes._ml_dtypes_ext import bfloat16
from ml_dtypes._ml_dtypes_ext import float4_e2m1fn
from ml_dtypes._ml_dtypes_ext import float6_e2m3fn
from ml_dtypes._ml_dtypes_ext import float6_e3m2fn
from ml_dtypes._ml_dtypes_ext import float8_e3m4
from ml_dtypes._ml_dtypes_ext import float8_e4m3
from ml_dtypes._ml_dtypes_ext import float8_e4m3b11fnuz
Expand All @@ -27,6 +30,9 @@
import numpy as np

_bfloat16_dtype = np.dtype(bfloat16)
_float4_e2m1fn_dtype = np.dtype(float4_e2m1fn)
_float6_e2m3fn_dtype = np.dtype(float6_e2m3fn)
_float6_e3m2fn_dtype = np.dtype(float6_e3m2fn)
_float8_e3m4_dtype = np.dtype(float8_e3m4)
_float8_e4m3_dtype = np.dtype(float8_e4m3)
_float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz)
Expand All @@ -45,6 +51,33 @@ def __init__(self):
self.smallest_subnormal = bfloat16(smallest_subnormal)


class _Float4E2m1fnMachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p0")
self.smallest_normal = float4_e2m1fn(smallest_normal)
smallest_subnormal = float.fromhex("0x0.8p0")
self.smallest_subnormal = float4_e2m1fn(smallest_subnormal)


class _Float6E2m3fnMachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p0")
self.smallest_normal = float6_e2m3fn(smallest_normal)
smallest_subnormal = float.fromhex("0x0.2p0")
self.smallest_subnormal = float6_e2m3fn(smallest_subnormal)


class _Float6E3m2fnMachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p-2")
self.smallest_normal = float6_e3m2fn(smallest_normal)
smallest_subnormal = float.fromhex("0x0.4p-2")
self.smallest_subnormal = float6_e3m2fn(smallest_subnormal)


class _Float8E3m4MachArLike:

def __init__(self):
Expand Down Expand Up @@ -110,7 +143,7 @@ def __init__(self):

class finfo(np.finfo): # pylint: disable=invalid-name,missing-class-docstring
__doc__ = np.finfo.__doc__
_finfo_cache: Dict[np.dtype, np.finfo] = {}
_finfo_cache: Dict[type, np.finfo] = {} # pylint: disable=g-bare-generic

@staticmethod
def _bfloat16_finfo():
Expand Down Expand Up @@ -157,6 +190,129 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj

@staticmethod
def _float4_e2m1fn_finfo():
eps = float.fromhex("0x0.8p0") # 0.5
max_ = float.fromhex("0x1.8p2") # 6.0

obj = object.__new__(np.finfo)
obj.dtype = _float4_e2m1fn_dtype
obj.bits = 4
obj.eps = eps
obj.epsneg = eps
obj.machep = -1
obj.negep = -1
obj.max = float4_e2m1fn(max_)
obj.min = float4_e2m1fn(-max_)
obj.nexp = 2
obj.nmant = 1
obj.iexp = obj.nexp
obj.maxexp = 3
obj.minexp = 0
obj.precision = 0
obj.resolution = float4_e2m1fn(1.0)
# pylint: disable=protected-access
obj._machar = _Float4E2m1fnMachArLike()
tiny = obj._machar.smallest_normal
if not hasattr(obj, "tiny"):
obj.tiny = tiny
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = tiny
obj.smallest_subnormal = obj._machar.smallest_subnormal

float_to_str = str
obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(obj.max)
obj._str_epsneg = float_to_str(obj.epsneg)
obj._str_eps = float_to_str(obj.eps)
obj._str_resolution = float_to_str(obj.resolution)
# pylint: enable=protected-access
return obj

@staticmethod
def _float6_e2m3fn_finfo():
eps = float.fromhex("0x0.2p0") # 0.125
max_ = float.fromhex("0x1.Ep2") # 7.5

obj = object.__new__(np.finfo)
obj.dtype = _float6_e2m3fn_dtype
obj.bits = 6
obj.eps = eps
obj.epsneg = eps
obj.machep = -3
obj.negep = -3
obj.max = float6_e2m3fn(max_)
obj.min = float6_e2m3fn(-max_)
obj.nexp = 2
obj.nmant = 3
obj.iexp = obj.nexp
obj.maxexp = 3
obj.minexp = 0
obj.precision = 0
obj.resolution = float6_e2m3fn(1.0)
# pylint: disable=protected-access
obj._machar = _Float6E2m3fnMachArLike()
tiny = obj._machar.smallest_normal
if not hasattr(obj, "tiny"):
obj.tiny = tiny
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = tiny
obj.smallest_subnormal = obj._machar.smallest_subnormal

float_to_str = str
obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(obj.max)
obj._str_epsneg = float_to_str(obj.epsneg)
obj._str_eps = float_to_str(obj.eps)
obj._str_resolution = float_to_str(obj.resolution)
# pylint: enable=protected-access
return obj

@staticmethod
def _float6_e3m2fn_finfo():
eps = float.fromhex("0x1p-2") # 0.25
max_ = float.fromhex("0x1.Cp4") # 28

obj = object.__new__(np.finfo)
obj.dtype = _float6_e3m2fn_dtype
obj.bits = 6
obj.eps = eps
obj.epsneg = eps / 2
obj.machep = -2
obj.negep = -3
obj.max = float6_e3m2fn(max_)
obj.min = float6_e3m2fn(-max_)
obj.nexp = 3
obj.nmant = 2
obj.iexp = obj.nexp
obj.maxexp = 5
obj.minexp = -2
obj.precision = 0
obj.resolution = float6_e3m2fn(1.0)
# pylint: disable=protected-access
obj._machar = _Float6E3m2fnMachArLike()
tiny = obj._machar.smallest_normal
if not hasattr(obj, "tiny"):
obj.tiny = tiny
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = tiny
obj.smallest_subnormal = obj._machar.smallest_subnormal

float_to_str = str
obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(obj.max)
obj._str_epsneg = float_to_str(obj.epsneg)
obj._str_eps = float_to_str(obj.eps)
obj._str_resolution = float_to_str(obj.resolution)
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_e3m4_finfo():
def float_to_str(f):
Expand Down Expand Up @@ -472,71 +628,34 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj

_finfo_type_map = {
_bfloat16_dtype: _bfloat16_finfo,
_float4_e2m1fn_dtype: _float4_e2m1fn_finfo,
_float6_e2m3fn_dtype: _float6_e2m3fn_finfo,
_float6_e3m2fn_dtype: _float6_e3m2fn_finfo,
_float8_e3m4_dtype: _float8_e3m4_finfo,
_float8_e4m3_dtype: _float8_e4m3_finfo,
_float8_e4m3fn_dtype: _float8_e4m3fn_finfo,
_float8_e4m3fnuz_dtype: _float8_e4m3fnuz_finfo,
_float8_e4m3b11fnuz_dtype: _float8_e4m3b11fnuz_finfo,
_float8_e5m2_dtype: _float8_e5m2_finfo,
_float8_e5m2fnuz_dtype: _float8_e5m2fnuz_finfo,
}
_finfo_name_map = {t.name: t for t in _finfo_type_map}

def __new__(cls, dtype):
if (
isinstance(dtype, str)
and dtype == "bfloat16"
or dtype == _bfloat16_dtype
):
if _bfloat16_dtype not in cls._finfo_cache:
cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
return cls._finfo_cache[_bfloat16_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e3m4"
or dtype == _float8_e3m4_dtype
):
if _float8_e3m4_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e3m4_dtype] = cls._float8_e3m4_finfo()
return cls._finfo_cache[_float8_e3m4_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3"
or dtype == _float8_e4m3_dtype
):
if _float8_e4m3_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3_dtype] = cls._float8_e4m3_finfo()
return cls._finfo_cache[_float8_e4m3_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3b11fnuz"
or dtype == _float8_e4m3b11fnuz_dtype
):
if _float8_e4m3b11fnuz_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3b11fnuz_dtype] = (
cls._float8_e4m3b11fnuz_finfo()
)
return cls._finfo_cache[_float8_e4m3b11fnuz_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3fn"
or dtype == _float8_e4m3fn_dtype
):
if _float8_e4m3fn_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3fn_dtype] = cls._float8_e4m3fn_finfo()
return cls._finfo_cache[_float8_e4m3fn_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e4m3fnuz"
or dtype == _float8_e4m3fnuz_dtype
):
if _float8_e4m3fnuz_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e4m3fnuz_dtype] = cls._float8_e4m3fnuz_finfo()
return cls._finfo_cache[_float8_e4m3fnuz_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e5m2"
or dtype == _float8_e5m2_dtype
):
if _float8_e5m2_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e5m2_dtype] = cls._float8_e5m2_finfo()
return cls._finfo_cache[_float8_e5m2_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e5m2fnuz"
or dtype == _float8_e5m2fnuz_dtype
):
if _float8_e5m2fnuz_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e5m2fnuz_dtype] = cls._float8_e5m2fnuz_finfo()
return cls._finfo_cache[_float8_e5m2fnuz_dtype]
if isinstance(dtype, str):
key = cls._finfo_name_map.get(dtype)
elif isinstance(dtype, np.dtype):
key = dtype
else:
key = np.dtype(dtype)
i = cls._finfo_cache.get(key)
if i is not None:
return i

init = cls._finfo_type_map.get(key)
if init is not None:
cls._finfo_cache[dtype] = init()
return cls._finfo_cache[dtype]
return super().__new__(cls, dtype)
Loading

0 comments on commit 40e66e5

Please sign in to comment.