Skip to content

Commit

Permalink
Overlap epilog compute with ldg of next grid stride in pairwise dista…
Browse files Browse the repository at this point in the history
…nce & fusedL2NN kernels (#292)

overlap epilog compute with ldg of next grid stride in pairwise distance base class. 
gives 2-3% perf improvement for pairwise distance kernels and fusedL2NN kernel.

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)

Approvers:
  - Thejaswi. N. S (https://github.com/teju85)

URL: #292
  • Loading branch information
mdoijade authored Jul 15, 2021
1 parent f94780c commit 35411a0
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions cpp/include/raft/distance/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -136,39 +136,29 @@ struct PairwiseDistances : public BaseClass {
this->xrowid += stride;
}

DI void prolog(IdxT gridStrideX, IdxT gridStrideY) {
if (gridStrideX > blockIdx.x * P::Nblk) {
DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) {
// Fetch next grid stride ldg if within range
if ((gridStrideX + gridDim.x * P::Nblk) < this->n) {
updateIndicesY();
} else if (gridStrideY > blockIdx.y * P::Mblk) {
this->ldgXY(0);
} else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) {
updateIndicesXY();
this->ldgXY(0);
}
}

typedef TxN_t<DataT, P::Veclen> VecType;
VecType zeros;
zeros.fill(BaseClass::Zero);
#pragma unroll
for (int j = 0; j < P::LdgPerThX; ++j) {
zeros.store(&this->ldgDataX[j][0], 0);
}
#pragma unroll
for (int j = 0; j < P::LdgPerThY; ++j) {
zeros.store(&this->ldgDataY[j][0], 0);
DI void prolog(IdxT gridStrideX, IdxT gridStrideY) {
if (gridStrideX == blockIdx.x * P::Nblk) {
this->ldgXY(0);
}

this->ldgXY(0);

#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
zeros.store(&this->regx[i][0], 0);
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
acc[i][j] = BaseClass::Zero;
}
}
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
zeros.store(&this->regy[j][0], 0);
}

this->stsXY();
__syncthreads();
Expand Down Expand Up @@ -239,8 +229,12 @@ struct PairwiseDistances : public BaseClass {
regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)];
}

// Overlap ldg with epilog computation
ldgNextGridStride(gridStrideX, gridStrideY);
epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY);
} else {
// Overlap ldg with epilog computation
ldgNextGridStride(gridStrideX, gridStrideY);
epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY);
}

Expand Down

0 comments on commit 35411a0

Please sign in to comment.