From 95d52518e0cc50f67152eb8862032f36619bdfa4 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 16 Sep 2024 04:13:34 -0700 Subject: [PATCH] PR #17135: Add TypeParam to FP8E4M3DistanceTest Imported from GitHub PR https://github.com/openxla/xla/pull/17135 Add TypeParam to FP8E4M3DistanceTest This test suite will be extended with tsl::float8_e4m3 type in future Copybara import of the project: -- 8afe5a96ad239f9bd546b0575d9c85a60ad8af80 by Alexander Pivovarov : Add TypeParam to FP8E4M3DistanceTest Merging this change closes #17135 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/17135 from apivovarov:dedup_fp_util_test 8afe5a96ad239f9bd546b0575d9c85a60ad8af80 PiperOrigin-RevId: 675090814 --- xla/fp_util_test.cc | 56 ++++++++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/xla/fp_util_test.cc b/xla/fp_util_test.cc index 2ba9f90c62ae5..36f0c5be9d5bd 100644 --- a/xla/fp_util_test.cc +++ b/xla/fp_util_test.cc @@ -111,56 +111,56 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest, 0x1.fffffffffffffp-127, 0x1.aaaaaaaaaaaaap-127)); -TEST(FPDistanceTest, F8E4M3FNDistance) { - // a & b are equal - EXPECT_EQ(CalculateDistanceInFloats( - tsl::float8_e4m3fn(8.0), tsl::float8_e4m3fn(8.0)), - 0); +// Test F8E4M3 floating-point types (F8E4M3FN) +template +class FP8E4M3DistanceTest : public ::testing::Test {}; + +using F8E4M3Types = ::testing::Types; +TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types); + +TYPED_TEST(FP8E4M3DistanceTest, F8E4M3Distance) { + // a & b are equal, distance should be 0 + EXPECT_EQ( + CalculateDistanceInFloats(TypeParam(8.0), TypeParam(8.0)), 0); // a & b have the same exponents - EXPECT_EQ(CalculateDistanceInFloats( - tsl::float8_e4m3fn(8.0), tsl::float8_e4m3fn(13)), + EXPECT_EQ(CalculateDistanceInFloats(TypeParam(8.0), TypeParam(13)), 5); // a & b have different exponents - EXPECT_EQ(CalculateDistanceInFloats( - tsl::float8_e4m3fn(8.0), tsl::float8_e4m3fn(6.0)), - 4); + EXPECT_EQ( + CalculateDistanceInFloats(TypeParam(8.0), TypeParam(6.0)), 4); // 1 from 0 in the positive direction - EXPECT_EQ(CalculateDistanceInFloats( - std::numeric_limits::denorm_min(), - tsl::float8_e4m3fn(0)), + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::denorm_min(), TypeParam(0)), 1); // 1 from 0 in the negative direction - EXPECT_EQ(CalculateDistanceInFloats( - -std::numeric_limits::denorm_min(), - tsl::float8_e4m3fn(0)), + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), TypeParam(0)), 1); // a & b have different signs - EXPECT_EQ(CalculateDistanceInFloats( - -std::numeric_limits::denorm_min(), - std::numeric_limits::denorm_min()), + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + std::numeric_limits::denorm_min()), 2); // 1 non denorm from 0 in the positive direction - EXPECT_EQ(CalculateDistanceInFloats( - std::numeric_limits::min(), - tsl::float8_e4m3fn(0)), + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::min(), TypeParam(0)), 8); // 1 non denorm from 0 in the negative direction - EXPECT_EQ(CalculateDistanceInFloats( - -std::numeric_limits::min(), - tsl::float8_e4m3fn(0)), + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), TypeParam(0)), 8); // a & b have different signs - EXPECT_EQ(CalculateDistanceInFloats( - -std::numeric_limits::min(), - std::numeric_limits::min()), + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), + std::numeric_limits::min()), 16); }