-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add float8_e4m3 and float8_e3m4 types support #23585
Conversation
Thanks for the contribution! I don't think we'll be able to bump our The good news is this is easy enough to do with a few version guards: if you look at the initial implementation of |
Here's an example of how this was handled in the past: https://github.com/google/jax/blob/jax-v0.4.12/jax/_src/dtypes.py#L71 Basically, we only define the dtype in JAX if it's defined in Another strategy we could use is the module-level |
Incidentally, the current TF pin is : If we release I suspect we could ease this process if we committed to semver for |
That said I'd probably do it the way Jake said for now and then we can think about the minimum version bump separately, there may be other factors I haven't considered (e.g., users being stuck on an older TF for whatever reason). |
0336705
to
5553be7
Compare
5553be7
to
090ff3e
Compare
I updated the PR and tested it with ml_dtypes 0.4.0 and 0.5.0 |
a6218f2
to
0b862ed
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good – at some point we should refactor things so that we keep a single source of truth for the list of custom float types, rather than re-defining it a half dozen times across the package. But that can be for another PR.
Please fix the lint issues – thanks! Also, the test failures look real. It seems that there's some place where the new float8 types must be registered |
0b862ed
to
be89ca7
Compare
MyPyfixed mypy issues btw, Contributing to JAX explains how to run lint/ruff/mypy/jupytext locally pre-commit run --all-files All passed now Regarding failed testsFAILED tests/export_test.py::JaxExportTest::test_poly_numeric_dtypes_dtype_float8_e3m4e4m3 and e3m4 were added to stablehlo 1.7.0 (3 weeks ago, Sep 4, 2024) jax/_src/export/_export.py uses
it returns target_version 1.5.0 workaround - use NONE instead of WEEK_4 - it returns 1.7.5 FAILED tests/array_test.py::JaxArrayTest::test_shards_have_correct_dtype17FAILED tests/dtypes_test.py::TestPromotionTables::testFloat8PromotionErrorThe tests work fine if I use XLA COMMIT_ID from XLA PR openxla/xla#16585 Add support for float8_e4m3 and float8_e3m4 types (in Review) I guess we need to put this PR on hold and rerun the tests once XLA PR-16585 is merged to XLA main and XLA_COMMIT is updated in JAX. StableHLO WEEK_4 issue should resolve itself in 1-2 weeks too. Test report on CPU pytest -n auto tests/
26921 passed, 12530 skipped in 492.32s (0:08:12) |
Imported from GitHub PR openxla/xla#16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ### `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` ### Testing: ``` bazel test \ //xla:array2d_test \ //xla:fp_util_test \ //xla:literal_comparison_test \ //xla:literal_test \ //xla/mlir/utils:type_util_test \ //xla:primitive_util_test \ //xla/python/ifrt:dtype_test \ //xla/python:xla_client_test \ //xla/service:elemental_ir_emitter_test \ //xla/service:float_normalization_test \ //xla/service/gpu/tests:float_conversions_test \ //xla/tests:array_elementwise_ops_test \ //xla/tests:constants_test \ //xla/tests:convert_test \ //xla/tests:float8_test \ //xla:util_test bazel test \ //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \ //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test ``` ### Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged) - StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged) - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved) - XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template) - JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review) Copybara import of the project: -- ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>: Add support for float8_e4m3 and float8_e3m4 types Merging this change closes #16585 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6 PiperOrigin-RevId: 680651037
Imported from GitHub PR #16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ### `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` ### Testing: ``` bazel test \ //xla:array2d_test \ //xla:fp_util_test \ //xla:literal_comparison_test \ //xla:literal_test \ //xla/mlir/utils:type_util_test \ //xla:primitive_util_test \ //xla/python/ifrt:dtype_test \ //xla/python:xla_client_test \ //xla/service:elemental_ir_emitter_test \ //xla/service:float_normalization_test \ //xla/service/gpu/tests:float_conversions_test \ //xla/tests:array_elementwise_ops_test \ //xla/tests:constants_test \ //xla/tests:convert_test \ //xla/tests:float8_test \ //xla:util_test bazel test \ //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \ //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test ``` ### Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged) - StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged) - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved) - XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template) - JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review) Copybara import of the project: -- ec1c723 by Alexander Pivovarov <[email protected]>: Add support for float8_e4m3 and float8_e3m4 types Merging this change closes #16585 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723 PiperOrigin-RevId: 680651037
Imported from GitHub PR openxla/xla#16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ### `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` ### Testing: ``` bazel test \ //xla:array2d_test \ //xla:fp_util_test \ //xla:literal_comparison_test \ //xla:literal_test \ //xla/mlir/utils:type_util_test \ //xla:primitive_util_test \ //xla/python/ifrt:dtype_test \ //xla/python:xla_client_test \ //xla/service:elemental_ir_emitter_test \ //xla/service:float_normalization_test \ //xla/service/gpu/tests:float_conversions_test \ //xla/tests:array_elementwise_ops_test \ //xla/tests:constants_test \ //xla/tests:convert_test \ //xla/tests:float8_test \ //xla:util_test bazel test \ //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \ //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test ``` ### Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged) - StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged) - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved) - XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template) - JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review) Copybara import of the project: -- ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>: Add support for float8_e4m3 and float8_e3m4 types Merging this change closes #16585 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6 PiperOrigin-RevId: 680651037
Imported from GitHub PR openxla/xla#16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ### `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` ### Testing: ``` bazel test \ //xla:array2d_test \ //xla:fp_util_test \ //xla:literal_comparison_test \ //xla:literal_test \ //xla/mlir/utils:type_util_test \ //xla:primitive_util_test \ //xla/python/ifrt:dtype_test \ //xla/python:xla_client_test \ //xla/service:elemental_ir_emitter_test \ //xla/service:float_normalization_test \ //xla/service/gpu/tests:float_conversions_test \ //xla/tests:array_elementwise_ops_test \ //xla/tests:constants_test \ //xla/tests:convert_test \ //xla/tests:float8_test \ //xla:util_test bazel test \ //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \ //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test ``` ### Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged) - StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged) - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved) - XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template) - JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review) Copybara import of the project: -- ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>: Add support for float8_e4m3 and float8_e3m4 types Merging this change closes #16585 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6 PiperOrigin-RevId: 680651037
Imported from GitHub PR #16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ### `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` ### Testing: ``` bazel test \ //xla:array2d_test \ //xla:fp_util_test \ //xla:literal_comparison_test \ //xla:literal_test \ //xla/mlir/utils:type_util_test \ //xla:primitive_util_test \ //xla/python/ifrt:dtype_test \ //xla/python:xla_client_test \ //xla/service:elemental_ir_emitter_test \ //xla/service:float_normalization_test \ //xla/service/gpu/tests:float_conversions_test \ //xla/tests:array_elementwise_ops_test \ //xla/tests:constants_test \ //xla/tests:convert_test \ //xla/tests:float8_test \ //xla:util_test bazel test \ //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \ //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test ``` ### Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged) - StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged) - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved) - XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template) - JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review) Copybara import of the project: -- ec1c723 by Alexander Pivovarov <[email protected]>: Add support for float8_e4m3 and float8_e3m4 types Merging this change closes #16585 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723 PiperOrigin-RevId: 680651037
Imported from GitHub PR openxla/xla#16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ### `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` ### Testing: ``` bazel test \ //xla:array2d_test \ //xla:fp_util_test \ //xla:literal_comparison_test \ //xla:literal_test \ //xla/mlir/utils:type_util_test \ //xla:primitive_util_test \ //xla/python/ifrt:dtype_test \ //xla/python:xla_client_test \ //xla/service:elemental_ir_emitter_test \ //xla/service:float_normalization_test \ //xla/service/gpu/tests:float_conversions_test \ //xla/tests:array_elementwise_ops_test \ //xla/tests:constants_test \ //xla/tests:convert_test \ //xla/tests:float8_test \ //xla:util_test bazel test \ //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \ //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test ``` ### Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged) - StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged) - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved) - XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template) - JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review) Copybara import of the project: -- ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>: Add support for float8_e4m3 and float8_e3m4 types Merging this change closes #16585 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6 PiperOrigin-RevId: 680651037
Imported from GitHub PR #16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ### `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` ### Testing: ``` bazel test \ //xla:array2d_test \ //xla:fp_util_test \ //xla:literal_comparison_test \ //xla:literal_test \ //xla/mlir/utils:type_util_test \ //xla:primitive_util_test \ //xla/python/ifrt:dtype_test \ //xla/python:xla_client_test \ //xla/service:elemental_ir_emitter_test \ //xla/service:float_normalization_test \ //xla/service/gpu/tests:float_conversions_test \ //xla/tests:array_elementwise_ops_test \ //xla/tests:constants_test \ //xla/tests:convert_test \ //xla/tests:float8_test \ //xla:util_test bazel test \ //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \ //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test ``` ### Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged) - StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged) - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved) - XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template) - JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review) Copybara import of the project: -- ec1c723 by Alexander Pivovarov <[email protected]>: Add support for float8_e4m3 and float8_e3m4 types Merging this change closes #16585 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723 PiperOrigin-RevId: 680651037
Imported from GitHub PR openxla/xla#16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ### `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` ### Testing: ``` bazel test \ //xla:array2d_test \ //xla:fp_util_test \ //xla:literal_comparison_test \ //xla:literal_test \ //xla/mlir/utils:type_util_test \ //xla:primitive_util_test \ //xla/python/ifrt:dtype_test \ //xla/python:xla_client_test \ //xla/service:elemental_ir_emitter_test \ //xla/service:float_normalization_test \ //xla/service/gpu/tests:float_conversions_test \ //xla/tests:array_elementwise_ops_test \ //xla/tests:constants_test \ //xla/tests:convert_test \ //xla/tests:float8_test \ //xla:util_test bazel test \ //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \ //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test ``` ### Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged) - StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged) - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved) - XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template) - JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review) Copybara import of the project: -- ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>: Add support for float8_e4m3 and float8_e3m4 types Merging this change closes #16585 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6 PiperOrigin-RevId: 680651037
We're seeing some internal failures on GPU and TPU backends. I'll try to debug. |
Will try to build/test on GPU instance |
The error is this:
Something to do with one of the lower-level passes not knowing how to consume the new serialized types |
7937051
to
f611112
Compare
I tested the PR on GPU instance (nvidia A10G 23GB) and found that The issue occurs because When I replace Temporary I limited # TODO: Remove "cpu" check once xla::GetDefaultStablehloVersion() is 1.7.0+
if device_under_test() == "cpu" and jax._src.lib.version >= (0, 4, 35):
... I also opened an XLA PR to review the possibility to use |
Yeah, we're aware of this. We need to plumb the plugin's stablehlo version to JAX, which we haven't done yet. We should always be able to produce the newest stablehlo the consumer can consume, but currently we are overly conservative, I think. @dfm |
It is likely that the same code is utilized in some closed-source PJRT plugins/compilers.
The tests work fine on GPU if
|
f611112
to
c95f301
Compare
Jake, can you help with re-running the tests. I temporary limited f8e3m4 and f8e4m3 tests to run on "cpu" only. |
c95f301
to
1136024
Compare
Summary of PR Testing on GPU:Installed packages:
Results:
*Note: Initial test failures occurred for f8e3m4 and f8e4m3 on the GPU due to the use of StableHLO version 1.1.0 (WEEK_12). Jake, Peter, Dan, |
We're still seeing some new failures in the PJRT runtime – I'm not sure how to address those. @hawkinsp do you have thoughts on how to proceed here? |
StableHLO WEEK_12 is in the process of updating to version 1.5.0 (openxla/stablehlo#2599) WEEK_12 is scheduled to switch to version 1.7.0 (where f8e4m3 was added) in about 36 days (on or after November 28, 2024). Are there any potential workarounds for the "failed internal tests for GPU and TPU backends" (reported here) that could allow this PR to be shipped earlier? |
1136024
to
78da9fa
Compare
I retested this PR using jaxlib 0.4.35 (released Oct 22) on a server with 4 GPU devices. The checks and workarounds I previously added to skip certain tests for f8e4m3/f8e3m4 types are no longer needed, so I removed them and rebased this PR. All tests passed on the 4-GPU setup. Jake, Peter, do you think we should attempt merging this PR again now that jaxlib 0.4.35 was released two weeks ago and all tests passed on GPU without any workarounds or skips? |
The other methods of `_LazyDtypes` filter by the supported dtypes, so it's strange that this property does not. Change in preparation for landing #23585 without breaking existing tests. PiperOrigin-RevId: 697727110
I think #24956 will unblock merging this. The problem is that we shouldn't be running for every custom dtype in the tests, if any given backend only supports a subset of them. |
The other methods of `_LazyDtypes` filter by the supported dtypes, so it's strange that this property does not. Change in preparation for landing #23585 without breaking existing tests. PiperOrigin-RevId: 697727110
The other methods of `_LazyDtypes` filter by the supported dtypes, so it's strange that this property does not. Change in preparation for landing #23585 without breaking existing tests. PiperOrigin-RevId: 697752034
Description
Amazon has proposed two new FP8 types,
Float8E4M3
andFloat8E3M4
. These types are implemented in commercially available hardware Amazon EC2 Trn1 Instances, and added to MLIR builtin types, LLVM APFloat, ml_dtypes, StableHLO.XLA has Float8E4M3 and Float8E3M4 implementation in Review. See PR links in Related PRs section below.
This PR adds f8E4M3 and f8E3M4 types support to JAX.
f8E4M3
type follows IEEE 754 convention.f8E3M4
type follows IEEE 754 conventionRelated PRs:
How to build/install
This PR requires ml_dtype version 20240821 or later.
The current version on PyPI is 0.4.0, released on April 1, 2024, which is outdated. Therefore, ml_dtypes should be installed from source.
Related issue: jax-ml/ml_dtypes#185 [Question] Can we release a new version of ml_dtypes?
Smoke test