diff --git a/velox/experimental/wave/common/tests/BlockTest.cu b/velox/experimental/wave/common/tests/BlockTest.cu index eab1173a0f40..ade413839d93 100644 --- a/velox/experimental/wave/common/tests/BlockTest.cu +++ b/velox/experimental/wave/common/tests/BlockTest.cu @@ -502,7 +502,11 @@ UPDATE_CASE(updateSum1NoSync, testSumNoSync, 0); UPDATE_CASE(updateSum1Mtx, testSumMtx, 0); UPDATE_CASE(updateSum1MtxCoalesce, testSumMtxCoalesce, 0); UPDATE_CASE(updateSum1Atomic, testSumAtomic, 0); -UPDATE_CASE(updateSum1AtomicCoalesce, testSumAtomicCoalesce, 0); +UPDATE_CASE(updateSum1AtomicCoalesceShfl, testSumAtomicCoalesceShfl, 0); +UPDATE_CASE( + updateSum1AtomicCoalesceShmem, + testSumAtomicCoalesceShmem, + run.blockSize * sizeof(int64_t)); UPDATE_CASE(updateSum1Exch, testSumExch, sizeof(ProbeShared)); UPDATE_CASE(updateSum1Order, testSumOrder, 0); @@ -662,7 +666,8 @@ REGISTER_KERNEL("partitionShorts", partitionShortsKernel); REGISTER_KERNEL("hashTest", hashTestKernel); REGISTER_KERNEL("allocatorTest", allocatorTestKernel); REGISTER_KERNEL("sum1atm", updateSum1AtomicKernel); -REGISTER_KERNEL("sum1atmCoa", updateSum1AtomicCoalesceKernel); +REGISTER_KERNEL("sum1atmCoaShfl", updateSum1AtomicCoalesceShflKernel); +REGISTER_KERNEL("sum1atmCoaShmem", updateSum1AtomicCoalesceShmemKernel); REGISTER_KERNEL("sum1Exch", updateSum1ExchKernel); REGISTER_KERNEL("sum1Part", updateSum1PartKernel); REGISTER_KERNEL("partSum", update1PartitionKernel); diff --git a/velox/experimental/wave/common/tests/BlockTest.h b/velox/experimental/wave/common/tests/BlockTest.h index 2c79469113d0..5218537141a3 100644 --- a/velox/experimental/wave/common/tests/BlockTest.h +++ b/velox/experimental/wave/common/tests/BlockTest.h @@ -148,7 +148,8 @@ class BlockTestStream : public Stream { void updateSum1Atomic(TestingRow* rows, HashRun& run); void updateSum1Exch(TestingRow* rows, HashRun& run); void updateSum1NoSync(TestingRow* rows, HashRun& run); - void updateSum1AtomicCoalesce(TestingRow* rows, HashRun& run); + void updateSum1AtomicCoalesceShfl(TestingRow* rows, HashRun& run); + void updateSum1AtomicCoalesceShmem(TestingRow* rows, HashRun& run); void updateSum1Part(TestingRow* rows, HashRun& run); void updateSum1Mtx(TestingRow* rows, HashRun& run); void updateSum1MtxCoalesce(TestingRow* rows, HashRun& run); diff --git a/velox/experimental/wave/common/tests/HashTableTest.cpp b/velox/experimental/wave/common/tests/HashTableTest.cpp index 0e5662911411..bae0fd4667dc 100644 --- a/velox/experimental/wave/common/tests/HashTableTest.cpp +++ b/velox/experimental/wave/common/tests/HashTableTest.cpp @@ -145,11 +145,15 @@ class HashTableTest : public testing::Test { case HashTestCase::kUpdateSum1: UPDATE_CASE("sum1Atm", updateSum1Atomic, true, 0); UPDATE_CASE("sum1NoSync", updateSum1NoSync, false, 0); - UPDATE_CASE("sum1AtmCoa", updateSum1AtomicCoalesce, true, 1); + UPDATE_CASE("sum1AtmCoaShfl", updateSum1AtomicCoalesceShfl, true, 1); + UPDATE_CASE("sum1AtmCoaShmem", updateSum1AtomicCoalesceShmem, true, 1); UPDATE_CASE("sum1Mtx", updateSum1Mtx, true, 1); UPDATE_CASE("sum1MtxCoa", updateSum1MtxCoalesce, true, 0); UPDATE_CASE("sum1Part", updateSum1Part, true, 0); - UPDATE_CASE("sum1Order", updateSum1Order, true, 0); + // Commenting out Order and Exch functions as they are too slow. + // (for case when only 1 distinct element). + + // UPDATE_CASE("sum1Order", updateSum1Order, true, 0); // UPDATE_CASE("sum1Exch", updateSum1Exch, false, 0); break; @@ -354,18 +358,28 @@ TEST_F(HashTableTest, update) { { HashRun run; run.testCase = HashTestCase::kUpdateSum1; - updateTestCase(1000, 2000000, run); + updateTestCase(10000000, 2000000, run); } { HashRun run; run.testCase = HashTestCase::kUpdateSum1; - updateTestCase(10000000, 2000000, run); + updateTestCase(100000, 2000000, run); + } + { + HashRun run; + run.testCase = HashTestCase::kUpdateSum1; + updateTestCase(1000, 2000000, run); } { HashRun run; run.testCase = HashTestCase::kUpdateSum1; updateTestCase(10, 2000000, run); } + { + HashRun run; + run.testCase = HashTestCase::kUpdateSum1; + updateTestCase(1, 2000000, run); + } } TEST_F(HashTableTest, groupBy) { diff --git a/velox/experimental/wave/common/tests/Updates.cuh b/velox/experimental/wave/common/tests/Updates.cuh index e25b2c918be8..fd0aa491e291 100644 --- a/velox/experimental/wave/common/tests/Updates.cuh +++ b/velox/experimental/wave/common/tests/Updates.cuh @@ -102,7 +102,47 @@ void __device__ testSumAtomic(TestingRow* rows, HashProbe* probe) { } } -void __device__ testSumAtomicCoalesce(TestingRow* rows, HashProbe* probe) { +void __device__ testSumAtomicCoalesceShmem(TestingRow* rows, HashProbe* probe) { + constexpr int32_t kWarpThreads = 32; + auto keys = reinterpret_cast(probe->keys); + auto indices = keys[0]; + auto deltas = keys[1]; + int32_t base = probe->numRowsPerThread * blockDim.x * blockIdx.x; + int32_t lane = cub::LaneId(); + int32_t end = base + probe->numRows[blockIdx.x]; + extern __shared__ char smem[]; + + int64_t* totals = (int64_t*)smem; + + for (auto count = base; count < end; count += blockDim.x) { + auto i = threadIdx.x + count; + + if (i < end) { + totals[threadIdx.x] = 0; + __syncwarp(); + uint32_t laneMask = count + kWarpThreads <= end + ? 0xffffffff + : lowMask(end - count); + auto index = indices[i]; + auto delta = deltas[i]; + uint32_t allPeers = __match_any_sync(laneMask, index); + int32_t leader = __ffs(allPeers) - 1; + atomicAdd( + (unsigned long long*)&totals + [(threadIdx.x & (~(kWarpThreads - 1))) | leader], + (unsigned long long)delta); + __syncwarp(); + if (lane == leader) { + auto* row = &rows[index]; + atomicAdd( + (unsigned long long*)&row->count, + (unsigned long long)totals[threadIdx.x]); + } + } + } +} + +void __device__ testSumAtomicCoalesceShfl(TestingRow* rows, HashProbe* probe) { constexpr int32_t kWarpThreads = 32; auto keys = reinterpret_cast(probe->keys); auto indices = keys[0];