From c95f301c73d6a2a0513952c06bdc07b7e360705a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 7 Oct 2024 15:33:24 -0700 Subject: [PATCH] Add float8_e4m3 and float8_e3m4 types support --- jax/_src/dtypes.py | 19 +++++++++++++++++++ jax/_src/export/serialization.fbs | 2 ++ jax/_src/export/serialization.py | 4 ++++ jax/_src/export/serialization_generated.py | 2 ++ jax/_src/interpreters/mlir.py | 12 ++++++------ jax/_src/lax/lax.py | 12 ++++++++++-- jax/_src/numpy/lax_numpy.py | 4 ++++ jax/_src/public_test_util.py | 14 ++++++++++++++ jax/_src/test_util.py | 20 ++++++++++++++++---- jax/numpy/__init__.py | 9 +++++++++ tests/dtypes_test.py | 7 +++++++ tests/export_test.py | 22 +++++++++++++++------- 12 files changed, 108 insertions(+), 19 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 82be38d1cb57..a0137f06e15f 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -90,12 +90,17 @@ def type(self) -> type: ... # fp8 support +# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 +float8_e3m4: type[np.generic] | None = None +float8_e4m3: type[np.generic] | None = None float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2 float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz +_float8_e3m4_dtype: np.dtype | None = None +_float8_e4m3_dtype: np.dtype | None = None _float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) @@ -137,6 +142,20 @@ def supports_inf(dtype: DTypeLike) -> bool: _float8_e5m2fnuz_dtype, ] +# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 +if hasattr(ml_dtypes, "float8_e4m3"): + float8_e4m3 = ml_dtypes.float8_e4m3 + _float8_e4m3_dtype = np.dtype(float8_e4m3) + _custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e4m3_dtype) + _float8_dtypes.insert(0, _float8_e4m3_dtype) +if hasattr(ml_dtypes, "float8_e3m4"): + float8_e3m4 = ml_dtypes.float8_e3m4 + _float8_e3m4_dtype = np.dtype(float8_e3m4) + _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e3m4_dtype) + _float8_dtypes.insert(0, _float8_e3m4_dtype) + # 2-bit integer support int2: type[np.generic] | None = None uint2: type[np.generic] | None = None diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 758950adaa8e..59e169dc6fb6 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -64,6 +64,8 @@ enum DType: byte { i4 = 15, ui4 = 16, + f8_e3m4 = 24, + f8_e4m3 = 23, f8_e4m3b11fnuz = 17, f8_e4m3fn = 18, f8_e4m3fnuz = 19, diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index a47b095e4450..e283e0d57528 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -324,6 +324,10 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz, } +if dtypes._float8_e3m4_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 +if dtypes._float8_e4m3_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index a872d03a9fdd..583b41814963 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -52,6 +52,8 @@ class DType: bf16 = 14 i4 = 15 ui4 = 16 + f8_e3m4 = 24 + f8_e4m3 = 23 f8_e4m3b11fnuz = 17 f8_e4m3fn = 18 f8_e4m3fnuz = 19 diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index b6e9b2f4ef07..9432874e37b1 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -186,13 +186,13 @@ def _is_ir_values(x: IrValues) -> bool: if dtypes.int2 is not None: assert dtypes.uint2 is not None - _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial( - ir.IntegerType.get_signless, 2 - ) - _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial( - ir.IntegerType.get_unsigned, 2 - ) + _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2) + _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2) +if dtypes.float8_e3m4 is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get +if dtypes.float8_e4m3 is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 32e723f31172..0281cc997a65 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -942,11 +942,15 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, DotAlgorithmPreset.ANY_F8_ANY_F8_ANY | DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM ): - fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz), + fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), - np.dtype(dtypes.float8_e5m2fnuz)) + np.dtype(dtypes.float8_e5m2fnuz)] + if dtypes.float8_e3m4 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] + if dtypes.float8_e4m3 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " @@ -3485,6 +3489,10 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) + if dtypes.float8_e3m4 is not None: + fp8_dtypes += (dtypes.float8_e3m4,) + if dtypes.float8_e4m3 is not None: + fp8_dtypes += (dtypes.float8_e4m3,) return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes del preferred_element_type # Implied by the output aval lhs_aval, rhs_aval = ctx.avals_in diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a0e218c88cc2..2341145374f4 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -210,6 +210,10 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: int16 = _make_scalar_type(np.int16) int32 = _make_scalar_type(np.int32) int64 = _make_scalar_type(np.int64) +if dtypes.float8_e3m4 is not None: + float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) +if dtypes.float8_e4m3 is not None: + float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 9859eb64cda2..6bbcdd08471f 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -90,6 +90,14 @@ def default_tolerance(): np.dtype(np.complex128): 1e-5, } +# TODO: make this unconditional when ml_dtypes>=0.5.0 is required +if _dtypes.float8_e3m4 is not None: + _default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 + default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 +if _dtypes.float8_e4m3 is not None: + _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 + default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 + def is_python_scalar(val): return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -106,6 +114,12 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): _dtypes.float8_e5m2fnuz, _dtypes.bfloat16, ] + + if _dtypes.float8_e4m3 is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e4m3) + if _dtypes.float8_e3m4 is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e3m4) + def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 81737f27540b..f54fbecf015b 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1433,10 +1433,22 @@ def supported(self, dtypes): @_cached_property def custom_floats(self): - return [np.dtype(t) for t in [ - _dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz, - _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz, - _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]] + float_dtypes = [ + _dtypes.bfloat16, + _dtypes.float8_e4m3b11fnuz, + _dtypes.float8_e4m3fn, + _dtypes.float8_e4m3fnuz, + _dtypes.float8_e5m2, + _dtypes.float8_e5m2fnuz, + ] + # TODO: Remove lib.version check once minimum_jaxlib_version is 0.4.35+ + # TODO: Remove "cpu" check once xla::GetDefaultStablehloVersion() is 1.7.0+ + if device_under_test() == "cpu" and jax._src.lib.version >= (0, 4, 35): + if _dtypes.float8_e3m4 is not None: + float_dtypes += [_dtypes.float8_e3m4] + if _dtypes.float8_e4m3 is not None: + float_dtypes += [_dtypes.float8_e4m3] + return [np.dtype(t) for t in float_dtypes] @_cached_property def floating(self): diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index bd806872990f..8fc991b2d21c 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -273,6 +273,15 @@ except ImportError: pass +# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0 +try: + from jax._src.numpy.lax_numpy import ( + float8_e3m4 as float8_e3m4, + float8_e4m3 as float8_e4m3, + ) +except ImportError: + pass + from jax._src.numpy.array_api_metadata import ( __array_api_version__ as __array_api_version__, __array_namespace_info__ as __array_namespace_info__, diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 89d70871a8f9..27b1cef1230b 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -64,6 +64,13 @@ fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), np.dtype(dtypes.float8_e5m2fnuz)] +# TODO: Remove lib.version check once minimum_jaxlib_version is 0.4.35+ +# TODO: Remove "cpu" check once xla::GetDefaultStablehloVersion() is 1.7.0+ +if jtu.device_under_test() == "cpu" and jax._src.lib.version >= (0, 4, 35): + if dtypes.float8_e3m4 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] + if dtypes.float8_e4m3 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes diff --git a/tests/export_test.py b/tests/export_test.py index 0d946d84d22b..2f4321803675 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -905,13 +905,21 @@ def f_jax(x): # x: bool[b] for dtype in dtypes._jax_types if dtype != np.dtype("bool") ]) def test_poly_numeric_dtypes(self, dtype=np.int32): - if str(dtype) in {"float8_e4m3b11fnuz", - "float8_e4m3fnuz", - "float8_e5m2fnuz", - "int2", - "int4", - "uint2", - "uint4"}: + unsupported_dtypes = { + "float8_e4m3b11fnuz", + "float8_e4m3fnuz", + "float8_e5m2fnuz", + "int2", + "int4", + "uint2", + "uint4", + } + # TODO: Remove once minimum_jaxlib_version is 0.4.35+ + # TODO: Remove "cpu" check once xla::GetDefaultStablehloVersion() is 1.7.0+ + if not (jtu.device_under_test() == "cpu" and jax._src.lib.version >= (0, 4, 35)): + unsupported_dtypes.add("float8_e3m4") + unsupported_dtypes.add("float8_e4m3") + if str(dtype) in unsupported_dtypes: self.skipTest(f"TODO: serialization not supported for {str(dtype)}") @jax.jit def f_jax(x):