Skip to content

Commit

Permalink
PR #17182: Parametrize ConstantsFloatTest OneCellFloat
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17182

Parametrize ConstantsFloatTest OneCellFloat

Copybara import of the project:

--
ad01a3e by Alexander Pivovarov <[email protected]>:

Parametrize ConstantsFloatTest OneCellFloat

Merging this change closes #17182

COPYBARA_INTEGRATE_REVIEW=#17182 from apivovarov:param_constants_test ad01a3e
PiperOrigin-RevId: 675086252
  • Loading branch information
apivovarov authored and Google-ML-Automation committed Sep 16, 2024
1 parent f96a975 commit d5b8b57
Showing 1 changed file with 16 additions and 59 deletions.
75 changes: 16 additions & 59 deletions xla/tests/constants_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,31 @@ class ConstantsTest : public ClientLibraryTestBase {
const ErrorSpec error_spec_{1e-3, 1e-5};
};

template <typename T>
class ConstantsFloatTest : public ConstantsTest {};

typedef ::testing::Types<float, half, tsl::float8_e4m3fn,
tsl::float8_e4m3b11fnuz, tsl::float8_e4m3fnuz,
tsl::float8_e5m2, tsl::float8_e5m2fnuz>
FloatTypes;

TYPED_TEST_SUITE(ConstantsFloatTest, FloatTypes);

TEST_F(ConstantsTest, ZeroCellF32) {
XlaBuilder builder(TestName());
ConstantR1<float>(&builder, {});

ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}

TEST_F(ConstantsTest, OneCellF32) {
std::vector<float> constant = {2.0};
TYPED_TEST(ConstantsFloatTest, OneCellFloat) {
std::vector<TypeParam> constant = {TypeParam{2.0}};

XlaBuilder builder(TestName());
ConstantR1<float>(&builder, constant);
XlaBuilder builder(ClientLibraryTestBase::TestName());
ConstantR1<TypeParam>(&builder, constant);

ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
ClientLibraryTestBase::ComputeAndCompareR1<TypeParam>(&builder, constant, {},
this->error_spec_);
}

TEST_F(ConstantsTest, OneCellS32) {
Expand Down Expand Up @@ -99,60 +110,6 @@ TEST_F(ConstantsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(OneCellS4))) {
ComputeAndCompareR1<int8_t>(&builder, {-2}, {});
}

TEST_F(ConstantsTest, OneCellF16) {
std::vector<half> constant = {half{2.0}};

XlaBuilder builder(TestName());
auto c = ConstantR1<half>(&builder, constant);
// F16 outputs are not yet supported so convert to F32
ConvertElementType(c, F32);

ComputeAndCompareR1<float>(&builder, {2.0f}, {}, error_spec_);
}

TEST_F(ConstantsTest, OneCellF8e5m2) {
std::vector<tsl::float8_e5m2> constant = {tsl::float8_e5m2{2.0}};

XlaBuilder builder(TestName());
auto c = ConstantR1<tsl::float8_e5m2>(&builder, constant);
// F8 outputs are not yet supported so convert to F32
ConvertElementType(c, F32);

ComputeAndCompareR1<float>(&builder, {2.0f}, {}, error_spec_);
}

TEST_F(ConstantsTest, OneCellF8e4m3b11fnuz) {
std::vector<tsl::float8_e4m3b11fnuz> constant = {
tsl::float8_e4m3b11fnuz{2.0}};

XlaBuilder builder(TestName());
auto c = ConstantR1<tsl::float8_e4m3b11fnuz>(&builder, constant);
// F8 outputs are not yet supported so convert to F32
ConvertElementType(c, F32);

ComputeAndCompareR1<float>(&builder, {2.0f}, {}, error_spec_);
}

TEST_F(ConstantsTest, OneCellF8e5m2fnuz) {
std::vector<tsl::float8_e5m2fnuz> constant = {tsl::float8_e5m2fnuz{2.0}};

XlaBuilder builder(TestName());
ConstantR1<tsl::float8_e5m2fnuz>(&builder, constant);

ComputeAndCompareR1<tsl::float8_e5m2fnuz>(&builder, constant, {},
error_spec_);
}

TEST_F(ConstantsTest, OneCellF8e4m3fnuz) {
std::vector<tsl::float8_e4m3fnuz> constant = {tsl::float8_e4m3fnuz{2.0}};

XlaBuilder builder(TestName());
ConstantR1<tsl::float8_e4m3fnuz>(&builder, constant);

ComputeAndCompareR1<tsl::float8_e4m3fnuz>(&builder, constant, {},
error_spec_);
}

TEST_F(ConstantsTest, EightCells) {
std::vector<float> constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};

Expand Down

0 comments on commit d5b8b57

Please sign in to comment.