Skip to content

Commit

Permalink
Add float8_e4m3 and float8_e3m4 types support
Browse files Browse the repository at this point in the history
  • Loading branch information
superbobry authored and apivovarov committed Oct 10, 2024
1 parent 8ef41a6 commit c95f301
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 19 deletions.
19 changes: 19 additions & 0 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,17 @@ def type(self) -> type: ...


# fp8 support
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
float8_e3m4: type[np.generic] | None = None
float8_e4m3: type[np.generic] | None = None
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz
float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2
float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz

_float8_e3m4_dtype: np.dtype | None = None
_float8_e4m3_dtype: np.dtype | None = None
_float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz)
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
_float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
Expand Down Expand Up @@ -137,6 +142,20 @@ def supports_inf(dtype: DTypeLike) -> bool:
_float8_e5m2fnuz_dtype,
]

# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
if hasattr(ml_dtypes, "float8_e4m3"):
float8_e4m3 = ml_dtypes.float8_e4m3
_float8_e4m3_dtype = np.dtype(float8_e4m3)
_custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e4m3_dtype)
_float8_dtypes.insert(0, _float8_e4m3_dtype)
if hasattr(ml_dtypes, "float8_e3m4"):
float8_e3m4 = ml_dtypes.float8_e3m4
_float8_e3m4_dtype = np.dtype(float8_e3m4)
_custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e3m4_dtype)
_float8_dtypes.insert(0, _float8_e3m4_dtype)

# 2-bit integer support
int2: type[np.generic] | None = None
uint2: type[np.generic] | None = None
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/export/serialization.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ enum DType: byte {
i4 = 15,
ui4 = 16,

f8_e3m4 = 24,
f8_e4m3 = 23,
f8_e4m3b11fnuz = 17,
f8_e4m3fn = 18,
f8_e4m3fnuz = 19,
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/export/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz,
}

if dtypes._float8_e3m4_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4
if dtypes._float8_e4m3_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3

_dtype_kind_to_dtype = {
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/export/serialization_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class DType:
bf16 = 14
i4 = 15
ui4 = 16
f8_e3m4 = 24
f8_e4m3 = 23
f8_e4m3b11fnuz = 17
f8_e4m3fn = 18
f8_e4m3fnuz = 19
Expand Down
12 changes: 6 additions & 6 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,13 @@ def _is_ir_values(x: IrValues) -> bool:

if dtypes.int2 is not None:
assert dtypes.uint2 is not None
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(
ir.IntegerType.get_signless, 2
)
_dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(
ir.IntegerType.get_unsigned, 2
)
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2)
_dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2)

if dtypes.float8_e3m4 is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get
if dtypes.float8_e4m3 is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get

def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
if isinstance(dtype, core.bint):
Expand Down
12 changes: 10 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,11 +942,15 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
):
fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz),
fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz),
np.dtype(dtypes.float8_e4m3fn),
np.dtype(dtypes.float8_e4m3fnuz),
np.dtype(dtypes.float8_e5m2),
np.dtype(dtypes.float8_e5m2fnuz))
np.dtype(dtypes.float8_e5m2fnuz)]
if dtypes.float8_e3m4 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
raise ValueError(
f"The dot algorithm '{self}' requires both inputs to have float8 "
Expand Down Expand Up @@ -3485,6 +3489,10 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
if dtypes.float8_e3m4 is not None:
fp8_dtypes += (dtypes.float8_e3m4,)
if dtypes.float8_e4m3 is not None:
fp8_dtypes += (dtypes.float8_e4m3,)
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
del preferred_element_type # Implied by the output aval
lhs_aval, rhs_aval = ctx.avals_in
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
int16 = _make_scalar_type(np.int16)
int32 = _make_scalar_type(np.int32)
int64 = _make_scalar_type(np.int64)
if dtypes.float8_e3m4 is not None:
float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4)
if dtypes.float8_e4m3 is not None:
float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3)
float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn)
float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
Expand Down
14 changes: 14 additions & 0 deletions jax/_src/public_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ def default_tolerance():
np.dtype(np.complex128): 1e-5,
}

# TODO: make this unconditional when ml_dtypes>=0.5.0 is required
if _dtypes.float8_e3m4 is not None:
_default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
if _dtypes.float8_e4m3 is not None:
_default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1

def is_python_scalar(val):
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))

Expand All @@ -106,6 +114,12 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
_dtypes.float8_e5m2fnuz,
_dtypes.bfloat16,
]

if _dtypes.float8_e4m3 is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e4m3)
if _dtypes.float8_e3m4 is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e3m4)

def maybe_upcast(x):
if x.dtype in custom_float_dtypes:
return x.astype(np.float32)
Expand Down
20 changes: 16 additions & 4 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,10 +1433,22 @@ def supported(self, dtypes):

@_cached_property
def custom_floats(self):
return [np.dtype(t) for t in [
_dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz,
_dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz,
_dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]]
float_dtypes = [
_dtypes.bfloat16,
_dtypes.float8_e4m3b11fnuz,
_dtypes.float8_e4m3fn,
_dtypes.float8_e4m3fnuz,
_dtypes.float8_e5m2,
_dtypes.float8_e5m2fnuz,
]
# TODO: Remove lib.version check once minimum_jaxlib_version is 0.4.35+
# TODO: Remove "cpu" check once xla::GetDefaultStablehloVersion() is 1.7.0+
if device_under_test() == "cpu" and jax._src.lib.version >= (0, 4, 35):
if _dtypes.float8_e3m4 is not None:
float_dtypes += [_dtypes.float8_e3m4]
if _dtypes.float8_e4m3 is not None:
float_dtypes += [_dtypes.float8_e4m3]
return [np.dtype(t) for t in float_dtypes]

@_cached_property
def floating(self):
Expand Down
9 changes: 9 additions & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,15 @@
except ImportError:
pass

# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0
try:
from jax._src.numpy.lax_numpy import (
float8_e3m4 as float8_e3m4,
float8_e4m3 as float8_e4m3,
)
except ImportError:
pass

from jax._src.numpy.array_api_metadata import (
__array_api_version__ as __array_api_version__,
__array_namespace_info__ as __array_namespace_info__,
Expand Down
7 changes: 7 additions & 0 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@
fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn),
np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2),
np.dtype(dtypes.float8_e5m2fnuz)]
# TODO: Remove lib.version check once minimum_jaxlib_version is 0.4.35+
# TODO: Remove "cpu" check once xla::GetDefaultStablehloVersion() is 1.7.0+
if jtu.device_under_test() == "cpu" and jax._src.lib.version >= (0, 4, 35):
if dtypes.float8_e3m4 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
float_dtypes += fp8_dtypes
custom_float_dtypes += fp8_dtypes

Expand Down
22 changes: 15 additions & 7 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,13 +905,21 @@ def f_jax(x): # x: bool[b]
for dtype in dtypes._jax_types if dtype != np.dtype("bool")
])
def test_poly_numeric_dtypes(self, dtype=np.int32):
if str(dtype) in {"float8_e4m3b11fnuz",
"float8_e4m3fnuz",
"float8_e5m2fnuz",
"int2",
"int4",
"uint2",
"uint4"}:
unsupported_dtypes = {
"float8_e4m3b11fnuz",
"float8_e4m3fnuz",
"float8_e5m2fnuz",
"int2",
"int4",
"uint2",
"uint4",
}
# TODO: Remove once minimum_jaxlib_version is 0.4.35+
# TODO: Remove "cpu" check once xla::GetDefaultStablehloVersion() is 1.7.0+
if not (jtu.device_under_test() == "cpu" and jax._src.lib.version >= (0, 4, 35)):
unsupported_dtypes.add("float8_e3m4")
unsupported_dtypes.add("float8_e4m3")
if str(dtype) in unsupported_dtypes:
self.skipTest(f"TODO: serialization not supported for {str(dtype)}")
@jax.jit
def f_jax(x):
Expand Down

0 comments on commit c95f301

Please sign in to comment.