Skip to content

Commit

Permalink
[HOTFIX] Fix distance metrics L2/cosine/correlation when X & Y are sa…
Browse files Browse the repository at this point in the history
…me buffer but with different shape and add unit test for such case. (#1571)

-- This is how tiled_brute_force_knn may use pairwise distance API hence assuming when X == Y the buffer has same shape is incorrect.

Authors:
   - Mahesh Doijade (https://github.com/mdoijade)

Approvers:
   - Tamas Bela Feher (https://github.com/tfeher)
   - Corey J. Nolet (https://github.com/cjnolet)
  • Loading branch information
mdoijade authored Jun 6, 2023
1 parent 7d3bed2 commit fc979fe
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 33 deletions.
79 changes: 46 additions & 33 deletions cpp/include/raft/distance/detail/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ void distance_impl(raft::resources const& handle,
bool is_row_major,
DataT) // unused
{
ASSERT(
!(((x != y) && (worksize < 2 * (m + n) * sizeof(AccT))) || (worksize < 2 * m * sizeof(AccT))),
"workspace size error");
ASSERT(!(worksize < 2 * (m + n) * sizeof(AccT)), "workspace size error");
ASSERT(workspace != nullptr, "workspace is null");

cudaStream_t stream = raft::resource::get_cuda_stream(handle);
Expand All @@ -137,9 +135,27 @@ void distance_impl(raft::resources const& handle,
AccT* y_norm = workspace;
AccT* sq_x_norm = workspace;
AccT* sq_y_norm = workspace;
if (x != y) {
// TODO: Column major case looks to have lower accuracy for X == Y,
// perhaps the use of stridedSummationKernel could be causing this,
// need to investigate and fix.
if (x == y && is_row_major) {
raft::linalg::reduce(x_norm,
x,
k,
std::max(m, n),
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
sq_x_norm += std::max(m, n);
sq_y_norm = sq_x_norm;
raft::linalg::rowNorm(
sq_x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream);
} else {
y_norm += m;

raft::linalg::reduce(x_norm,
x,
k,
Expand Down Expand Up @@ -167,21 +183,6 @@ void distance_impl(raft::resources const& handle,
sq_y_norm = sq_x_norm + m;
raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream);
raft::linalg::rowNorm(sq_y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream);
} else {
raft::linalg::reduce(x_norm,
x,
k,
m,
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
sq_x_norm += m;
sq_y_norm = sq_x_norm;
raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream);
}

using OpT = ops::correlation_distance_op<DataT, AccT, IdxT>;
Expand Down Expand Up @@ -210,23 +211,25 @@ void distance_impl(raft::resources const& handle,
"OutT can be uint8_t, float, double,"
"if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT).");

ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))),
"workspace size error");
ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error");
ASSERT(workspace != nullptr, "workspace is null");

cudaStream_t stream = raft::resource::get_cuda_stream(handle);

DataT* x_norm = workspace;
DataT* y_norm = workspace;
if (x != y) {
// TODO: Column major case looks to have lower accuracy for X == Y,
// perhaps the use of stridedSummationKernel could be causing this,
// need to investigate and fix.
if (x == y && is_row_major) {
raft::linalg::rowNorm(
x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
} else {
y_norm += m;
raft::linalg::rowNorm(
x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
raft::linalg::rowNorm(
y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
} else {
raft::linalg::rowNorm(
x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
}

ops::cosine_distance_op<DataT, AccT, IdxT> distance_op{};
Expand Down Expand Up @@ -453,21 +456,29 @@ void distance_impl_l2_expanded( // NOTE: different name
"OutT can be uint8_t, float, double,"
"if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT).");

ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))),
"workspace size error");
ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error");
ASSERT(workspace != nullptr, "workspace is null");

DataT* x_norm = workspace;
DataT* y_norm = workspace;
if (x != y) {
// TODO: Column major case looks to have lower accuracy for X == Y,
// perhaps the use of stridedSummationKernel could be causing this,
// need to investigate and fix.
if ((x == y) && is_row_major) {
raft::linalg::rowNorm(x_norm,
x,
k,
std::max(m, n),
raft::linalg::L2Norm,
is_row_major,
stream,
raft::identity_op{});
} else {
y_norm += m;
raft::linalg::rowNorm(
x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{});
raft::linalg::rowNorm(
y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{});
} else {
raft::linalg::rowNorm(
x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{});
}

ops::l2_exp_distance_op<DataT, AccT, IdxT> distance_op{perform_sqrt};
Expand Down Expand Up @@ -789,8 +800,10 @@ size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, In
(distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1;

if (is_allocated) {
// TODO : when X == Y allocate std::max(m, n) instead of m + n when column major input
// accuracy issue is resolved until then we allocate as m + n.
worksize += numOfBuffers * m * sizeof(AccType);
if (x != y) worksize += numOfBuffers * n * sizeof(AccType);
worksize += numOfBuffers * n * sizeof(AccType);
}

return worksize;
Expand Down
23 changes: 23 additions & 0 deletions cpp/test/distance/dist_correlation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ template <typename DataType>
class DistanceCorrelation
: public DistanceTest<raft::distance::DistanceType::CorrelationExpanded, DataType> {};

template <typename DataType>
class DistanceCorrelationXequalY
: public DistanceTestSameBuffer<raft::distance::DistanceType::CorrelationExpanded, DataType> {};

const std::vector<DistanceInputs<float>> inputsf = {
{0.001f, 1024, 1024, 32, true, 1234ULL},
{0.001f, 1024, 32, 1024, true, 1234ULL},
Expand All @@ -44,6 +48,25 @@ TEST_P(DistanceCorrelationF, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationF, ::testing::ValuesIn(inputsf));

typedef DistanceCorrelationXequalY<float> DistanceCorrelationXequalYF;
TEST_P(DistanceCorrelationXequalYF, Result)
{
int m = params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(),
dist[0].data(),
m,
m,
raft::CompareApprox<float>(params.tolerance),
stream));
ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(),
dist[1].data(),
m / 2,
m,
raft::CompareApprox<float>(params.tolerance),
stream));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationXequalYF, ::testing::ValuesIn(inputsf));

const std::vector<DistanceInputs<double>> inputsd = {
{0.001, 1024, 1024, 32, true, 1234ULL},
{0.001, 1024, 32, 1024, true, 1234ULL},
Expand Down
39 changes: 39 additions & 0 deletions cpp/test/distance/dist_cos.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ template <typename DataType>
class DistanceExpCos : public DistanceTest<raft::distance::DistanceType::CosineExpanded, DataType> {
};

template <typename DataType>
class DistanceExpCosXequalY
: public DistanceTestSameBuffer<raft::distance::DistanceType::CosineExpanded, DataType> {};

const std::vector<DistanceInputs<float>> inputsf = {
{0.001f, 1024, 1024, 32, true, 1234ULL},
{0.001f, 1024, 32, 1024, true, 1234ULL},
Expand All @@ -34,6 +38,18 @@ const std::vector<DistanceInputs<float>> inputsf = {
{0.001f, 32, 1024, 1024, false, 1234ULL},
{0.003f, 1024, 1024, 1024, false, 1234ULL},
};

const std::vector<DistanceInputs<float>> inputsXeqYf = {
{0.01f, 1024, 1024, 32, true, 1234ULL},
{0.01f, 1024, 32, 1024, true, 1234ULL},
{0.01f, 32, 1024, 1024, true, 1234ULL},
{0.03f, 1024, 1024, 1024, true, 1234ULL},
{0.01f, 1024, 1024, 32, false, 1234ULL},
{0.01f, 1024, 32, 1024, false, 1234ULL},
{0.01f, 32, 1024, 1024, false, 1234ULL},
{0.03f, 1024, 1024, 1024, false, 1234ULL},
};

typedef DistanceExpCos<float> DistanceExpCosF;
TEST_P(DistanceExpCosF, Result)
{
Expand All @@ -44,6 +60,29 @@ TEST_P(DistanceExpCosF, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosF, ::testing::ValuesIn(inputsf));

typedef DistanceExpCosXequalY<float> DistanceExpCosXequalYF;
TEST_P(DistanceExpCosXequalYF, Result)
{
int m = params.m;
int n = params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(),
dist[0].data(),
m,
n,
raft::CompareApprox<float>(params.tolerance),
stream));
n = params.isRowMajor ? m : m / 2;
m = params.isRowMajor ? m / 2 : m;

ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(),
dist[1].data(),
m,
n,
raft::CompareApprox<float>(params.tolerance),
stream));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosXequalYF, ::testing::ValuesIn(inputsXeqYf));

const std::vector<DistanceInputs<double>> inputsd = {
{0.001, 1024, 1024, 32, true, 1234ULL},
{0.001, 1024, 32, 1024, true, 1234ULL},
Expand Down
40 changes: 40 additions & 0 deletions cpp/test/distance/dist_l2_exp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ template <typename DataType>
class DistanceEucExpTest : public DistanceTest<raft::distance::DistanceType::L2Expanded, DataType> {
};

template <typename DataType>
class DistanceEucExpTestXequalY
: public DistanceTestSameBuffer<raft::distance::DistanceType::L2Expanded, DataType> {};

const std::vector<DistanceInputs<float>> inputsf = {
{0.001f, 2048, 4096, 128, true, 1234ULL},
{0.001f, 1024, 1024, 32, true, 1234ULL},
Expand All @@ -37,6 +41,21 @@ const std::vector<DistanceInputs<float>> inputsf = {
{0.003f, 1024, 1024, 1024, false, 1234ULL},
{0.003f, 1021, 1021, 1021, false, 1234ULL},
};

const std::vector<DistanceInputs<float>> inputsXeqYf = {
{0.01f, 2048, 4096, 128, true, 1234ULL},
{0.01f, 1024, 1024, 32, true, 1234ULL},
{0.01f, 1024, 32, 1024, true, 1234ULL},
{0.01f, 32, 1024, 1024, true, 1234ULL},
{0.03f, 1024, 1024, 1024, true, 1234ULL},
{0.03f, 1021, 1021, 1021, true, 1234ULL},
{0.01f, 1024, 1024, 32, false, 1234ULL},
{0.01f, 1024, 32, 1024, false, 1234ULL},
{0.01f, 32, 1024, 1024, false, 1234ULL},
{0.03f, 1024, 1024, 1024, false, 1234ULL},
{0.03f, 1021, 1021, 1021, false, 1234ULL},
};

typedef DistanceEucExpTest<float> DistanceEucExpTestF;
TEST_P(DistanceEucExpTestF, Result)
{
Expand All @@ -47,6 +66,27 @@ TEST_P(DistanceEucExpTestF, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestF, ::testing::ValuesIn(inputsf));

typedef DistanceEucExpTestXequalY<float> DistanceEucExpTestXequalYF;
TEST_P(DistanceEucExpTestXequalYF, Result)
{
int m = params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(),
dist[0].data(),
m,
m,
raft::CompareApprox<float>(params.tolerance),
stream));
ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(),
dist[1].data(),
m / 2,
m,
raft::CompareApprox<float>(params.tolerance),
stream));
}
INSTANTIATE_TEST_CASE_P(DistanceTests,
DistanceEucExpTestXequalYF,
::testing::ValuesIn(inputsXeqYf));

const std::vector<DistanceInputs<double>> inputsd = {
{0.001, 1024, 1024, 32, true, 1234ULL},
{0.001, 1024, 32, 1024, true, 1234ULL},
Expand Down
Loading

0 comments on commit fc979fe

Please sign in to comment.