diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 7493c4e558..b6885808ce 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -126,9 +126,7 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT) // unused { - ASSERT( - !(((x != y) && (worksize < 2 * (m + n) * sizeof(AccT))) || (worksize < 2 * m * sizeof(AccT))), - "workspace size error"); + ASSERT(!(worksize < 2 * (m + n) * sizeof(AccT)), "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -137,9 +135,27 @@ void distance_impl(raft::resources const& handle, AccT* y_norm = workspace; AccT* sq_x_norm = workspace; AccT* sq_y_norm = workspace; - if (x != y) { + // TODO: Column major case looks to have lower accuracy for X == Y, + // perhaps the use of stridedSummationKernel could be causing this, + // need to investigate and fix. + if (x == y && is_row_major) { + raft::linalg::reduce(x_norm, + x, + k, + std::max(m, n), + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); + sq_x_norm += std::max(m, n); + sq_y_norm = sq_x_norm; + raft::linalg::rowNorm( + sq_x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream); + } else { y_norm += m; - raft::linalg::reduce(x_norm, x, k, @@ -167,21 +183,6 @@ void distance_impl(raft::resources const& handle, sq_y_norm = sq_x_norm + m; raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); raft::linalg::rowNorm(sq_y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream); - } else { - raft::linalg::reduce(x_norm, - x, - k, - m, - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - sq_x_norm += m; - sq_y_norm = sq_x_norm; - raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); } using OpT = ops::correlation_distance_op; @@ -210,23 +211,25 @@ void distance_impl(raft::resources const& handle, "OutT can be uint8_t, float, double," "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); - ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), - "workspace size error"); + ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); cudaStream_t stream = raft::resource::get_cuda_stream(handle); DataT* x_norm = workspace; DataT* y_norm = workspace; - if (x != y) { + // TODO: Column major case looks to have lower accuracy for X == Y, + // perhaps the use of stridedSummationKernel could be causing this, + // need to investigate and fix. + if (x == y && is_row_major) { + raft::linalg::rowNorm( + x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + } else { y_norm += m; raft::linalg::rowNorm( x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); raft::linalg::rowNorm( y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); - } else { - raft::linalg::rowNorm( - x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } ops::cosine_distance_op distance_op{}; @@ -453,21 +456,29 @@ void distance_impl_l2_expanded( // NOTE: different name "OutT can be uint8_t, float, double," "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); - ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), - "workspace size error"); + ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); DataT* x_norm = workspace; DataT* y_norm = workspace; - if (x != y) { + // TODO: Column major case looks to have lower accuracy for X == Y, + // perhaps the use of stridedSummationKernel could be causing this, + // need to investigate and fix. + if ((x == y) && is_row_major) { + raft::linalg::rowNorm(x_norm, + x, + k, + std::max(m, n), + raft::linalg::L2Norm, + is_row_major, + stream, + raft::identity_op{}); + } else { y_norm += m; raft::linalg::rowNorm( x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); raft::linalg::rowNorm( y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); - } else { - raft::linalg::rowNorm( - x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); } ops::l2_exp_distance_op distance_op{perform_sqrt}; @@ -789,8 +800,10 @@ size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, In (distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1; if (is_allocated) { + // TODO : when X == Y allocate std::max(m, n) instead of m + n when column major input + // accuracy issue is resolved until then we allocate as m + n. worksize += numOfBuffers * m * sizeof(AccType); - if (x != y) worksize += numOfBuffers * n * sizeof(AccType); + worksize += numOfBuffers * n * sizeof(AccType); } return worksize; diff --git a/cpp/test/distance/dist_correlation.cu b/cpp/test/distance/dist_correlation.cu index fc729dec1c..aa2866483a 100644 --- a/cpp/test/distance/dist_correlation.cu +++ b/cpp/test/distance/dist_correlation.cu @@ -24,6 +24,10 @@ template class DistanceCorrelation : public DistanceTest {}; +template +class DistanceCorrelationXequalY + : public DistanceTestSameBuffer {}; + const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, {0.001f, 1024, 32, 1024, true, 1234ULL}, @@ -44,6 +48,25 @@ TEST_P(DistanceCorrelationF, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationF, ::testing::ValuesIn(inputsf)); +typedef DistanceCorrelationXequalY DistanceCorrelationXequalYF; +TEST_P(DistanceCorrelationXequalYF, Result) +{ + int m = params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(), + dist[0].data(), + m, + m, + raft::CompareApprox(params.tolerance), + stream)); + ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(), + dist[1].data(), + m / 2, + m, + raft::CompareApprox(params.tolerance), + stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationXequalYF, ::testing::ValuesIn(inputsf)); + const std::vector> inputsd = { {0.001, 1024, 1024, 32, true, 1234ULL}, {0.001, 1024, 32, 1024, true, 1234ULL}, diff --git a/cpp/test/distance/dist_cos.cu b/cpp/test/distance/dist_cos.cu index 9e1cf5af17..caf55529ed 100644 --- a/cpp/test/distance/dist_cos.cu +++ b/cpp/test/distance/dist_cos.cu @@ -24,6 +24,10 @@ template class DistanceExpCos : public DistanceTest { }; +template +class DistanceExpCosXequalY + : public DistanceTestSameBuffer {}; + const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, {0.001f, 1024, 32, 1024, true, 1234ULL}, @@ -34,6 +38,18 @@ const std::vector> inputsf = { {0.001f, 32, 1024, 1024, false, 1234ULL}, {0.003f, 1024, 1024, 1024, false, 1234ULL}, }; + +const std::vector> inputsXeqYf = { + {0.01f, 1024, 1024, 32, true, 1234ULL}, + {0.01f, 1024, 32, 1024, true, 1234ULL}, + {0.01f, 32, 1024, 1024, true, 1234ULL}, + {0.03f, 1024, 1024, 1024, true, 1234ULL}, + {0.01f, 1024, 1024, 32, false, 1234ULL}, + {0.01f, 1024, 32, 1024, false, 1234ULL}, + {0.01f, 32, 1024, 1024, false, 1234ULL}, + {0.03f, 1024, 1024, 1024, false, 1234ULL}, +}; + typedef DistanceExpCos DistanceExpCosF; TEST_P(DistanceExpCosF, Result) { @@ -44,6 +60,29 @@ TEST_P(DistanceExpCosF, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosF, ::testing::ValuesIn(inputsf)); +typedef DistanceExpCosXequalY DistanceExpCosXequalYF; +TEST_P(DistanceExpCosXequalYF, Result) +{ + int m = params.m; + int n = params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(), + dist[0].data(), + m, + n, + raft::CompareApprox(params.tolerance), + stream)); + n = params.isRowMajor ? m : m / 2; + m = params.isRowMajor ? m / 2 : m; + + ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(), + dist[1].data(), + m, + n, + raft::CompareApprox(params.tolerance), + stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosXequalYF, ::testing::ValuesIn(inputsXeqYf)); + const std::vector> inputsd = { {0.001, 1024, 1024, 32, true, 1234ULL}, {0.001, 1024, 32, 1024, true, 1234ULL}, diff --git a/cpp/test/distance/dist_l2_exp.cu b/cpp/test/distance/dist_l2_exp.cu index 6b6a290386..7bdbb44362 100644 --- a/cpp/test/distance/dist_l2_exp.cu +++ b/cpp/test/distance/dist_l2_exp.cu @@ -24,6 +24,10 @@ template class DistanceEucExpTest : public DistanceTest { }; +template +class DistanceEucExpTestXequalY + : public DistanceTestSameBuffer {}; + const std::vector> inputsf = { {0.001f, 2048, 4096, 128, true, 1234ULL}, {0.001f, 1024, 1024, 32, true, 1234ULL}, @@ -37,6 +41,21 @@ const std::vector> inputsf = { {0.003f, 1024, 1024, 1024, false, 1234ULL}, {0.003f, 1021, 1021, 1021, false, 1234ULL}, }; + +const std::vector> inputsXeqYf = { + {0.01f, 2048, 4096, 128, true, 1234ULL}, + {0.01f, 1024, 1024, 32, true, 1234ULL}, + {0.01f, 1024, 32, 1024, true, 1234ULL}, + {0.01f, 32, 1024, 1024, true, 1234ULL}, + {0.03f, 1024, 1024, 1024, true, 1234ULL}, + {0.03f, 1021, 1021, 1021, true, 1234ULL}, + {0.01f, 1024, 1024, 32, false, 1234ULL}, + {0.01f, 1024, 32, 1024, false, 1234ULL}, + {0.01f, 32, 1024, 1024, false, 1234ULL}, + {0.03f, 1024, 1024, 1024, false, 1234ULL}, + {0.03f, 1021, 1021, 1021, false, 1234ULL}, +}; + typedef DistanceEucExpTest DistanceEucExpTestF; TEST_P(DistanceEucExpTestF, Result) { @@ -47,6 +66,27 @@ TEST_P(DistanceEucExpTestF, Result) } INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestF, ::testing::ValuesIn(inputsf)); +typedef DistanceEucExpTestXequalY DistanceEucExpTestXequalYF; +TEST_P(DistanceEucExpTestXequalYF, Result) +{ + int m = params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(), + dist[0].data(), + m, + m, + raft::CompareApprox(params.tolerance), + stream)); + ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(), + dist[1].data(), + m / 2, + m, + raft::CompareApprox(params.tolerance), + stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, + DistanceEucExpTestXequalYF, + ::testing::ValuesIn(inputsXeqYf)); + const std::vector> inputsd = { {0.001, 1024, 1024, 32, true, 1234ULL}, {0.001, 1024, 32, 1024, true, 1234ULL}, diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 6c7cab3f7b..20d78c7bb5 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -532,6 +532,108 @@ class DistanceTest : public ::testing::TestWithParam> { rmm::device_uvector x, y, dist_ref, dist, dist2; }; +/* + * This test suite verifies the path when X and Y are same buffer, + * distance metrics which requires norms like L2 expanded/cosine/correlation + * takes a more optimal path in such case to skip norm calculation for Y buffer. + * It may happen that though both X and Y are same buffer but user passes + * different dimensions for them like in case of tiled_brute_force_knn. + */ +template +class DistanceTestSameBuffer : public ::testing::TestWithParam> { + public: + using dev_vector = rmm::device_uvector; + DistanceTestSameBuffer() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + x(params.m * params.k, stream), + dist_ref({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}), + dist({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}), + dist2({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}) + { + } + + void SetUp() override + { + auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); + common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name()); + + raft::random::RngState r(params.seed); + int m = params.m; + int n = params.m; + int k = params.k; + DataType metric_arg = params.metric_arg; + bool isRowMajor = params.isRowMajor; + if (distanceType == raft::distance::DistanceType::HellingerExpanded || + distanceType == raft::distance::DistanceType::JensenShannon || + distanceType == raft::distance::DistanceType::KLDivergence) { + // Hellinger works only on positive numbers + uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); + } else if (distanceType == raft::distance::DistanceType::RusselRaoExpanded) { + uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); + // Russel rao works on boolean values. + bernoulli(handle, r, x.data(), m * k, 0.5f); + } else { + uniform(handle, r, x.data(), m * k, DataType(-1.0), DataType(1.0)); + } + + for (int i = 0; i < 2; i++) { + // both X and Y are same buffer but when i = 1 + // different dimensions for x & y is passed. + m = m / (i + 1); + naiveDistance(dist_ref[i].data(), + x.data(), + x.data(), + m, + n, + k, + distanceType, + isRowMajor, + metric_arg, + stream); + + DataType threshold = -10000.f; + + if (isRowMajor) { + distanceLauncher(handle, + x.data(), + x.data(), + dist[i].data(), + dist2[i].data(), + m, + n, + k, + params, + threshold, + metric_arg); + + } else { + distanceLauncher(handle, + x.data(), + x.data(), + dist[i].data(), + dist2[i].data(), + m, + n, + k, + params, + threshold, + metric_arg); + } + } + resource::sync_stream(handle, stream); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + DistanceInputs params; + dev_vector x; + static const int N = 2; + std::array dist_ref, dist, dist2; +}; + template class BigMatrixDistanceTest : public ::testing::Test { public: