From d5b8b575e453e3c0fe05b4eeb993e6f1253d490a Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 16 Sep 2024 03:56:46 -0700 Subject: [PATCH] PR #17182: Parametrize ConstantsFloatTest OneCellFloat Imported from GitHub PR https://github.com/openxla/xla/pull/17182 Parametrize ConstantsFloatTest OneCellFloat Copybara import of the project: -- ad01a3eadc77385968af54f27685133caca5e8f9 by Alexander Pivovarov : Parametrize ConstantsFloatTest OneCellFloat Merging this change closes #17182 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/17182 from apivovarov:param_constants_test ad01a3eadc77385968af54f27685133caca5e8f9 PiperOrigin-RevId: 675086252 --- xla/tests/constants_test.cc | 75 ++++++++----------------------------- 1 file changed, 16 insertions(+), 59 deletions(-) diff --git a/xla/tests/constants_test.cc b/xla/tests/constants_test.cc index 26407b4279052..1a462f7fbad51 100644 --- a/xla/tests/constants_test.cc +++ b/xla/tests/constants_test.cc @@ -43,6 +43,16 @@ class ConstantsTest : public ClientLibraryTestBase { const ErrorSpec error_spec_{1e-3, 1e-5}; }; +template +class ConstantsFloatTest : public ConstantsTest {}; + +typedef ::testing::Types + FloatTypes; + +TYPED_TEST_SUITE(ConstantsFloatTest, FloatTypes); + TEST_F(ConstantsTest, ZeroCellF32) { XlaBuilder builder(TestName()); ConstantR1(&builder, {}); @@ -50,13 +60,14 @@ TEST_F(ConstantsTest, ZeroCellF32) { ComputeAndCompareR1(&builder, {}, {}, error_spec_); } -TEST_F(ConstantsTest, OneCellF32) { - std::vector constant = {2.0}; +TYPED_TEST(ConstantsFloatTest, OneCellFloat) { + std::vector constant = {TypeParam{2.0}}; - XlaBuilder builder(TestName()); - ConstantR1(&builder, constant); + XlaBuilder builder(ClientLibraryTestBase::TestName()); + ConstantR1(&builder, constant); - ComputeAndCompareR1(&builder, constant, {}, error_spec_); + ClientLibraryTestBase::ComputeAndCompareR1(&builder, constant, {}, + this->error_spec_); } TEST_F(ConstantsTest, OneCellS32) { @@ -99,60 +110,6 @@ TEST_F(ConstantsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(OneCellS4))) { ComputeAndCompareR1(&builder, {-2}, {}); } -TEST_F(ConstantsTest, OneCellF16) { - std::vector constant = {half{2.0}}; - - XlaBuilder builder(TestName()); - auto c = ConstantR1(&builder, constant); - // F16 outputs are not yet supported so convert to F32 - ConvertElementType(c, F32); - - ComputeAndCompareR1(&builder, {2.0f}, {}, error_spec_); -} - -TEST_F(ConstantsTest, OneCellF8e5m2) { - std::vector constant = {tsl::float8_e5m2{2.0}}; - - XlaBuilder builder(TestName()); - auto c = ConstantR1(&builder, constant); - // F8 outputs are not yet supported so convert to F32 - ConvertElementType(c, F32); - - ComputeAndCompareR1(&builder, {2.0f}, {}, error_spec_); -} - -TEST_F(ConstantsTest, OneCellF8e4m3b11fnuz) { - std::vector constant = { - tsl::float8_e4m3b11fnuz{2.0}}; - - XlaBuilder builder(TestName()); - auto c = ConstantR1(&builder, constant); - // F8 outputs are not yet supported so convert to F32 - ConvertElementType(c, F32); - - ComputeAndCompareR1(&builder, {2.0f}, {}, error_spec_); -} - -TEST_F(ConstantsTest, OneCellF8e5m2fnuz) { - std::vector constant = {tsl::float8_e5m2fnuz{2.0}}; - - XlaBuilder builder(TestName()); - ConstantR1(&builder, constant); - - ComputeAndCompareR1(&builder, constant, {}, - error_spec_); -} - -TEST_F(ConstantsTest, OneCellF8e4m3fnuz) { - std::vector constant = {tsl::float8_e4m3fnuz{2.0}}; - - XlaBuilder builder(TestName()); - ConstantR1(&builder, constant); - - ComputeAndCompareR1(&builder, constant, {}, - error_spec_); -} - TEST_F(ConstantsTest, EightCells) { std::vector constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};