Skip to content

Commit

Permalink
Speed up aggregation update using shmem (#10114)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #10114

For cases where we might have enough memory in shmem, we can do faster agg update.

This is 0.97-3.5x faster than thread unsafe `sum1NoSync` version on H100.

Algo:
Within a warp, do atomicAdd within shmem first, and then do atomicAdd in global memory.

Also added a test case with only 1 distinct element(and commented out extremely slow sum1Order as it takes forever to run)

Reviewed By: Yuhta

Differential Revision: D58296793

fbshipit-source-id: 51a5d4cb8e1aab76fe7d21406b3aab03c902e39f
  • Loading branch information
pranjalssh authored and facebook-github-bot committed Jun 10, 2024
1 parent 20069f9 commit c622f91
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 8 deletions.
9 changes: 7 additions & 2 deletions velox/experimental/wave/common/tests/BlockTest.cu
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,11 @@ UPDATE_CASE(updateSum1NoSync, testSumNoSync, 0);
UPDATE_CASE(updateSum1Mtx, testSumMtx, 0);
UPDATE_CASE(updateSum1MtxCoalesce, testSumMtxCoalesce, 0);
UPDATE_CASE(updateSum1Atomic, testSumAtomic, 0);
UPDATE_CASE(updateSum1AtomicCoalesce, testSumAtomicCoalesce, 0);
UPDATE_CASE(updateSum1AtomicCoalesceShfl, testSumAtomicCoalesceShfl, 0);
UPDATE_CASE(
updateSum1AtomicCoalesceShmem,
testSumAtomicCoalesceShmem,
run.blockSize * sizeof(int64_t));
UPDATE_CASE(updateSum1Exch, testSumExch, sizeof(ProbeShared));
UPDATE_CASE(updateSum1Order, testSumOrder, 0);

Expand Down Expand Up @@ -662,7 +666,8 @@ REGISTER_KERNEL("partitionShorts", partitionShortsKernel);
REGISTER_KERNEL("hashTest", hashTestKernel);
REGISTER_KERNEL("allocatorTest", allocatorTestKernel);
REGISTER_KERNEL("sum1atm", updateSum1AtomicKernel);
REGISTER_KERNEL("sum1atmCoa", updateSum1AtomicCoalesceKernel);
REGISTER_KERNEL("sum1atmCoaShfl", updateSum1AtomicCoalesceShflKernel);
REGISTER_KERNEL("sum1atmCoaShmem", updateSum1AtomicCoalesceShmemKernel);
REGISTER_KERNEL("sum1Exch", updateSum1ExchKernel);
REGISTER_KERNEL("sum1Part", updateSum1PartKernel);
REGISTER_KERNEL("partSum", update1PartitionKernel);
Expand Down
3 changes: 2 additions & 1 deletion velox/experimental/wave/common/tests/BlockTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ class BlockTestStream : public Stream {
void updateSum1Atomic(TestingRow* rows, HashRun& run);
void updateSum1Exch(TestingRow* rows, HashRun& run);
void updateSum1NoSync(TestingRow* rows, HashRun& run);
void updateSum1AtomicCoalesce(TestingRow* rows, HashRun& run);
void updateSum1AtomicCoalesceShfl(TestingRow* rows, HashRun& run);
void updateSum1AtomicCoalesceShmem(TestingRow* rows, HashRun& run);
void updateSum1Part(TestingRow* rows, HashRun& run);
void updateSum1Mtx(TestingRow* rows, HashRun& run);
void updateSum1MtxCoalesce(TestingRow* rows, HashRun& run);
Expand Down
22 changes: 18 additions & 4 deletions velox/experimental/wave/common/tests/HashTableTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,15 @@ class HashTableTest : public testing::Test {
case HashTestCase::kUpdateSum1:
UPDATE_CASE("sum1Atm", updateSum1Atomic, true, 0);
UPDATE_CASE("sum1NoSync", updateSum1NoSync, false, 0);
UPDATE_CASE("sum1AtmCoa", updateSum1AtomicCoalesce, true, 1);
UPDATE_CASE("sum1AtmCoaShfl", updateSum1AtomicCoalesceShfl, true, 1);
UPDATE_CASE("sum1AtmCoaShmem", updateSum1AtomicCoalesceShmem, true, 1);
UPDATE_CASE("sum1Mtx", updateSum1Mtx, true, 1);
UPDATE_CASE("sum1MtxCoa", updateSum1MtxCoalesce, true, 0);
UPDATE_CASE("sum1Part", updateSum1Part, true, 0);
UPDATE_CASE("sum1Order", updateSum1Order, true, 0);
// Commenting out Order and Exch functions as they are too slow.
// (for case when only 1 distinct element).

// UPDATE_CASE("sum1Order", updateSum1Order, true, 0);
// UPDATE_CASE("sum1Exch", updateSum1Exch, false, 0);

break;
Expand Down Expand Up @@ -354,18 +358,28 @@ TEST_F(HashTableTest, update) {
{
HashRun run;
run.testCase = HashTestCase::kUpdateSum1;
updateTestCase(1000, 2000000, run);
updateTestCase(10000000, 2000000, run);
}
{
HashRun run;
run.testCase = HashTestCase::kUpdateSum1;
updateTestCase(10000000, 2000000, run);
updateTestCase(100000, 2000000, run);
}
{
HashRun run;
run.testCase = HashTestCase::kUpdateSum1;
updateTestCase(1000, 2000000, run);
}
{
HashRun run;
run.testCase = HashTestCase::kUpdateSum1;
updateTestCase(10, 2000000, run);
}
{
HashRun run;
run.testCase = HashTestCase::kUpdateSum1;
updateTestCase(1, 2000000, run);
}
}

TEST_F(HashTableTest, groupBy) {
Expand Down
42 changes: 41 additions & 1 deletion velox/experimental/wave/common/tests/Updates.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,47 @@ void __device__ testSumAtomic(TestingRow* rows, HashProbe* probe) {
}
}

void __device__ testSumAtomicCoalesce(TestingRow* rows, HashProbe* probe) {
void __device__ testSumAtomicCoalesceShmem(TestingRow* rows, HashProbe* probe) {
constexpr int32_t kWarpThreads = 32;
auto keys = reinterpret_cast<int64_t**>(probe->keys);
auto indices = keys[0];
auto deltas = keys[1];
int32_t base = probe->numRowsPerThread * blockDim.x * blockIdx.x;
int32_t lane = cub::LaneId();
int32_t end = base + probe->numRows[blockIdx.x];
extern __shared__ char smem[];

int64_t* totals = (int64_t*)smem;

for (auto count = base; count < end; count += blockDim.x) {
auto i = threadIdx.x + count;

if (i < end) {
totals[threadIdx.x] = 0;
__syncwarp();
uint32_t laneMask = count + kWarpThreads <= end
? 0xffffffff
: lowMask<uint32_t>(end - count);
auto index = indices[i];
auto delta = deltas[i];
uint32_t allPeers = __match_any_sync(laneMask, index);
int32_t leader = __ffs(allPeers) - 1;
atomicAdd(
(unsigned long long*)&totals
[(threadIdx.x & (~(kWarpThreads - 1))) | leader],
(unsigned long long)delta);
__syncwarp();
if (lane == leader) {
auto* row = &rows[index];
atomicAdd(
(unsigned long long*)&row->count,
(unsigned long long)totals[threadIdx.x]);
}
}
}
}

void __device__ testSumAtomicCoalesceShfl(TestingRow* rows, HashProbe* probe) {
constexpr int32_t kWarpThreads = 32;
auto keys = reinterpret_cast<int64_t**>(probe->keys);
auto indices = keys[0];
Expand Down

0 comments on commit c622f91

Please sign in to comment.