-
Notifications
You must be signed in to change notification settings - Fork 434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for float8_e4m3 and float8_e3m4 types #16585
Conversation
0a7a780
to
5b32d43
Compare
5b32d43
to
98e4256
Compare
third_party/stablehlo/workspace.bzl
Outdated
# LINT.ThenChange(Google-internal path) | ||
|
||
tf_http_archive( | ||
name = "stablehlo", | ||
sha256 = STABLEHLO_SHA256, | ||
strip_prefix = "stablehlo-{commit}".format(commit = STABLEHLO_COMMIT), | ||
urls = tf_mirror_urls("https://github.com/openxla/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)), | ||
urls = tf_mirror_urls("https://github.com/apivovarov/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not intended to fetch it from your repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's intended that you also update the deps in the same PR, could you split it in separate PRs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is pending the merge of StableHLO openxla/stablehlo#2482 Add f8E4M3 and f8E3M4 types support (in Review).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Alexander, Integrate StableHLO at openxla/stablehlo@4f31b2e7 was werged to XLA main today. It includes float8_e4m3
type support. My temporary change in third_party/stablehlo/workspace.bzl
was removed from this PR. @mooskagh
### Summary This is a proposal to add `Float8E4M3` and `Float8E3M4` floating point types to StableHLO. Feedback welcome, see [RFC: Float8E4M3 and Float8E3M4](https://github.com/apivovarov/stablehlo/blob/rfc_f8E4M3_f8E3M4/rfcs/20240808-f8E4M3_f8E3M4.md) for more details. ### References and Links - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - [RFC: FP8 in StableHLO](https://github.com/openxla/stablehlo/blob/main/rfcs/20221031-fp8.md) - [RFC: Float8E4M3FNUZ and Float8E5M2FNUZ](https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md) - StableHLO [PR-2482](#2482) Add f8E4M3 and f8E3M4 types support - [Amazon EC2 Trn1 Instances](https://aws.amazon.com/ec2/instance-types/trn1/) - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-16585](openxla/xla#16585) Add support for float8_e4m3
This PR adds f8E4M3 and f8E3M4 types support. f8E4M3 and f8E3M4 types follow IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](#2486) [RFC] Add f8E4M3 and f8E3M4 types support - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-16585](openxla/xla#16585) Add support for float8_e4m3
98e4256
to
4804fca
Compare
I think Reed is the best person to review, I think this will require a patch on our end due to |
Hi Reed, This PR introduces support for the new f8E4M3 type, which adheres to the IEEE-754 convention. I've already added this type to the LLVM, MLIR, ml_dtypes, and StableHLO projects. This PR extends the support to XLA and includes a reference implementation for the CPU compiler. Could you please help review this PR? |
4804fca
to
9dce4c9
Compare
9dce4c9
to
88e586c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change! Sorry for the delay in reviewing.
Normally, it's better to minimize the size of PRs, but I would prefer if E3M4 is also added in the same PR, since it touches most of the same files in the exact same way as E4M3, so it makes it easy to batch review both dtypes at once.
But if adding E3M4 to the same PR is inconvenient with you, I'm fine with this being done as a separate, future PR.
xla/client/lib/math.cc
Outdated
{BF16, F16, F8E5M2, F8E4M3, F8E4M3FN, | ||
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the type list here is growing, I would avoid hardcoding the list of FP8 types. One way to avoid this is to have DoWithUpcastToF32 take a should_upcast
bool instead of the existing upcast_types
list. Then you can pass something like should_upcast = BitWidth(b.GetShape(x).element_type) <= 16
.
There are a lot of places where we list out all FP8 types, but every place we can remove listing these out will help when more types are added :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened PR #17130 - Add default upcasting behavior to DoWithUpcastToF32
xla/fp_util_test.cc
Outdated
@@ -111,6 +111,59 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest, | |||
0x1.fffffffffffffp-127, | |||
0x1.aaaaaaaaaaaaap-127)); | |||
|
|||
TEST(FPDistanceTest, F8E4M3Distance) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this test is almost identical to F8E4M3FNDistance, can you merge them to avoid duplication?
One way to to create a type-parameterized test with TYPED_TEST_P. Another way would be to have a for-loop over primtiives types F8E4M3 and F8E4M3FN, and in the body use primitive_util::PrimitiveTypeSwitch
with a lambda that does the CalculateDistanceInFloats
calls. See here for an example of how to use PrimitiveTypeSwitch
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened PR #17135 - Add TypeParam to FP8E4M3DistanceTest
} else if constexpr (std::is_integral_v<ElementwiseT>) { | ||
if constexpr (std::is_signed_v<ElementwiseT>) { | ||
if (rhs_el < static_cast<ElementwiseT>(0)) { | ||
ElementWiseBinaryOp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't change the formatting here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
restored
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
github workflow runs pipx
to check clang formatting. Opened PR #17234 Format hlo_evaluator_typed_visitor.h
xla/literal_comparison_test.cc
Outdated
@@ -25,6 +25,63 @@ limitations under the License. | |||
namespace xla { | |||
namespace { | |||
|
|||
TEST(LiteralComparisonTest, F8E4M3CompareNear_Equal) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all the FP8 tests are duplicated for each FP8 type. Can you use TYPED_TEST_P
to deduplicate them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened #17133 - Dedup LiteralComparisonTests
xla/literal_test.cc
Outdated
@@ -644,15 +648,19 @@ TEST_F(LiteralUtilTest, IsAll) { | |||
// 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false | |||
EXPECT_FALSE(LiteralUtil::CreateR1<tsl::float8_e5m2>({q16}).IsAll(9)); | |||
|
|||
tsl::float8_e4m3fn r16(9); // Exactly representable in e4m3 | |||
tsl::float8_e4m3 e4m3(9); // Exactly representable in e4m3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why there is a convention of using subsequent single letters to name the FP8 values (q16, r16, s16, etc) but you should follow it or change the convention. Either name this q16, renaming the above e5m2 to p16, or rename all the other FP8 variable names to something more descriptive, as you did for this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renamed to q16
case xla::F8E4M3: | ||
return absl::UnimplementedError("F8E4M3 not implemented"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid having to modify this every time a new FP8 type is added, remove all these FP8 cases and check if IsF8Type(literal.shape().element_type()
before the switch statement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened PR #17170 Code dedup in execution_trace_utils LiteralToValue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just FYI, this was an intentional choice. Missing switch cases are a compiler error, so having a switch without a default case is preferable when possible. No big deal though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using single-case switch statements can make it easier for the compiler to detect potential errors in the code.
Opened PR #17279 - Use switch case without default in LiteralToValue
@@ -500,6 +500,36 @@ TEST_F(FloatNormalizationTest, DoNotChangeBitcastConvert) { | |||
EXPECT_EQ(root->operand(0)->shape().element_type(), U16); | |||
} | |||
|
|||
TEST_F(FloatNormalizationTest, ResolveIfUnsupportedF8e4m3) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merge this with the existing test ResolveIfUnsupportedF8e5m2, either by looping over values (F8E4M3, F8E5M2) or by using a value-parameterized test with TEST_P.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened PR #17177 - Parametrize FloatNormalizationF8Test ResolveIfUnsupportedF8
xla/tests/constants_test.cc
Outdated
|
||
XlaBuilder builder(TestName()); | ||
auto c = ConstantR1<tsl::float8_e4m3>(&builder, constant); | ||
// F8 outputs are not yet supported so convert to F32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is actually no longer true. We should change the other tests with this comment as well. The test OneCellF8e5m2fnuz does have an FP8 output, so you can use that as an example in modifying this test.
If you want, you can also change the two existing tests with the comment to have FP8 outputs as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened PR #17182 - Parametrize ConstantsFloatTest OneCellFloat
xla/service/elemental_ir_emitter.cc
Outdated
f16_reduced = | ||
b->CreateOr(b->CreateAnd(f16_reduced, i16_const(0x9FFF)), | ||
b->CreateLShr(b->CreateAnd(f16_reduced, i16_const(0x4000)), | ||
i16_const(1))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is effectively subtracting 8 from the exponent I think, as the difference in exponent bias is 8. Why not do that subtraction directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In contrast to the EmitF16ToF8e4m3fn
function, the EmitF16ToF8e4m3
function does not include special code to handle Inf and NaN cases. (code around constexpr int max_finite_value = 0x5F7F;
)
If I use -8 approach then several tests in //xla/tests:convert_test_cpu
FAILED.
e.g.
- inf -> -1.0
- nan -> -1.5
Example:
input is inf
EmitReducePrecisionIR returns
x = 0.11111.0000000000 (0x7C00)
Option1: minus 8
x -= 0.01000.0000000000
// x is 0.10111.0000000000
// Shift to convert to F8: x is 1.0111.000
// f8e4m3 Result is -1.0 (Wrong)
Option2: Right shift E5 exponent's leftmost bit
x = (x & 0b1001'1111'1111'1111) | ((x & 0b0100'0000'0000'0000) >> 1)
// x is 0.01111.0000000000
// Shift to convert to F8: x is 0.1111.000
// f8e4m3 Result is inf (Correct)
xla/service/elemental_ir_emitter.cc
Outdated
|
||
// Set output exponent to 11111 if input exponent is 1111 (Inf or NaN) | ||
// 0.1111.000 is 0x78 | ||
// 0.11111.000000000000 is 0x7C00 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0.11111.000000000000 has 12 zeros at the end, when it should have 10.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -220,6 +220,59 @@ absl::StatusOr<llvm::Value*> EmitReducePrecisionIR( | |||
return result; | |||
} | |||
|
|||
llvm::Value* handle_halfway_points_F16ToF8(llvm::Value* f16_abs_bits, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I think you might have to modify expand_float_ops.cc, which is used by the new MLIR emitters which replace the existing emitters on GPUs. But I'm not very familiar with these new emitters. @jreiffers, can you advice on what needs to be done here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have generic support for all possible float conversions, but the emitted code might not be optimal, so it should be considered a fallback. I didn't look at these conversion routines here in detail, but if they're better, it would make sense to port them to the MLIR pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated xla/service/gpu/fusions/transforms/expand_float_ops.cc
and added f8E4M3
cases to:
- IsInf()
- IsNaN()
- RewriteF8Cst::matchAndRewrite() // If we're comparing to +-0, compare the absolute values.
expand_float_ops.cc
includes a specialized function for the f8e5m2 type - EmitF16ToF8e5m2()
. This is because F16 is technically f16e5m10
. The two types are similar, with the primary difference being that the mantissa in f8e5m2
is truncated to 2 bits.
f8E4M3
has a different number of exponent and mantissa bits. The conversion can be efficiently managed using the "generic support for all possible float conversions".
Tested xla on CUDA:
//xla/tests/... 799 tests: 799 tests pass
//xla/service/... 865 tests: 865 tests pass
//xla/client/... 77 tests: 77 tests pass
//xla/runtime/... 1 test: 1 test passes
//xla/ffi/... 6 tests: 6 tests pass
//xla/hlo/... 12 tests: 12 tests pass
//xla/mlir/... 141 tests: 141 tests pass
//xla/mlir_hlo/... 98 tests: 98 tests pass
//xla/pjrt/... 26 tests: 26 tests pass
//xla/tools/... 41 tests: 41 tests pass
//xla/translate/... 61 tests: 61 tests pass
Apologies for the delay, I'm OOO this week. Will take a look on Monday. |
ml_dtypes Updates: Add float8_e4m3 and float8_e3m4 types support Fix float divmod with zero denominator Add int2 and uint2 types ml_dtypes/commits Related PRs ml_dtypes PR Add float8_e4m3 jax-ml/ml_dtypes#161 Add float8_e4m3 (Merged) XLA PR Add support for float8_e4m3 #16585 (In Review) This closes #17075 PiperOrigin-RevId: 674396944
ml_dtypes Updates: Add float8_e4m3 and float8_e3m4 types support Fix float divmod with zero denominator Add int2 and uint2 types ml_dtypes/commits Related PRs ml_dtypes PR Add float8_e4m3 jax-ml/ml_dtypes#161 Add float8_e4m3 (Merged) XLA PR Add support for float8_e4m3 #16585 (In Review) This closes #17075 PiperOrigin-RevId: 674396944
Imported from GitHub PR openxla/xla#16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ### `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` ### Testing: ``` bazel test \ //xla:array2d_test \ //xla:fp_util_test \ //xla:literal_comparison_test \ //xla:literal_test \ //xla/mlir/utils:type_util_test \ //xla:primitive_util_test \ //xla/python/ifrt:dtype_test \ //xla/python:xla_client_test \ //xla/service:elemental_ir_emitter_test \ //xla/service:float_normalization_test \ //xla/service/gpu/tests:float_conversions_test \ //xla/tests:array_elementwise_ops_test \ //xla/tests:constants_test \ //xla/tests:convert_test \ //xla/tests:float8_test \ //xla:util_test bazel test \ //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \ //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test ``` ### Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged) - StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged) - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved) - XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template) - JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review) Copybara import of the project: -- ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>: Add support for float8_e4m3 and float8_e3m4 types Merging this change closes #16585 PiperOrigin-RevId: 681551979
PR has been merged! Reed, thank you for your help, guidance, and support! @reedwm |
No problem, and thanks for the well-tested PR! Also thank you for all the test clean up PRs! Note in merging, when converting to E3M4, I had to change the code to first convert to half to take into account we do not use an ml-dtypes version that includes jax-ml/ml_dtypes#205 yet. I added TODOs in the form of |
PR #16585: Add support for float8_e4m3 and float8_e3m4 types Imported from GitHub PR openxla/xla#16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 =... *** PiperOrigin-RevId: 681876540
PR #16585: Add support for float8_e4m3 and float8_e3m4 types Imported from GitHub PR openxla/xla#16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 =... *** PiperOrigin-RevId: 681876540
The commit hash is now 6f02f77c4fa624d8b467c36d1d959a9b49b07900, matching what is in TF at https://github.com/tensorflow/tensorflow/blob/master/third_party/py/ml_dtypes/workspace.bzl This fixes a breakage caused by openxla/xla#16585 PiperOrigin-RevId: 681951673
The commit hash is now 6f02f77c4fa624d8b467c36d1d959a9b49b07900, matching what is in TF at https://github.com/tensorflow/tensorflow/blob/master/third_party/py/ml_dtypes/workspace.bzl This fixes a breakage caused by openxla/xla#16585 PiperOrigin-RevId: 681951673
The commit hash is now 6f02f77c4fa624d8b467c36d1d959a9b49b07900, matching what is in TF at https://github.com/tensorflow/tensorflow/blob/master/third_party/py/ml_dtypes/workspace.bzl This fixes a breakage caused by openxla/xla#16585 PiperOrigin-RevId: 682038821
Imported from GitHub PR #16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](#16585) Add support for float8_e4m3 Copybara import of the project: -- 5972205 by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205 PiperOrigin-RevId: 696646489
Imported from GitHub PR openxla/xla#16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](openxla/xla#16585) Add support for float8_e4m3 Copybara import of the project: -- 59722056e36e5a0bab7736b4ad3897446861de0f by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16775 from apivovarov:elemental_ir_emitter_test 59722056e36e5a0bab7736b4ad3897446861de0f PiperOrigin-RevId: 696646489
Imported from GitHub PR #16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](#16585) Add support for float8_e4m3 Copybara import of the project: -- 5972205 by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205 PiperOrigin-RevId: 696646489
Imported from GitHub PR openxla/xla#16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](openxla/xla#16585) Add support for float8_e4m3 Copybara import of the project: -- 59722056e36e5a0bab7736b4ad3897446861de0f by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16775 from apivovarov:elemental_ir_emitter_test 59722056e36e5a0bab7736b4ad3897446861de0f PiperOrigin-RevId: 696646489
Imported from GitHub PR #16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](#16585) Add support for float8_e4m3 Copybara import of the project: -- 5972205 by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205 PiperOrigin-RevId: 696646489
Imported from GitHub PR openxla/xla#16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](openxla/xla#16585) Add support for float8_e4m3 Copybara import of the project: -- 59722056e36e5a0bab7736b4ad3897446861de0f by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16775 from apivovarov:elemental_ir_emitter_test 59722056e36e5a0bab7736b4ad3897446861de0f PiperOrigin-RevId: 696646489
Imported from GitHub PR #16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](#16585) Add support for float8_e4m3 Copybara import of the project: -- 5972205 by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205 PiperOrigin-RevId: 696646489
Imported from GitHub PR #16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](#16585) Add support for float8_e4m3 Copybara import of the project: -- 5972205 by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205 PiperOrigin-RevId: 696730664
Imported from GitHub PR openxla/xla#16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](openxla/xla#16585) Add support for float8_e4m3 Copybara import of the project: -- 59722056e36e5a0bab7736b4ad3897446861de0f by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16775 from apivovarov:elemental_ir_emitter_test 59722056e36e5a0bab7736b4ad3897446861de0f PiperOrigin-RevId: 696730664
Imported from GitHub PR #16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](#16585) Add support for float8_e4m3 Copybara import of the project: -- 5972205 by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205 PiperOrigin-RevId: 696730664
Imported from GitHub PR openxla/xla#16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](openxla/xla#16585) Add support for float8_e4m3 Copybara import of the project: -- 59722056e36e5a0bab7736b4ad3897446861de0f by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16775 from apivovarov:elemental_ir_emitter_test 59722056e36e5a0bab7736b4ad3897446861de0f PiperOrigin-RevId: 696730664
Imported from GitHub PR #16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](#16585) Add support for float8_e4m3 Copybara import of the project: -- 5972205 by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205 PiperOrigin-RevId: 696730664
Imported from GitHub PR openxla/xla#16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](openxla/xla#16585) Add support for float8_e4m3 Copybara import of the project: -- 59722056e36e5a0bab7736b4ad3897446861de0f by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16775 from apivovarov:elemental_ir_emitter_test 59722056e36e5a0bab7736b4ad3897446861de0f PiperOrigin-RevId: 696730664
Imported from GitHub PR #16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](#16585) Add support for float8_e4m3 Copybara import of the project: -- 5972205 by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205 PiperOrigin-RevId: 696792994
Imported from GitHub PR openxla/xla#16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](openxla/xla#16585) Add support for float8_e4m3 Copybara import of the project: -- 59722056e36e5a0bab7736b4ad3897446861de0f by Alexander Pivovarov <[email protected]>: Add test for EmitReducePrecisionIR Merging this change closes #16775 PiperOrigin-RevId: 696792994
This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).
f8E4M3
type follows IEEE 754 convention.f8E3M4
type follows IEEE 754 conventionTesting:
Related PRs: