Skip to content

Commit 2c57fe1

Browse files
committed
Introduces the new partitioner to implement the reduction StreamK kernel
1 parent a46b725 commit 2c57fe1

File tree

7 files changed

+214
-55
lines changed

7 files changed

+214
-55
lines changed

example/ck_tile/40_streamk_gemm/gemm_utils.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ struct GemmConfigBase
3030
template <typename PrecType>
3131
struct GemmConfigMemoryInterwave : public GemmConfigBase
3232
{
33-
static constexpr ck_tile::index_t M_Tile = 128;
34-
static constexpr ck_tile::index_t N_Tile = 128;
33+
static constexpr ck_tile::index_t M_Tile = 256;
34+
static constexpr ck_tile::index_t N_Tile = 256;
3535
static constexpr ck_tile::index_t K_Tile = 32;
3636

3737
static constexpr ck_tile::index_t M_Warp = 2;

example/ck_tile/40_streamk_gemm/run_gemm_example.inc

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,18 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
6969
int n_warmup,
7070
int n_repeat,
7171
bool flush_cache,
72-
ck_tile::StreamKReductionStrategy reduction_strategy,
73-
uint32_t num_sk_blocks)
72+
ck_tile::StreamKReductionStrategy reduction_strategy)
7473
{
75-
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
76-
b_k_n_dev_buf.GetDeviceBuffer(),
77-
c_m_n_dev_buf.GetDeviceBuffer(),
78-
M,
79-
N,
80-
K,
81-
stride_A,
82-
stride_B,
83-
stride_C,
84-
reduction_strategy,
85-
num_sk_blocks};
74+
ck_tile::reboot::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
75+
b_k_n_dev_buf.GetDeviceBuffer(),
76+
c_m_n_dev_buf.GetDeviceBuffer(),
77+
M,
78+
N,
79+
K,
80+
stride_A,
81+
stride_B,
82+
stride_C,
83+
reduction_strategy};
8684

8785
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
8886

@@ -197,7 +195,6 @@ int run_gemm_example_with_layouts(int argc,
197195

198196
ck_tile::StreamKReductionStrategy reduction_strategy =
199197
get_reduction_strategy_value(arg_parser.get_str("reduction_strategy"));
200-
uint32_t num_sk_blocks = static_cast<uint32_t>(arg_parser.get_int("num_sk_blocks"));
201198

202199
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
203200
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
@@ -261,8 +258,7 @@ int run_gemm_example_with_layouts(int argc,
261258
n_warmup,
262259
n_repeat,
263260
flush_cache,
264-
reduction_strategy,
265-
num_sk_blocks);
261+
reduction_strategy);
266262

267263
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
268264

example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// SPDX-License-Identifier: MIT
33

44
#include "gemm_utils.hpp"
5-
#include "run_gemm_example.inc"
65
#include "ck_tile/ops/common.hpp"
76

87
template <typename GemmConfig,
@@ -17,9 +16,8 @@ template <typename GemmConfig,
1716
typename ELayout,
1817
typename CDEElementWise,
1918
ck_tile::StreamKReductionStrategy ReductionStrategy>
20-
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
19+
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs& args,
2120
const ck_tile::stream_config& s)
22-
2321
{
2422
using GemmShape = ck_tile::TileGemmShape<
2523
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
@@ -29,7 +27,8 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
2927
GemmConfig::PermuteA,
3028
GemmConfig::PermuteB>;
3129

32-
using TilePartitioner = ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy>;
30+
using TilePartitioner =
31+
ck_tile::StreamKTilePartitioner_v2<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
3332

3433
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
3534
GemmConfig::kPadN,
@@ -78,9 +77,13 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
7877
memory_operation.value,
7978
GemmConfig::NumWaveGroups>>;
8079

81-
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
80+
using Kernel = ck_tile::reboot::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
8281

83-
auto kargs = Kernel::MakeKernelArgs(args);
82+
auto kargs = Kernel::MakeKernelArgs(args);
83+
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
84+
ck_tile::DeviceMem workspace_data(workspace_size);
85+
workspace_data.SetZero();
86+
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
8487

8588
dim3 grids = Kernel::GridSize(kargs.tile_partitioner);
8689
dim3 blocks = Kernel::BlockSize();
@@ -101,28 +104,28 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
101104
<< std::endl;
102105
}
103106

104-
// Function to clear the output C tensor results after each repetition of the kernel
105-
auto clear_gemm_output = [&]() {
107+
auto reset_data_buffers = [&]() {
106108
if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
109+
{
110+
// Clear the output C tensor results after each repetition of the kernel
107111
hipGetErrorString(hipMemsetAsync(
108112
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
113+
}
114+
else if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
115+
{
116+
// Reset sk flags to zero before each repetition of the kernel
117+
workspace_data.SetZero();
118+
}
109119
};
110120

111-
std::function<void()> preprocess = clear_gemm_output;
121+
std::function<void()> preprocess = reset_data_buffers;
112122

113123
float ave_time = ck_tile::launch_kernel_time_mask(
114124
s,
115125
preprocess,
116126
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
117127

118-
ck_tile::index_t num_wgs_per_tile = ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
119-
kargs.tile_partitioner.sk_num_blocks,
120-
// k_iters_per_big_block could be 1, which indicates that all Stream-K workgroups are
121-
// big and each does one iteration. Thus, we ensure the value passed in is at least 1 to
122-
// avoid division by zero errors.
123-
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
124-
kargs.tile_partitioner.k_iters_per_tile.get());
125-
128+
ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
126129
return std::tuple{ave_time, num_wgs_per_tile};
127130
};
128131

@@ -145,6 +148,8 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
145148
}
146149
}
147150

151+
#include "run_gemm_example.inc"
152+
148153
template <typename GemmConfig, typename TypeConfig>
149154
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
150155
{

include/ck_tile/host/kernel_launch.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ CK_TILE_HOST double timing_loop_impl(TimerType timer,
8181
{
8282
for(int i = 0; i < s.cold_niters_; i++)
8383
{
84+
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
85+
{
86+
preprocess();
87+
}
8488
callables_func();
8589
}
8690
// Only profile preprocess if it's provided

include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp

Lines changed: 168 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "ck_tile/ops/common.hpp"
88
#include "ck_tile/host/concat.hpp"
99

10+
const bool print_log = true;
1011
namespace ck_tile {
1112
namespace reboot {
1213

@@ -84,9 +85,10 @@ struct StreamKKernel
8485
using CLayout = typename GemmPipeline::CLayout;
8586

8687
/// @brief Specify the data type configurations for A, B, and C
87-
using ADataType = typename GemmPipeline::ADataType;
88-
using BDataType = typename GemmPipeline::BDataType;
89-
using CDataType = typename EpiloguePipeline::ODataType;
88+
using ADataType = typename GemmPipeline::ADataType;
89+
using BDataType = typename GemmPipeline::BDataType;
90+
using CDataType = typename EpiloguePipeline::ODataType;
91+
using AccDataType = typename EpiloguePipeline::AccDataType;
9092

9193
template <typename T>
9294
static constexpr bool is_tuple_v = is_detected<is_tuple, T>::value;
@@ -243,22 +245,14 @@ struct StreamKKernel
243245

244246
CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
245247
{
246-
if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction)
247-
{
248-
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
249-
{
250-
CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
251-
}
252-
return false;
253-
}
254248
return UniversalGemmKernel::IsSupportedArgument(kargs);
255249
}
256250

257251
/// @brief Computes the buffer size needed to store accumulation results for Stream K.
258252
/// @return The buffer size needed.
259253
CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs)
260254
{
261-
return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
255+
return kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType));
262256
}
263257

264258
/// @brief Sets the kargs' current workspace_ptr to the given workspace_ptr.
@@ -299,6 +293,89 @@ struct StreamKKernel
299293
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
300294
}
301295

296+
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
297+
index_t cta_idx) const
298+
{
299+
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
300+
workgroup_barrier sk_flags(sk_flags_ptr);
301+
sk_flags.wait_set(0, 1, cta_idx);
302+
}
303+
304+
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
305+
{
306+
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
307+
workgroup_barrier sk_flags(sk_flags_ptr);
308+
sk_flags.wait_eq(1, cta_idx);
309+
}
310+
311+
template <typename OAccTile>
312+
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
313+
const OAccTile& in_block_tile) const
314+
{
315+
using BlockType = remove_cvref_t<decltype(in_out_block_tile)>;
316+
constexpr auto o_spans = BlockType::get_distributed_spans();
317+
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
318+
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
319+
constexpr auto idx = make_tuple(idx0, idx1);
320+
in_out_block_tile(idx) = in_out_block_tile[idx] + in_block_tile[idx];
321+
});
322+
});
323+
}
324+
325+
template <typename DataType, typename OAccTileDist>
326+
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
327+
index_t cta_idx,
328+
const OAccTileDist& c_block_tile_dist) const
329+
{
330+
const auto c_block_tile_buffer_size =
331+
TilePartitioner::MPerBlock * TilePartitioner::NPerBlock * sizeof(DataType);
332+
void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
333+
kargs.tile_partitioner.get_flags_buffer_size() +
334+
cta_idx * c_block_tile_buffer_size;
335+
336+
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
337+
static_cast<DataType*>(partial_buffer_ptr),
338+
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
339+
make_tuple(TilePartitioner::NPerBlock, 1),
340+
number<GemmPipeline::GetVectorSizeC()>{},
341+
number<1>{});
342+
343+
auto partial_tile_window = make_tile_window(
344+
partial_tensor_view,
345+
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
346+
{0, 0},
347+
c_block_tile_dist);
348+
349+
return load_tile(partial_tile_window);
350+
}
351+
352+
template <typename OAccTile>
353+
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
354+
index_t cta_idx,
355+
const OAccTile& c_block_tile) const
356+
{
357+
const auto c_block_tile_buffer_size = TilePartitioner::MPerBlock *
358+
TilePartitioner::NPerBlock *
359+
sizeof(typename OAccTile::DataType);
360+
void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
361+
kargs.tile_partitioner.get_flags_buffer_size() +
362+
cta_idx * c_block_tile_buffer_size;
363+
364+
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
365+
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
366+
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
367+
make_tuple(TilePartitioner::NPerBlock, 1),
368+
number<GemmPipeline::GetVectorSizeC()>{},
369+
number<1>{});
370+
371+
auto partial_tile_window = make_tile_window(
372+
partial_tensor_view,
373+
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
374+
{0, 0});
375+
376+
store_tile(partial_tile_window, c_block_tile);
377+
}
378+
302379
/// @brief Runs the main Stream-K algorithm.
303380
/// @param kargs Stream-K kernel arguments.
304381
/// @param cta_idx The current Stream-K workgroup's index.
@@ -347,7 +424,85 @@ struct StreamKKernel
347424
}
348425
else
349426
{
350-
// TODO: Apply reduction logic.
427+
const auto c_macro_tile_idx =
428+
kargs.tile_partitioner.get_output_tile_index(tile_idx);
429+
index_t i_m =
430+
c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
431+
index_t i_n =
432+
c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
433+
434+
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
435+
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
436+
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
437+
438+
// Create Gemm tensor views, pad views and tile windows
439+
const auto& gemm_tensor_views_tuple =
440+
UniversalGemmKernel::template MakeGemmTensorViews<
441+
EpiloguePipeline::MemoryOperation>(
442+
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size);
443+
444+
const auto& gemm_pad_views =
445+
UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
446+
auto gemm_tile_windows =
447+
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n);
448+
449+
// Run GEMM cooperatively by whole workgroup.
450+
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
451+
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
452+
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
453+
454+
// Since num_loop can vary per WG and per iteration of the Stream-K while loop,
455+
// we compute has_hot_loop and tail_num here. This is a similar pattern used by
456+
// grouped GEMM. In this case, we call the GemmPipeline's operator() function
457+
// that takes both has_hot_loop and tail_num.
458+
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
459+
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
460+
461+
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
462+
bs_block_window[UniversalGemmKernel::I0],
463+
num_loop_sk,
464+
has_hot_loop,
465+
tail_num,
466+
smem_ptr_0);
467+
468+
auto tile_started = iter_start == tile_iter_start;
469+
auto tile_ended = iter_end >= tile_iter_end;
470+
if(!tile_started)
471+
{
472+
StorePartial(kargs, cta_idx, c_block_tile);
473+
__threadfence(); // send signal when the store is done
474+
SignalStorePartialDone(kargs, cta_idx);
475+
}
476+
else
477+
{
478+
auto accum_block_tile = c_block_tile;
479+
if(!tile_ended)
480+
{
481+
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
482+
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
483+
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
484+
int accum_iters = local_iter_end - local_iter_start;
485+
int next_cta = cta_idx + 1;
486+
487+
while(accum_iters < iter_per_tile)
488+
{
489+
WaitStorePartialDone(kargs, next_cta);
490+
491+
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
492+
AddBlockTile(
493+
accum_block_tile,
494+
LoadPartial<typename BlockType::DataType>(
495+
kargs, next_cta, c_block_tile.get_tile_distribution()));
496+
497+
accum_iters += iter_per_cta + (next_cta < extra_iters);
498+
++next_cta;
499+
}
500+
}
501+
502+
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
503+
EpiloguePipeline{}(
504+
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
505+
}
351506
}
352507

353508
// Prepare for next Stream-K loop iteration.

0 commit comments

Comments
 (0)