|
7 | 7 | #include "ck_tile/ops/common.hpp" |
8 | 8 | #include "ck_tile/host/concat.hpp" |
9 | 9 |
|
| 10 | +const bool print_log = true; |
10 | 11 | namespace ck_tile { |
11 | 12 | namespace reboot { |
12 | 13 |
|
@@ -84,9 +85,10 @@ struct StreamKKernel |
84 | 85 | using CLayout = typename GemmPipeline::CLayout; |
85 | 86 |
|
86 | 87 | /// @brief Specify the data type configurations for A, B, and C |
87 | | - using ADataType = typename GemmPipeline::ADataType; |
88 | | - using BDataType = typename GemmPipeline::BDataType; |
89 | | - using CDataType = typename EpiloguePipeline::ODataType; |
| 88 | + using ADataType = typename GemmPipeline::ADataType; |
| 89 | + using BDataType = typename GemmPipeline::BDataType; |
| 90 | + using CDataType = typename EpiloguePipeline::ODataType; |
| 91 | + using AccDataType = typename EpiloguePipeline::AccDataType; |
90 | 92 |
|
91 | 93 | template <typename T> |
92 | 94 | static constexpr bool is_tuple_v = is_detected<is_tuple, T>::value; |
@@ -243,22 +245,14 @@ struct StreamKKernel |
243 | 245 |
|
244 | 246 | CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs) |
245 | 247 | { |
246 | | - if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction) |
247 | | - { |
248 | | - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) |
249 | | - { |
250 | | - CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy."); |
251 | | - } |
252 | | - return false; |
253 | | - } |
254 | 248 | return UniversalGemmKernel::IsSupportedArgument(kargs); |
255 | 249 | } |
256 | 250 |
|
257 | 251 | /// @brief Computes the buffer size needed to store accumulation results for Stream K. |
258 | 252 | /// @return The buffer size needed. |
259 | 253 | CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs) |
260 | 254 | { |
261 | | - return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType)); |
| 255 | + return kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType)); |
262 | 256 | } |
263 | 257 |
|
264 | 258 | /// @brief Sets the kargs' current workspace_ptr to the given workspace_ptr. |
@@ -299,6 +293,89 @@ struct StreamKKernel |
299 | 293 | {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size); |
300 | 294 | } |
301 | 295 |
|
| 296 | + CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs, |
| 297 | + index_t cta_idx) const |
| 298 | + { |
| 299 | + auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr); |
| 300 | + workgroup_barrier sk_flags(sk_flags_ptr); |
| 301 | + sk_flags.wait_set(0, 1, cta_idx); |
| 302 | + } |
| 303 | + |
| 304 | + CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const |
| 305 | + { |
| 306 | + auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr); |
| 307 | + workgroup_barrier sk_flags(sk_flags_ptr); |
| 308 | + sk_flags.wait_eq(1, cta_idx); |
| 309 | + } |
| 310 | + |
| 311 | + template <typename OAccTile> |
| 312 | + CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile, |
| 313 | + const OAccTile& in_block_tile) const |
| 314 | + { |
| 315 | + using BlockType = remove_cvref_t<decltype(in_out_block_tile)>; |
| 316 | + constexpr auto o_spans = BlockType::get_distributed_spans(); |
| 317 | + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { |
| 318 | + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { |
| 319 | + constexpr auto idx = make_tuple(idx0, idx1); |
| 320 | + in_out_block_tile(idx) = in_out_block_tile[idx] + in_block_tile[idx]; |
| 321 | + }); |
| 322 | + }); |
| 323 | + } |
| 324 | + |
| 325 | + template <typename DataType, typename OAccTileDist> |
| 326 | + CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs, |
| 327 | + index_t cta_idx, |
| 328 | + const OAccTileDist& c_block_tile_dist) const |
| 329 | + { |
| 330 | + const auto c_block_tile_buffer_size = |
| 331 | + TilePartitioner::MPerBlock * TilePartitioner::NPerBlock * sizeof(DataType); |
| 332 | + void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) + |
| 333 | + kargs.tile_partitioner.get_flags_buffer_size() + |
| 334 | + cta_idx * c_block_tile_buffer_size; |
| 335 | + |
| 336 | + const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>( |
| 337 | + static_cast<DataType*>(partial_buffer_ptr), |
| 338 | + make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}), |
| 339 | + make_tuple(TilePartitioner::NPerBlock, 1), |
| 340 | + number<GemmPipeline::GetVectorSizeC()>{}, |
| 341 | + number<1>{}); |
| 342 | + |
| 343 | + auto partial_tile_window = make_tile_window( |
| 344 | + partial_tensor_view, |
| 345 | + make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}), |
| 346 | + {0, 0}, |
| 347 | + c_block_tile_dist); |
| 348 | + |
| 349 | + return load_tile(partial_tile_window); |
| 350 | + } |
| 351 | + |
| 352 | + template <typename OAccTile> |
| 353 | + CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs, |
| 354 | + index_t cta_idx, |
| 355 | + const OAccTile& c_block_tile) const |
| 356 | + { |
| 357 | + const auto c_block_tile_buffer_size = TilePartitioner::MPerBlock * |
| 358 | + TilePartitioner::NPerBlock * |
| 359 | + sizeof(typename OAccTile::DataType); |
| 360 | + void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) + |
| 361 | + kargs.tile_partitioner.get_flags_buffer_size() + |
| 362 | + cta_idx * c_block_tile_buffer_size; |
| 363 | + |
| 364 | + const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>( |
| 365 | + static_cast<typename OAccTile::DataType*>(partial_buffer_ptr), |
| 366 | + make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}), |
| 367 | + make_tuple(TilePartitioner::NPerBlock, 1), |
| 368 | + number<GemmPipeline::GetVectorSizeC()>{}, |
| 369 | + number<1>{}); |
| 370 | + |
| 371 | + auto partial_tile_window = make_tile_window( |
| 372 | + partial_tensor_view, |
| 373 | + make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}), |
| 374 | + {0, 0}); |
| 375 | + |
| 376 | + store_tile(partial_tile_window, c_block_tile); |
| 377 | + } |
| 378 | + |
302 | 379 | /// @brief Runs the main Stream-K algorithm. |
303 | 380 | /// @param kargs Stream-K kernel arguments. |
304 | 381 | /// @param cta_idx The current Stream-K workgroup's index. |
@@ -347,7 +424,85 @@ struct StreamKKernel |
347 | 424 | } |
348 | 425 | else |
349 | 426 | { |
350 | | - // TODO: Apply reduction logic. |
| 427 | + const auto c_macro_tile_idx = |
| 428 | + kargs.tile_partitioner.get_output_tile_index(tile_idx); |
| 429 | + index_t i_m = |
| 430 | + c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock; |
| 431 | + index_t i_n = |
| 432 | + c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock; |
| 433 | + |
| 434 | + const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a; |
| 435 | + const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b; |
| 436 | + CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr); |
| 437 | + |
| 438 | + // Create Gemm tensor views, pad views and tile windows |
| 439 | + const auto& gemm_tensor_views_tuple = |
| 440 | + UniversalGemmKernel::template MakeGemmTensorViews< |
| 441 | + EpiloguePipeline::MemoryOperation>( |
| 442 | + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size); |
| 443 | + |
| 444 | + const auto& gemm_pad_views = |
| 445 | + UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple); |
| 446 | + auto gemm_tile_windows = |
| 447 | + UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n); |
| 448 | + |
| 449 | + // Run GEMM cooperatively by whole workgroup. |
| 450 | + const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0); |
| 451 | + const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1); |
| 452 | + const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2); |
| 453 | + |
| 454 | + // Since num_loop can vary per WG and per iteration of the Stream-K while loop, |
| 455 | + // we compute has_hot_loop and tail_num here. This is a similar pattern used by |
| 456 | + // grouped GEMM. In this case, we call the GemmPipeline's operator() function |
| 457 | + // that takes both has_hot_loop and tail_num. |
| 458 | + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk); |
| 459 | + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk); |
| 460 | + |
| 461 | + const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0], |
| 462 | + bs_block_window[UniversalGemmKernel::I0], |
| 463 | + num_loop_sk, |
| 464 | + has_hot_loop, |
| 465 | + tail_num, |
| 466 | + smem_ptr_0); |
| 467 | + |
| 468 | + auto tile_started = iter_start == tile_iter_start; |
| 469 | + auto tile_ended = iter_end >= tile_iter_end; |
| 470 | + if(!tile_started) |
| 471 | + { |
| 472 | + StorePartial(kargs, cta_idx, c_block_tile); |
| 473 | + __threadfence(); // send signal when the store is done |
| 474 | + SignalStorePartialDone(kargs, cta_idx); |
| 475 | + } |
| 476 | + else |
| 477 | + { |
| 478 | + auto accum_block_tile = c_block_tile; |
| 479 | + if(!tile_ended) |
| 480 | + { |
| 481 | + const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile(); |
| 482 | + const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta(); |
| 483 | + const index_t extra_iters = kargs.tile_partitioner.get_extra_iters(); |
| 484 | + int accum_iters = local_iter_end - local_iter_start; |
| 485 | + int next_cta = cta_idx + 1; |
| 486 | + |
| 487 | + while(accum_iters < iter_per_tile) |
| 488 | + { |
| 489 | + WaitStorePartialDone(kargs, next_cta); |
| 490 | + |
| 491 | + using BlockType = remove_cvref_t<decltype(c_block_tile)>; |
| 492 | + AddBlockTile( |
| 493 | + accum_block_tile, |
| 494 | + LoadPartial<typename BlockType::DataType>( |
| 495 | + kargs, next_cta, c_block_tile.get_tile_distribution())); |
| 496 | + |
| 497 | + accum_iters += iter_per_cta + (next_cta < extra_iters); |
| 498 | + ++next_cta; |
| 499 | + } |
| 500 | + } |
| 501 | + |
| 502 | + auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); |
| 503 | + EpiloguePipeline{}( |
| 504 | + c_block_window, accum_block_tile, ds_block_window, smem_ptr_0); |
| 505 | + } |
351 | 506 | } |
352 | 507 |
|
353 | 508 | // Prepare for next Stream-K loop iteration. |
|
0 commit comments