diff --git a/xla/service/float_normalization_test.cc b/xla/service/float_normalization_test.cc index 8476b38a3b859..a140d2e933af9 100644 --- a/xla/service/float_normalization_test.cc +++ b/xla/service/float_normalization_test.cc @@ -139,6 +139,13 @@ class FloatNormalizationTest : public HloTestBase { } }; +class FloatNormalizationF8Test + : public FloatNormalizationTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test, + ::testing::Values(F8E5M2)); + TEST_F(FloatNormalizationTest, NoopIfSupported) { auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -500,10 +507,11 @@ TEST_F(FloatNormalizationTest, DoNotChangeBitcastConvert) { EXPECT_EQ(root->operand(0)->shape().element_type(), U16); } -TEST_F(FloatNormalizationTest, ResolveIfUnsupportedF8e5m2) { +TEST_P(FloatNormalizationF8Test, ResolveIfUnsupportedF8) { + PrimitiveType f8_type = GetParam(); auto builder = HloComputation::Builder(TestName()); Shape f16_shape = ShapeUtil::MakeShape(F16, {2, 4}); - Shape f8_shape = ShapeUtil::MakeShape(F8E5M2, {2, 4}); + Shape f8_shape = ShapeUtil::MakeShape(f8_type, {2, 4}); HloInstruction* a = builder.AddInstruction( HloInstruction::CreateParameter(0, f16_shape, "a")); @@ -521,7 +529,7 @@ TEST_F(FloatNormalizationTest, ResolveIfUnsupportedF8e5m2) { auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get(), F8E5M2, F16)); + EXPECT_TRUE(Normalize(module.get(), f8_type, F16)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), mul1);