diff --git a/xla/literal_comparison_test.cc b/xla/literal_comparison_test.cc index 04b26922d3add1..37b7c31f267104 100644 --- a/xla/literal_comparison_test.cc +++ b/xla/literal_comparison_test.cc @@ -26,173 +26,67 @@ limitations under the License. namespace xla { namespace { -TEST(LiteralComparisonTest, F8E4M3FNCompareNear_Equal) { - auto actual = - LiteralUtil::CreateR0(tsl::float8_e4m3fn(8.0)); - auto expected = - LiteralUtil::CreateR0(tsl::float8_e4m3fn(8.0)); - TF_EXPECT_OK(literal_comparison::Near(actual, expected, ErrorSpec(0.0, 0.0), - /*detailed_message=*/false, - /*miscompare_callback=*/nullptr)); -} - -TEST(LiteralComparisonTest, F8E4M3FNCompareNear_NotEqual_1ulp) { - auto actual = - LiteralUtil::CreateR0(tsl::float8_e4m3fn(8.0)); - auto expected = - LiteralUtil::CreateR0(tsl::float8_e4m3fn(9.0)); - auto error_spec = ErrorSpec(0.0, 0.0); - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, - /*detailed_message=*/false, - /*miscompare_callback=*/nullptr)); - error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3FN; - error_spec.low_precision_fp_error_spec.within_n_values = 1; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, - /*detailed_message=*/false, - /*miscompare_callback=*/nullptr)); -} - -TEST(LiteralComparisonTest, F8E4M3FNCompareNear_NotEqual_4ulps) { - auto actual = - LiteralUtil::CreateR0(tsl::float8_e4m3fn(8.0)); - auto expected = - LiteralUtil::CreateR0(tsl::float8_e4m3fn(12.0)); - auto error_spec = ErrorSpec(0.0, 0.0); - error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3FN; - error_spec.low_precision_fp_error_spec.within_n_values = 1; - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, - /*detailed_message=*/false, - /*miscompare_callback=*/nullptr)); - error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3FN; - error_spec.low_precision_fp_error_spec.within_n_values = 4; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, - /*detailed_message=*/false, - /*miscompare_callback=*/nullptr)); -} - -TEST(LiteralComparisonTest, FloatUsingF8E4M3FNCompareNear_NotEqual_4ulps) { - auto actual = LiteralUtil::CreateR0(8.0); - auto expected = LiteralUtil::CreateR0(12.1); - auto error_spec = ErrorSpec(0.0, 0.0); - error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3FN; - error_spec.low_precision_fp_error_spec.within_n_values = 1; - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, - /*detailed_message=*/false, - /*miscompare_callback=*/nullptr)); - error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3FN; - error_spec.low_precision_fp_error_spec.within_n_values = 4; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, - /*detailed_message=*/false, - /*miscompare_callback=*/nullptr)); -} - -TEST(LiteralComparisonTest, F8E5M2CompareNear_Equal) { - auto actual = LiteralUtil::CreateR0(tsl::float8_e5m2(8.0)); - auto expected = - LiteralUtil::CreateR0(tsl::float8_e5m2(8.0)); - TF_EXPECT_OK(literal_comparison::Near(actual, expected, ErrorSpec(0.0, 0.0), - /*detailed_message=*/false, - /*miscompare_callback=*/nullptr)); -} - -TEST(LiteralComparisonTest, F8E5M2CompareNear_NotEqual_1ulp) { - auto actual = LiteralUtil::CreateR0(tsl::float8_e5m2(8.0)); - auto expected = - LiteralUtil::CreateR0(tsl::float8_e5m2(10.0)); - auto error_spec = ErrorSpec(0.0, 0.0); - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, - /*detailed_message=*/false, - /*miscompare_callback=*/nullptr)); - error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E5M2; - error_spec.low_precision_fp_error_spec.within_n_values = 1; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, - /*detailed_message=*/false, - /*miscompare_callback=*/nullptr)); -} - -TEST(LiteralComparisonTest, F8E5M2CompareNear_NotEqual_4ulps) { - auto actual = LiteralUtil::CreateR0(tsl::float8_e5m2(8.0)); - auto expected = - LiteralUtil::CreateR0(tsl::float8_e5m2(14.0)); - auto error_spec = ErrorSpec(0.0, 0.0); - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, - /*detailed_message=*/false, - /*miscompare_callback=*/nullptr)); - error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E5M2; - error_spec.low_precision_fp_error_spec.within_n_values = 4; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, - /*detailed_message=*/false, - /*miscompare_callback=*/nullptr)); -} +template +class LiteralComparisonTest : public ::testing::Test {}; -TEST(LiteralComparisonTest, FloatUsingF8E5M2CompareNear_NotEqual_4ulps) { - auto actual = LiteralUtil::CreateR0(8.0); - auto expected = LiteralUtil::CreateR0(13.0); - auto error_spec = ErrorSpec(0.0, 0.0); - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, - /*detailed_message=*/false, - /*miscompare_callback=*/nullptr)); - error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E5M2; - error_spec.low_precision_fp_error_spec.within_n_values = 4; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, - /*detailed_message=*/false, - /*miscompare_callback=*/nullptr)); -} +using TestedTypes = ::testing::Types; +TYPED_TEST_SUITE(LiteralComparisonTest, TestedTypes); -TEST(LiteralComparisonTest, F8E4M3B11FNUZCompareNear_Equal) { - auto actual = LiteralUtil::CreateR0( - tsl::float8_e4m3b11fnuz(8.0)); - auto expected = LiteralUtil::CreateR0( - tsl::float8_e4m3b11fnuz(8.0)); +TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { + auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); + auto expected = LiteralUtil::CreateR0(TypeParam(8.0)); TF_EXPECT_OK(literal_comparison::Near(actual, expected, ErrorSpec(0.0, 0.0), /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } -TEST(LiteralComparisonTest, F8E4M3B11FNUZCompareNear_NotEqual_1ulp) { - auto actual = LiteralUtil::CreateR0( - tsl::float8_e4m3b11fnuz(8.0)); - auto expected = LiteralUtil::CreateR0( - tsl::float8_e4m3b11fnuz(9.0)); +TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { + PrimitiveType type = primitive_util::NativeToPrimitiveType(); + auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); + float expV = type == F8E5M2 ? 10.0 : 9.0; + auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); - error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3B11FNUZ; + error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 1; EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } -TEST(LiteralComparisonTest, F8E4M3B11FNUZCompareNear_NotEqual_4ulps) { - auto actual = LiteralUtil::CreateR0( - tsl::float8_e4m3b11fnuz(8.0)); - auto expected = LiteralUtil::CreateR0( - tsl::float8_e4m3b11fnuz(12.0)); +TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { + PrimitiveType type = primitive_util::NativeToPrimitiveType(); + auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); + float expV = type == F8E5M2 ? 14.0 : 12.0; + auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); - error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3B11FNUZ; + error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 1; EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); - error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3B11FNUZ; + error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 4; EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } -TEST(LiteralComparisonTest, FloatUsingF8E4M3B11FNUZCompareNear_NotEqual_4ulps) { +TYPED_TEST(LiteralComparisonTest, FloatUsingCompareNear_NotEqual_4ulps) { + PrimitiveType type = primitive_util::NativeToPrimitiveType(); auto actual = LiteralUtil::CreateR0(8.0); - auto expected = LiteralUtil::CreateR0(12.1); + float expV = type == F8E5M2 ? 13.0 : 12.1; + auto expected = LiteralUtil::CreateR0(expV); auto error_spec = ErrorSpec(0.0, 0.0); - error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3B11FNUZ; + error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 1; EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); - error_spec.low_precision_fp_error_spec.type = PrimitiveType::F8E4M3B11FNUZ; + error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 4; EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, /*detailed_message=*/false,