From 35411a0b0ad74c7d54c323225a07952ec81cceec Mon Sep 17 00:00:00 2001 From: Mahesh Doijade <36705640+mdoijade@users.noreply.github.com> Date: Thu, 15 Jul 2021 21:00:53 +0530 Subject: [PATCH] Overlap epilog compute with ldg of next grid stride in pairwise distance & fusedL2NN kernels (#292) overlap epilog compute with ldg of next grid stride in pairwise distance base class. gives 2-3% perf improvement for pairwise distance kernels and fusedL2NN kernel. Authors: - Mahesh Doijade (https://github.com/mdoijade) Approvers: - Thejaswi. N. S (https://github.com/teju85) URL: https://github.com/rapidsai/raft/pull/292 --- .../raft/distance/pairwise_distance_base.cuh | 34 ++++++++----------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/pairwise_distance_base.cuh index d5a434f2fa..43abc9eb65 100644 --- a/cpp/include/raft/distance/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/pairwise_distance_base.cuh @@ -136,39 +136,29 @@ struct PairwiseDistances : public BaseClass { this->xrowid += stride; } - DI void prolog(IdxT gridStrideX, IdxT gridStrideY) { - if (gridStrideX > blockIdx.x * P::Nblk) { + DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) { + // Fetch next grid stride ldg if within range + if ((gridStrideX + gridDim.x * P::Nblk) < this->n) { updateIndicesY(); - } else if (gridStrideY > blockIdx.y * P::Mblk) { + this->ldgXY(0); + } else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) { updateIndicesXY(); + this->ldgXY(0); } + } - typedef TxN_t VecType; - VecType zeros; - zeros.fill(BaseClass::Zero); -#pragma unroll - for (int j = 0; j < P::LdgPerThX; ++j) { - zeros.store(&this->ldgDataX[j][0], 0); - } -#pragma unroll - for (int j = 0; j < P::LdgPerThY; ++j) { - zeros.store(&this->ldgDataY[j][0], 0); + DI void prolog(IdxT gridStrideX, IdxT gridStrideY) { + if (gridStrideX == blockIdx.x * P::Nblk) { + this->ldgXY(0); } - this->ldgXY(0); - #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { - zeros.store(&this->regx[i][0], 0); #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { acc[i][j] = BaseClass::Zero; } } -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - zeros.store(&this->regy[j][0], 0); - } this->stsXY(); __syncthreads(); @@ -239,8 +229,12 @@ struct PairwiseDistances : public BaseClass { regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; } + // Overlap ldg with epilog computation + ldgNextGridStride(gridStrideX, gridStrideY); epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); } else { + // Overlap ldg with epilog computation + ldgNextGridStride(gridStrideX, gridStrideY); epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); }