From 521cd7bda50cd16f53a081fd36553872fdae3151 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Fri, 15 Nov 2024 00:39:59 -0800 Subject: [PATCH] PR #16775: Add test for EmitReducePrecisionIR Imported from GitHub PR https://github.com/openxla/xla/pull/16775 I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests. Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases. Changes in this PR: - Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h` - Add `EmitReducePrecisionIR_F16ToF8e5m2` test - Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test Related PR: - [PR-16585](https://github.com/openxla/xla/pull/16585) Add support for float8_e4m3 Copybara import of the project: -- 59722056e36e5a0bab7736b4ad3897446861de0f by Alexander Pivovarov : Add test for EmitReducePrecisionIR Merging this change closes #16775 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/16775 from apivovarov:elemental_ir_emitter_test 59722056e36e5a0bab7736b4ad3897446861de0f PiperOrigin-RevId: 696792994 --- xla/service/BUILD | 1 + xla/service/elemental_ir_emitter.cc | 4 +- xla/service/elemental_ir_emitter.h | 4 + xla/service/elemental_ir_emitter_test.cc | 192 +++++++++++++++++++++++ 4 files changed, 199 insertions(+), 2 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index 84ed687470769..5874dcfe17d8a 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -4287,6 +4287,7 @@ xla_test( "//xla:types", "//xla/hlo/ir:hlo", "//xla/service/llvm_ir:ir_array", + "//xla/service/llvm_ir:llvm_util", "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index 3372cd0a5540a..d1276e1717bab 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -86,8 +86,6 @@ using llvm_ir::SetToFirstInsertPoint; using xla::float8_fnuz_ir_emitter::EmitF8fnuzToFloating; using xla::float8_fnuz_ir_emitter::EmitFloatingToF8fnuz; -namespace { - absl::StatusOr EmitReducePrecisionIR( PrimitiveType src_ty, llvm::Value* x, int64_t dest_exponent_bits, int64_t dest_mantissa_bits, bool quiet_nans, llvm::IRBuilderBase* b) { @@ -231,6 +229,8 @@ absl::StatusOr EmitReducePrecisionIR( return result; } +namespace { + template llvm::Value* handle_halfway_points_F16ToF8(llvm::Value* f16_abs_bits, llvm::Value* f8_bits, diff --git a/xla/service/elemental_ir_emitter.h b/xla/service/elemental_ir_emitter.h index 9bf393670a655..b45c7e0d8265a 100644 --- a/xla/service/elemental_ir_emitter.h +++ b/xla/service/elemental_ir_emitter.h @@ -352,6 +352,10 @@ class ElementalIrEmitterForTests : public ElementalIrEmitter { HloToElementGeneratorMap generator_map_; }; +absl::StatusOr EmitReducePrecisionIR( + PrimitiveType src_ty, llvm::Value* x, int64_t dest_exponent_bits, + int64_t dest_mantissa_bits, bool quiet_nans, llvm::IRBuilderBase* b); + } // namespace xla #endif // XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ diff --git a/xla/service/elemental_ir_emitter_test.cc b/xla/service/elemental_ir_emitter_test.cc index c0a0ddea66f92..c1edb9a4b856d 100644 --- a/xla/service/elemental_ir_emitter_test.cc +++ b/xla/service/elemental_ir_emitter_test.cc @@ -36,6 +36,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/hlo_module_config.h" #include "xla/service/llvm_ir/ir_array.h" +#include "xla/service/llvm_ir/llvm_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" @@ -48,6 +49,11 @@ namespace { using std::nullopt; +struct EmitReducePrecisionIrTestCase { + float input; + std::string expected_res; +}; + class ElementalIrEmitterExecutionTest : public HloTestBase { protected: void RunTest(const std::string& hlo_text, absl::Span args) { @@ -123,6 +129,192 @@ ENTRY main { RunTest(hlo_text, {&lhs, &rhs}); } +XLA_TEST_F(ElementalIrEmitterExecutionTest, EmitReducePrecisionIR_F16ToF8e5m2) { + llvm::LLVMContext llvm_context; + llvm::IRBuilder<> builder(llvm_context); + llvm::IRBuilderBase* b = &builder; + llvm::Type* f16_type = b->getHalfTy(); + + float inf = std::numeric_limits::infinity(); + float qnan = std::numeric_limits::quiet_NaN(); + float snan = std::numeric_limits::signaling_NaN(); + + EmitReducePrecisionIrTestCase test_cases[] = { + // clang-format off + {0.0, "half 0xH0000"}, + {0x1.0p-14, "half 0xH0400"}, + {0.250, "half 0xH3400"}, + {1.0, "half 0xH3C00"}, + {0x1.2p0, "half 0xH3C00"}, + {0x1.Cp15, "half 0xH7B00"}, + {-0x1.Cp15, "half 0xHFB00"}, + {0x1.Dp15, "half 0xH7B00"}, + {0x1.Ep15, "half 0xH7C00"}, + {0x1.0p16, "half 0xH7C00"}, + {inf, "half 0xH7C00"}, + {-inf, "half 0xHFC00"}, + {qnan, "half 0xH7E00"}, + {-qnan, "half 0xHFE00"}, + {snan, "half 0xH7F00"}, + {-snan, "half 0xHFF00"}, + // clang-format on + }; + + for (auto tc : test_cases) { + llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input); + + absl::StatusOr f16_reduced_statusor = EmitReducePrecisionIR( + /*src_ty=*/F16, c0, + /*dest_exponent_bits=*/primitive_util::ExponentWidth(F8E5M2), + /*dest_mantissa_bits=*/primitive_util::SignificandWidth(F8E5M2) - 1, + /*quiet_nans=*/true, b); + CHECK(f16_reduced_statusor.ok()); + llvm::Value* f16_reduced = f16_reduced_statusor.value(); + + std::string res = llvm_ir::DumpToString(f16_reduced); + EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input; + } +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, EmitReducePrecisionIR_F16ToF8e4m3) { + llvm::LLVMContext llvm_context; + llvm::IRBuilder<> builder(llvm_context); + llvm::IRBuilderBase* b = &builder; + llvm::Type* f16_type = b->getHalfTy(); + + float inf = std::numeric_limits::infinity(); + float qnan = std::numeric_limits::quiet_NaN(); + float snan = std::numeric_limits::signaling_NaN(); + + EmitReducePrecisionIrTestCase test_cases[] = { + // clang-format off + {0.0, "half 0xH0000"}, + {0x1.0p-6, "half 0xH2400"}, + {0.125, "half 0xH3000"}, + {1.0, "half 0xH3C00"}, + {0x1.1p0, "half 0xH3C00"}, + {0x1.Ep7, "half 0xH5B80"}, + {-0x1.Ep7, "half 0xHDB80"}, + {0x1.E8p7, "half 0xH5B80"}, + {0x1.Fp7, "half 0xH7C00"}, + {0x1.0p8, "half 0xH7C00"}, + {inf, "half 0xH7C00"}, + {-inf, "half 0xHFC00"}, + {qnan, "half 0xH7E00"}, + {-qnan, "half 0xHFE00"}, + {snan, "half 0xH7E00"}, + {-snan, "half 0xHFE00"}, + // clang-format on + }; + + for (auto tc : test_cases) { + llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input); + + absl::StatusOr f16_reduced_statusor = EmitReducePrecisionIR( + /*src_ty=*/F16, c0, + /*dest_exponent_bits=*/4, + /*dest_mantissa_bits=*/3, + /*quiet_nans=*/true, b); + CHECK(f16_reduced_statusor.ok()); + llvm::Value* f16_reduced = f16_reduced_statusor.value(); + + std::string res = llvm_ir::DumpToString(f16_reduced); + EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input; + } +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, EmitReducePrecisionIR_F16ToF8e3m4) { + llvm::LLVMContext llvm_context; + llvm::IRBuilder<> builder(llvm_context); + llvm::IRBuilderBase* b = &builder; + llvm::Type* f16_type = b->getHalfTy(); + + float inf = std::numeric_limits::infinity(); + float qnan = std::numeric_limits::quiet_NaN(); + float snan = std::numeric_limits::signaling_NaN(); + + EmitReducePrecisionIrTestCase test_cases[] = { + // clang-format off + {0.0, "half 0xH0000"}, + {0x1.0p-2, "half 0xH3400"}, + {0.5, "half 0xH3800"}, + {1.0, "half 0xH3C00"}, + {0x1.08p0, "half 0xH3C00"}, + {0x1.Fp3, "half 0xH4BC0"}, + {-0x1.Fp3, "half 0xHCBC0"}, + {0x1.F4p3, "half 0xH4BC0"}, + {0x1.F8p3, "half 0xH7C00"}, + {0x1.0p4, "half 0xH7C00"}, + {inf, "half 0xH7C00"}, + {-inf, "half 0xHFC00"}, + {qnan, "half 0xH7E00"}, + {-qnan, "half 0xHFE00"}, + {snan, "half 0xH7E00"}, + {-snan, "half 0xHFE00"}, + // clang-format on + }; + + for (auto tc : test_cases) { + llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input); + + absl::StatusOr f16_reduced_statusor = EmitReducePrecisionIR( + /*src_ty=*/F16, c0, + /*dest_exponent_bits=*/3, + /*dest_mantissa_bits=*/4, + /*quiet_nans=*/true, b); + CHECK(f16_reduced_statusor.ok()); + llvm::Value* f16_reduced = f16_reduced_statusor.value(); + + std::string res = llvm_ir::DumpToString(f16_reduced); + EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input; + } +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, + EmitReducePrecisionIR_F16ToF8e4m3fn) { + llvm::LLVMContext llvm_context; + llvm::IRBuilder<> builder(llvm_context); + llvm::IRBuilderBase* b = &builder; + llvm::Type* f16_type = b->getHalfTy(); + + float inf = std::numeric_limits::infinity(); + + EmitReducePrecisionIrTestCase test_cases[] = { + // clang-format off + {0.0, "half 0xH0000"}, + {0x1.0p-6, "half 0xH2400"}, + {0.125, "half 0xH3000"}, + {1.0, "half 0xH3C00"}, + {0x1.1p0, "half 0xH3C00"}, + {0x1.Cp8, "half 0xH5F00"}, + {-0x1.Cp8, "half 0xHDF00"}, + {0x1.Dp8, "half 0xH5F00"}, + {0x1.Ep8, "half 0xH5F80"}, + {0x1.0p9, "half 0xH6000"}, + {inf, "half 0xH7C00"}, + {-inf, "half 0xHFC00"}, + // clang-format on + }; + + for (auto tc : test_cases) { + llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input); + + // Truncate the mantissa to 3 bits. ReducePrecision cannot deal with + // f8E4M3FN's NaN representations, so don't use ReducePrecision to handle + // exponent reduction. + absl::StatusOr f16_reduced_statusor = EmitReducePrecisionIR( + /*src_ty=*/F16, c0, + /*dest_exponent_bits=*/5, + /*dest_mantissa_bits=*/3, + /*quiet_nans=*/false, b); + CHECK(f16_reduced_statusor.ok()); + llvm::Value* f16_reduced = f16_reduced_statusor.value(); + + std::string res = llvm_ir::DumpToString(f16_reduced); + EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input; + } +} + XLA_TEST_F(ElementalIrEmitterExecutionTest, ScalarDotFusion) { const char* hlo_text = R"( HloModule ScalarDotFusion