diff --git a/xla/fp_util_test.cc b/xla/fp_util_test.cc index 2ba9f90c62ae5..36f0c5be9d5bd 100644 --- a/xla/fp_util_test.cc +++ b/xla/fp_util_test.cc @@ -111,56 +111,56 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest, 0x1.fffffffffffffp-127, 0x1.aaaaaaaaaaaaap-127)); -TEST(FPDistanceTest, F8E4M3FNDistance) { - // a & b are equal - EXPECT_EQ(CalculateDistanceInFloats( - tsl::float8_e4m3fn(8.0), tsl::float8_e4m3fn(8.0)), - 0); +// Test F8E4M3 floating-point types (F8E4M3FN) +template +class FP8E4M3DistanceTest : public ::testing::Test {}; + +using F8E4M3Types = ::testing::Types; +TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types); + +TYPED_TEST(FP8E4M3DistanceTest, F8E4M3Distance) { + // a & b are equal, distance should be 0 + EXPECT_EQ( + CalculateDistanceInFloats(TypeParam(8.0), TypeParam(8.0)), 0); // a & b have the same exponents - EXPECT_EQ(CalculateDistanceInFloats( - tsl::float8_e4m3fn(8.0), tsl::float8_e4m3fn(13)), + EXPECT_EQ(CalculateDistanceInFloats(TypeParam(8.0), TypeParam(13)), 5); // a & b have different exponents - EXPECT_EQ(CalculateDistanceInFloats( - tsl::float8_e4m3fn(8.0), tsl::float8_e4m3fn(6.0)), - 4); + EXPECT_EQ( + CalculateDistanceInFloats(TypeParam(8.0), TypeParam(6.0)), 4); // 1 from 0 in the positive direction - EXPECT_EQ(CalculateDistanceInFloats( - std::numeric_limits::denorm_min(), - tsl::float8_e4m3fn(0)), + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::denorm_min(), TypeParam(0)), 1); // 1 from 0 in the negative direction - EXPECT_EQ(CalculateDistanceInFloats( - -std::numeric_limits::denorm_min(), - tsl::float8_e4m3fn(0)), + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), TypeParam(0)), 1); // a & b have different signs - EXPECT_EQ(CalculateDistanceInFloats( - -std::numeric_limits::denorm_min(), - std::numeric_limits::denorm_min()), + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + std::numeric_limits::denorm_min()), 2); // 1 non denorm from 0 in the positive direction - EXPECT_EQ(CalculateDistanceInFloats( - std::numeric_limits::min(), - tsl::float8_e4m3fn(0)), + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::min(), TypeParam(0)), 8); // 1 non denorm from 0 in the negative direction - EXPECT_EQ(CalculateDistanceInFloats( - -std::numeric_limits::min(), - tsl::float8_e4m3fn(0)), + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), TypeParam(0)), 8); // a & b have different signs - EXPECT_EQ(CalculateDistanceInFloats( - -std::numeric_limits::min(), - std::numeric_limits::min()), + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), + std::numeric_limits::min()), 16); }