Skip to content

Commit

Permalink
PR #17133: Dedup LiteralComparisonTests
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17133

Deduplicate LiteralComparisonTests using TYPED_TEST
Copybara import of the project:

--
6ee6c75 by Alexander Pivovarov <[email protected]>:

Dedup LiteralComparisonTests

Merging this change closes #17133

COPYBARA_INTEGRATE_REVIEW=#17133 from apivovarov:dedup_literal_comparison_test 6ee6c75
PiperOrigin-RevId: 675895660
  • Loading branch information
apivovarov authored and Google-ML-Automation committed Sep 18, 2024
1 parent 6d8bdd0 commit 9211481
Showing 1 changed file with 27 additions and 133 deletions.
160 changes: 27 additions & 133 deletions xla/literal_comparison_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,173 +26,67 @@ limitations under the License.
namespace xla {
namespace {

TEST(LiteralComparisonTest, F8E4M3FNCompareNear_Equal) {
auto actual =
LiteralUtil::CreateR0<tsl::float8_e4m3fn>(tsl::float8_e4m3fn(8.0));
auto expected =
LiteralUtil::CreateR0<tsl::float8_e4m3fn>(tsl::float8_e4m3fn(8.0));
TF_EXPECT_OK(literal_comparison::Near(actual, expected, ErrorSpec(0.0, 0.0),
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
}

TEST(LiteralComparisonTest, F8E4M3FNCompareNear_NotEqual_1ulp) {
auto actual =
LiteralUtil::CreateR0<tsl::float8_e4m3fn>(tsl::float8_e4m3fn(8.0));
auto expected =
LiteralUtil::CreateR0<tsl::float8_e4m3fn>(tsl::float8_e4m3fn(9.0));
auto error_spec = ErrorSpec(0.0, 0.0);
EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3FN;
error_spec.low_precision_fp_error_spec.within_n_values = 1;
EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
}

TEST(LiteralComparisonTest, F8E4M3FNCompareNear_NotEqual_4ulps) {
auto actual =
LiteralUtil::CreateR0<tsl::float8_e4m3fn>(tsl::float8_e4m3fn(8.0));
auto expected =
LiteralUtil::CreateR0<tsl::float8_e4m3fn>(tsl::float8_e4m3fn(12.0));
auto error_spec = ErrorSpec(0.0, 0.0);
error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3FN;
error_spec.low_precision_fp_error_spec.within_n_values = 1;
EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3FN;
error_spec.low_precision_fp_error_spec.within_n_values = 4;
EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
}

TEST(LiteralComparisonTest, FloatUsingF8E4M3FNCompareNear_NotEqual_4ulps) {
auto actual = LiteralUtil::CreateR0<float>(8.0);
auto expected = LiteralUtil::CreateR0<float>(12.1);
auto error_spec = ErrorSpec(0.0, 0.0);
error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3FN;
error_spec.low_precision_fp_error_spec.within_n_values = 1;
EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3FN;
error_spec.low_precision_fp_error_spec.within_n_values = 4;
EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
}

TEST(LiteralComparisonTest, F8E5M2CompareNear_Equal) {
auto actual = LiteralUtil::CreateR0<tsl::float8_e5m2>(tsl::float8_e5m2(8.0));
auto expected =
LiteralUtil::CreateR0<tsl::float8_e5m2>(tsl::float8_e5m2(8.0));
TF_EXPECT_OK(literal_comparison::Near(actual, expected, ErrorSpec(0.0, 0.0),
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
}

TEST(LiteralComparisonTest, F8E5M2CompareNear_NotEqual_1ulp) {
auto actual = LiteralUtil::CreateR0<tsl::float8_e5m2>(tsl::float8_e5m2(8.0));
auto expected =
LiteralUtil::CreateR0<tsl::float8_e5m2>(tsl::float8_e5m2(10.0));
auto error_spec = ErrorSpec(0.0, 0.0);
EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E5M2;
error_spec.low_precision_fp_error_spec.within_n_values = 1;
EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
}

TEST(LiteralComparisonTest, F8E5M2CompareNear_NotEqual_4ulps) {
auto actual = LiteralUtil::CreateR0<tsl::float8_e5m2>(tsl::float8_e5m2(8.0));
auto expected =
LiteralUtil::CreateR0<tsl::float8_e5m2>(tsl::float8_e5m2(14.0));
auto error_spec = ErrorSpec(0.0, 0.0);
EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E5M2;
error_spec.low_precision_fp_error_spec.within_n_values = 4;
EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
}
template <typename T>
class LiteralComparisonTest : public ::testing::Test {};

TEST(LiteralComparisonTest, FloatUsingF8E5M2CompareNear_NotEqual_4ulps) {
auto actual = LiteralUtil::CreateR0<float>(8.0);
auto expected = LiteralUtil::CreateR0<float>(13.0);
auto error_spec = ErrorSpec(0.0, 0.0);
EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E5M2;
error_spec.low_precision_fp_error_spec.within_n_values = 4;
EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
}
using TestedTypes = ::testing::Types<tsl::float8_e4m3fn,
tsl::float8_e4m3b11fnuz, tsl::float8_e5m2>;
TYPED_TEST_SUITE(LiteralComparisonTest, TestedTypes);

TEST(LiteralComparisonTest, F8E4M3B11FNUZCompareNear_Equal) {
auto actual = LiteralUtil::CreateR0<tsl::float8_e4m3b11fnuz>(
tsl::float8_e4m3b11fnuz(8.0));
auto expected = LiteralUtil::CreateR0<tsl::float8_e4m3b11fnuz>(
tsl::float8_e4m3b11fnuz(8.0));
TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) {
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(8.0));
auto expected = LiteralUtil::CreateR0<TypeParam>(TypeParam(8.0));
TF_EXPECT_OK(literal_comparison::Near(actual, expected, ErrorSpec(0.0, 0.0),
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
}

TEST(LiteralComparisonTest, F8E4M3B11FNUZCompareNear_NotEqual_1ulp) {
auto actual = LiteralUtil::CreateR0<tsl::float8_e4m3b11fnuz>(
tsl::float8_e4m3b11fnuz(8.0));
auto expected = LiteralUtil::CreateR0<tsl::float8_e4m3b11fnuz>(
tsl::float8_e4m3b11fnuz(9.0));
TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) {
PrimitiveType type = primitive_util::NativeToPrimitiveType<TypeParam>();
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(8.0));
float expV = type == F8E5M2 ? 10.0 : 9.0;
auto expected = LiteralUtil::CreateR0<TypeParam>(TypeParam{expV});
auto error_spec = ErrorSpec(0.0, 0.0);
EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3B11FNUZ;
error_spec.low_precision_fp_error_spec.type = type;
error_spec.low_precision_fp_error_spec.within_n_values = 1;
EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
}

TEST(LiteralComparisonTest, F8E4M3B11FNUZCompareNear_NotEqual_4ulps) {
auto actual = LiteralUtil::CreateR0<tsl::float8_e4m3b11fnuz>(
tsl::float8_e4m3b11fnuz(8.0));
auto expected = LiteralUtil::CreateR0<tsl::float8_e4m3b11fnuz>(
tsl::float8_e4m3b11fnuz(12.0));
TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) {
PrimitiveType type = primitive_util::NativeToPrimitiveType<TypeParam>();
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(8.0));
float expV = type == F8E5M2 ? 14.0 : 12.0;
auto expected = LiteralUtil::CreateR0<TypeParam>(TypeParam{expV});
auto error_spec = ErrorSpec(0.0, 0.0);
error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3B11FNUZ;
error_spec.low_precision_fp_error_spec.type = type;
error_spec.low_precision_fp_error_spec.within_n_values = 1;
EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3B11FNUZ;
error_spec.low_precision_fp_error_spec.type = type;
error_spec.low_precision_fp_error_spec.within_n_values = 4;
EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
}

TEST(LiteralComparisonTest, FloatUsingF8E4M3B11FNUZCompareNear_NotEqual_4ulps) {
TYPED_TEST(LiteralComparisonTest, FloatUsingCompareNear_NotEqual_4ulps) {
PrimitiveType type = primitive_util::NativeToPrimitiveType<TypeParam>();
auto actual = LiteralUtil::CreateR0<float>(8.0);
auto expected = LiteralUtil::CreateR0<float>(12.1);
float expV = type == F8E5M2 ? 13.0 : 12.1;
auto expected = LiteralUtil::CreateR0<float>(expV);
auto error_spec = ErrorSpec(0.0, 0.0);
error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3B11FNUZ;
error_spec.low_precision_fp_error_spec.type = type;
error_spec.low_precision_fp_error_spec.within_n_values = 1;
EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
/*miscompare_callback=*/nullptr));
error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3B11FNUZ;
error_spec.low_precision_fp_error_spec.type = type;
error_spec.low_precision_fp_error_spec.within_n_values = 4;
EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec,
/*detailed_message=*/false,
Expand Down

0 comments on commit 9211481

Please sign in to comment.