Skip to content

Commit

Permalink
PR #17177: Parametrize FloatNormalizationF8Test ResolveIfUnsupportedF8
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17177

Parametrize FloatNormalizationF8Test ResolveIfUnsupportedF8

other f8 types such as F8E4M3 will reuse this test
Copybara import of the project:

--
d4d6825 by Alexander Pivovarov <[email protected]>:

Parametrize FloatNormalizationF8Test ResolveIfUnsupportedF8

Merging this change closes #17177

COPYBARA_INTEGRATE_REVIEW=#17177 from apivovarov:param_float_normalization_test d4d6825
PiperOrigin-RevId: 675091260
  • Loading branch information
apivovarov authored and Google-ML-Automation committed Sep 16, 2024
1 parent 95d5251 commit e438de6
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions xla/service/float_normalization_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ class FloatNormalizationTest : public HloTestBase {
}
};

class FloatNormalizationF8Test
: public FloatNormalizationTest,
public ::testing::WithParamInterface<PrimitiveType> {};

INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test,
::testing::Values(F8E5M2));

TEST_F(FloatNormalizationTest, NoopIfSupported) {
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Expand Down Expand Up @@ -500,10 +507,11 @@ TEST_F(FloatNormalizationTest, DoNotChangeBitcastConvert) {
EXPECT_EQ(root->operand(0)->shape().element_type(), U16);
}

TEST_F(FloatNormalizationTest, ResolveIfUnsupportedF8e5m2) {
TEST_P(FloatNormalizationF8Test, ResolveIfUnsupportedF8) {
PrimitiveType f8_type = GetParam();
auto builder = HloComputation::Builder(TestName());
Shape f16_shape = ShapeUtil::MakeShape(F16, {2, 4});
Shape f8_shape = ShapeUtil::MakeShape(F8E5M2, {2, 4});
Shape f8_shape = ShapeUtil::MakeShape(f8_type, {2, 4});

HloInstruction* a = builder.AddInstruction(
HloInstruction::CreateParameter(0, f16_shape, "a"));
Expand All @@ -521,7 +529,7 @@ TEST_F(FloatNormalizationTest, ResolveIfUnsupportedF8e5m2) {
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());

EXPECT_TRUE(Normalize(module.get(), F8E5M2, F16));
EXPECT_TRUE(Normalize(module.get(), f8_type, F16));

EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
Expand Down

0 comments on commit e438de6

Please sign in to comment.