forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathsm90_tile_scheduler_stream_k.hpp
960 lines (827 loc) · 41.1 KB
/
sm90_tile_scheduler_stream_k.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/barrier.h"
#include "cutlass/block_striped.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
#include "cutlass/kernel_hardware_info.hpp"
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
namespace cutlass::gemm::kernel::detail {
// Persistent Thread Block (TB) scheduler leveraging stream-K decomposition
template <
class TileShape,
class ClusterShape
>
class PersistentTileSchedulerSm90StreamK {
//
// Data members
//
private:
using UnderlyingScheduler = PersistentTileSchedulerSm90;
private:
using UnderlyingArguments = typename UnderlyingScheduler::Arguments;
using UnderlyingParams = typename UnderlyingScheduler::Params;
uint64_t current_work_linear_idx_ = 0;
public:
using RasterOrder = UnderlyingScheduler::RasterOrder;
using RasterOrderOptions = UnderlyingScheduler::RasterOrderOptions;
static constexpr bool IsDynamicPersistent = false;
// Use a dummy barrier manager to simply get the type used to store the barrier
using BarrierType = typename NamedBarrierManager<1>::T;
using Params = PersistentTileSchedulerSm90StreamKParams;
using ReductionMode = Params::ReductionMode;
using DecompositionMode = Params::DecompositionMode;
struct WorkTileInfo {
int32_t M_idx = 0;
int32_t N_idx = 0;
int32_t K_idx = 0;
int32_t L_idx = 0;
// Number of k tiles to compute for this unit of work. For stream-K, this
// can indicate the number of K tiles across multiple output tiles.
uint32_t k_tile_count = 0;
// Number of k tiles remaining for the work unit as a whole
uint32_t k_tile_remaining = 0;
// Whether this unit of work is the final split for the given tile
bool is_separate_reduction = false;
CUTLASS_HOST_DEVICE
bool
is_valid() const {
// A work tile that computes no K tiles is invalid unless it is a separate-reduction work tile
// (which only performs reduction and epilogue)
return k_tile_count > 0 || is_separate_reduction;
}
CUTLASS_HOST_DEVICE
bool
is_reduction_unit() const {
return is_separate_reduction;
}
CUTLASS_HOST_DEVICE
int32_t
reduction_subtile_idx() const {
// For separate reduction units, the K_idx of the work tile is unused.
// Therefore, we override it to contain the subtile of that the reduction
// unit operates on.
return is_reduction_unit() ? K_idx : -1;
}
CUTLASS_HOST_DEVICE
void
setup_separate_reduction(int32_t epilogue_subtile_idx) {
// Set the epilogue subtile in the K_idx, since this is otherwise unused
// by separate reduction units.
K_idx = epilogue_subtile_idx;
is_separate_reduction = true;
k_tile_count = 0;
// Clean up remaining k tiles
k_tile_remaining = 0;
}
CUTLASS_HOST_DEVICE
static WorkTileInfo
invalid_work_tile() {
return {-1, -1, -1, -1, 0};
}
CUTLASS_HOST_DEVICE
bool
is_final_split(uint32_t k_tiles_per_output_tile) const {
return (K_idx + k_tile_count) == k_tiles_per_output_tile;
}
};
struct Arguments {
Arguments() = default;
Arguments(Arguments const&) = default;
Arguments(Arguments&&) = default;
CUTLASS_HOST_DEVICE
Arguments&
operator=(Arguments const& args) {
splits = args.splits;
max_swizzle_size = args.max_swizzle_size;
raster_order = args.raster_order;
reduction_mode = args.reduction_mode;
decomposition_mode = args.decomposition_mode;
return *this;
}
CUTLASS_HOST_DEVICE
Arguments&
operator=(Arguments&& args) noexcept {
splits = args.splits;
max_swizzle_size = args.max_swizzle_size;
raster_order = args.raster_order;
reduction_mode = args.reduction_mode;
decomposition_mode = args.decomposition_mode;
return *this;
}
CUTLASS_HOST_DEVICE
Arguments(int splits_) : splits(splits_) {}
CUTLASS_HOST_DEVICE
Arguments(int splits_, int max_swizzle_size_, RasterOrderOptions raster_order_, DecompositionMode decomposition_mode_) :
splits(splits_),
max_swizzle_size(max_swizzle_size_),
raster_order(raster_order_),
decomposition_mode(decomposition_mode_) {}
// The splitting factor to be used in a split-K decomposition of the problem.
// If this is set to a value greater than 1, stream-K decomposition logic
// is bypassed in favor of a split-K decomposition.
int splits = 1;
int max_swizzle_size = 1;
RasterOrderOptions raster_order = RasterOrderOptions::Heuristic;
ReductionMode reduction_mode = ReductionMode::Deterministic;
DecompositionMode decomposition_mode = DecompositionMode::Heuristic;
};
// Sink scheduler params as a member
Params scheduler_params;
//
// Methods
//
template <class ProblemShape>
static Params
to_underlying_arguments(
ProblemShape problem_shape,
TileShape tile_shape,
ClusterShape cluster_shape,
KernelHardwareInfo const& hw_info,
Arguments const& args,
void* workspace,
const uint32_t epilogue_subtile = 1,
[[maybe_unused]] uint32_t ktile_start_alignment_count = 1u) {
static_assert(cute::is_static<TileShape>::value);
static_assert(cute::is_static<ClusterShape>::value);
auto problem_shape_mnkl = cute::append<4>(problem_shape, cute::Int<1>{});
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{})));
Params params;
params.initialize(
problem_blocks,
k_tile_per_output_tile,
to_gemm_coord(cluster_shape),
hw_info,
args.splits,
args.max_swizzle_size,
args.raster_order,
args.reduction_mode,
args.decomposition_mode,
workspace,
epilogue_subtile
);
return params;
}
static bool
can_implement(Arguments const& args) {
// Split count > 1 is only valid for heuristic and split-K decomposition modes
return (args.splits == 1 ||
args.decomposition_mode == DecompositionMode::Heuristic ||
args.decomposition_mode == DecompositionMode::SplitK);
}
CUTLASS_HOST_DEVICE
PersistentTileSchedulerSm90StreamK() { };
CUTLASS_HOST_DEVICE
PersistentTileSchedulerSm90StreamK(Params const& params_) : scheduler_params(params_) {
if (params_.raster_order_ == RasterOrder::AlongN) {
current_work_linear_idx_ = uint64_t(BlockIdxX()) + uint64_t(BlockIdxY()) * uint64_t(GridDimX());
}
else {
current_work_linear_idx_ = uint64_t(BlockIdxX()) * uint64_t(GridDimY()) + uint64_t(BlockIdxY());
}
}
CUTLASS_DEVICE
WorkTileInfo
get_current_work() const {
return get_current_work_for_linear_idx(current_work_linear_idx_, scheduler_params);
}
CUTLASS_DEVICE
static WorkTileInfo
get_current_work_for_linear_idx(uint64_t linear_idx, Params const& params) {
// The maximum number of work units is units_per_problem_ * splits_.
// The multiplication by splits_ is used for handling split-K, in which
// units_per_problem_ is equal to the total number of output tiles. To account
// for the fact that we have splits_ peers per output tile, we multiply this
// value by splits_. For stream-K, this multiplication ends up being a no-op
// because splits_ is set to 1 for stream-K.
if(linear_idx >= (params.units_per_problem_ * params.divmod_splits_.divisor + params.separate_reduction_units_)) {
// Invalid work. Return an empty result.
return WorkTileInfo::invalid_work_tile();
}
WorkTileInfo work_tile_info;
assign_work(params, linear_idx, work_tile_info);
return work_tile_info;
}
// Returns whether the current work_tile_info passed in should continue to be used. This
// occurs only in the stream-K decomposition with stream-K work units, which encompass
// work over multiple output tiles. If the current work_tile_info should continue to be
// used, it is updated to advance to the next output tile it should cover.
CUTLASS_DEVICE
bool
continue_current_work(WorkTileInfo& work_tile_info) const {
return continue_current_work_for_linear_idx(
current_work_linear_idx_, work_tile_info, scheduler_params);
}
CUTLASS_DEVICE
static bool
continue_current_work_for_linear_idx(
uint64_t linear_idx,
WorkTileInfo& work_tile_info,
Params const& params) {
work_tile_info.k_tile_remaining -= work_tile_info.k_tile_count;
if (work_tile_info.k_tile_remaining == 0) {
return false;
}
assign_work(params, linear_idx, work_tile_info);
return work_tile_info.is_valid();
}
CUTLASS_DEVICE
void
advance_to_next_work(uint32_t advance_count = 1) {
current_work_linear_idx_ += uint64_t(GridDimX()) * uint64_t(GridDimY()) * uint64_t(GridDimZ()) * uint64_t(advance_count);
}
CUTLASS_DEVICE
bool is_last_tile(WorkTileInfo work_tile_info, uint32_t advance_count = 1) const {
// Never pass this by reference; it needs a copy,
// because continue_current_work will modify it.
if (continue_current_work(work_tile_info)) {
return false;
}
return not get_current_work_for_linear_idx(
current_work_linear_idx_ + (
uint64_t(GridDimX()) * uint64_t(GridDimY()) * uint64_t(GridDimZ()) * uint64_t(advance_count)
),
scheduler_params
).is_valid();
}
// Given the inputs, computes the total number of output blocks this problem will compute over
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
template <class ProblemShape>
CUTLASS_HOST_DEVICE static
dim3
get_tiled_cta_shape_mnl(ProblemShape problem_shape_mnkl, TileShape cta_shape, ClusterShape cluster_shape) {
return UnderlyingScheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape);
}
// Given the cluster shape, computes the physical grid we should launch.
template <class ProblemShape>
CUTLASS_HOST_DEVICE static
dim3
get_grid_shape(
[[maybe_unused]] Params const& params,
ProblemShape problem_shape,
TileShape tile_shape,
ClusterShape cluster_shape,
KernelHardwareInfo hw_info,
Arguments arguments) {
auto problem_shape_mnkl = cute::append<4>(problem_shape, cute::Int<1>{});
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
return Params::get_grid_shape(
problem_blocks,
to_gemm_coord(cluster_shape),
hw_info,
arguments.max_swizzle_size,
arguments.raster_order
);
}
// Returns whether fixup is needed for `work_tile_info`.
CUTLASS_HOST_DEVICE
static bool
requires_fixup(Params const& params, WorkTileInfo const& work_tile_info) {
// Fixup is not needed for invalid or data-parallel tiles
return work_tile_info.is_valid() && work_tile_info.k_tile_count != params.divmod_tiles_per_output_tile_.divisor;
}
CUTLASS_HOST_DEVICE
static bool
requires_separate_reduction(Params const& params) {
return params.requires_separate_reduction();
}
// When the work tile is not special for reduction, it's valid. Otherwise need to skip
// global loading that producer warpgroup do, also math computation that consumer warpgroup do.
CUTLASS_DEVICE
static bool
valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) {
return !work_tile_info.is_reduction_unit();
}
// Performs the reduction across splits for a given output tile.
template <class FrgTensorC>
CUTLASS_DEVICE
static void
fixup(
Params const& params,
WorkTileInfo const& work_tile_info,
FrgTensorC& accumulators,
uint32_t num_barriers,
uint32_t barrier_idx) {
static constexpr uint32_t Offset = static_cast<int>(cutlass::arch::ReservedNamedBarriers::StreamkBarrier0);
static constexpr uint32_t MaxNumNamedBarriers = 2;
using BarrierManager = NamedBarrierManager<NumThreadsPerWarpGroup, Offset, MaxNumNamedBarriers>;
return fixup_helper<FrgTensorC, BarrierManager>(
params, work_tile_info, accumulators, num_barriers, barrier_idx);
}
// Helper for performing the reduction across splits for a given output tile.
template <class FrgTensorC, class BarrierManager>
CUTLASS_DEVICE
static void
fixup_helper(
Params const& params,
WorkTileInfo const& work_tile_info,
FrgTensorC& accumulators,
uint32_t num_barriers,
uint32_t barrier_idx,
uint32_t num_accumulator_mtxs = 1) {
using ElementAccumulator = typename FrgTensorC::value_type;
if (!requires_fixup(params, work_tile_info)) {
return;
}
uint64_t tile_idx = output_tile_index(params, work_tile_info);
// Index of the lock on which to wait
uint64_t lock_idx = (tile_idx * num_barriers) + barrier_idx;
uint64_t reduction_tile_idx = tile_idx;
uint64_t num_peers = 0;
uint64_t reduction_peer_offset = 0;
if (params.requires_separate_reduction()) {
// If separate reduction is to be performed, each stream-K unit writes its partials
// to a separate portion of the workspace. There are as many of these portions as there
// are peers for a given output tile, so we multiply the tile index by the maximum peer count.
auto [first_peer_id, my_peer_id, last_peer_id] = tile_peer_range(params, tile_idx, static_cast<uint32_t>(work_tile_info.K_idx));
num_peers = last_peer_id - first_peer_id + 1;
reduction_tile_idx *= Params::max_peers_per_tile(params.sk_units_, params.sk_tiles_);
reduction_peer_offset = my_peer_id * cute::size<0>(TileShape{}) * cute::size<1>(TileShape{});
}
// Reductions use BlockStripedReduce with a width of BarrierManager::ThreadCount under the hood.
// Thus, the start of the reduction space is the same across all threads in a warp group.
uint64_t reduction_offset =
(static_cast<uint64_t>(cute::size<0>(TileShape{})) * static_cast<uint64_t>(cute::size<1>(TileShape{})) * reduction_tile_idx * num_accumulator_mtxs) +
reduction_peer_offset +
(static_cast<uint64_t>(size(accumulators)) * barrier_idx * BarrierManager::ThreadCount);
ElementAccumulator* group_reduction_workspace = reinterpret_cast<ElementAccumulator*>(params.reduction_workspace_) + reduction_offset;
using AccumulatorArrayT = Array<typename FrgTensorC::value_type, size(FrgTensorC{})>;
using BlockStripedReduceT = BlockStripedReduce<BarrierManager::ThreadCount, AccumulatorArrayT>;
AccumulatorArrayT* reduction_workspace_array = reinterpret_cast<AccumulatorArrayT*>(group_reduction_workspace);
AccumulatorArrayT* accumulator_array = reinterpret_cast<AccumulatorArrayT*>(accumulators.data());
uint32_t barrier_group_thread_idx = ThreadIdxX() % BarrierManager::ThreadCount;
// The number of tiles for which reduction is required is either:
// (a) the total number of output tiles (in the case of split-K)
// (b) the number of stream-K tiles (potentially multiplied by peer count if using separate reduction)
// To calculate the total number of output tiles in the split-K case, we
// note that, in the split-K case, the units_per_problem_ member of Params will be
// the total number of output tiles.
uint32_t reduction_tiles = 0;
if (params.divmod_splits_.divisor > 1) {
reduction_tiles = params.units_per_problem_;
}
else if (params.requires_separate_reduction()) {
reduction_tiles = params.sk_tiles_ * Params::max_peers_per_tile(params.sk_units_, params.sk_tiles_);
}
else {
reduction_tiles = params.sk_tiles_;
}
uint64_t reduction_workspace_size = Params::get_reduction_workspace_size(
reduction_tiles, to_gemm_coord(TileShape{}), sizeof_bits<ElementAccumulator>::value, num_accumulator_mtxs);
BarrierType* lock_workspace = reinterpret_cast<BarrierType*>(
reinterpret_cast<uint8_t*>(params.reduction_workspace_) + reduction_workspace_size);
if (work_tile_info.is_reduction_unit()) {
plus<AccumulatorArrayT> add_fragments;
uint64_t peer_offset = size(accumulators) * num_barriers * BarrierManager::ThreadCount;
// Wait until the peers collaborating on this output tile have all written
// their accumulators to workspace.
BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, num_peers);
// Load the first peer's data
BlockStripedReduceT::load(*accumulator_array, reduction_workspace_array, barrier_group_thread_idx);
for (uint64_t i = 1; i < num_peers; ++i) {
// Load peer fragment
AccumulatorArrayT addend_fragment;
auto peer_reduction_workspace = reinterpret_cast<AccumulatorArrayT*>(group_reduction_workspace + (i * peer_offset));
BlockStripedReduceT::load(addend_fragment, peer_reduction_workspace, barrier_group_thread_idx);
// Add peer fragment
*accumulator_array = add_fragments(*accumulator_array, addend_fragment);
}
}
else if (!compute_epilogue(work_tile_info, params)) {
if (params.requires_separate_reduction() || work_tile_info.K_idx == 0) {
// The first peer initializes the workspace partials in the non-separate-reduction case,
// and all peers write to their own location in workspace when using separate reduction
BlockStripedReduceT::store(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx);
}
else {
// Wait until the preceding split added its accumulators
BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx);
// Perform reduction in workspace
BlockStripedReduceT::reduce(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx);
}
// If separate reduction is being performed, each participating stream-K unit increments the barrier
// by only 1. Otherwise, increment by the K tile count that this unit has processed.
uint32_t increment = params.requires_separate_reduction() ? 1 : work_tile_info.k_tile_count;
// Signal our arrival
BarrierManager::arrive_inc(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, increment);
}
else {
if (params.reduction_mode_ == ReductionMode::Deterministic) {
// Wait until the preceding split added its accumulators
BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx);
}
else {
// Wait unitl the first split has stored its accumulators
BarrierManager::wait_lt(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, 1);
}
// The block computing the final split for the tile adds previously-reduced partials
// to its accumulators and computes the epilogue.
BlockStripedReduceT::load_add(*accumulator_array, reduction_workspace_array, barrier_group_thread_idx);
}
}
// Returns whether the block assigned this work should compute the epilogue for the corresponding
// output tile. For the case of stream-K, this should only occur if the work is marked as the final split.
CUTLASS_HOST_DEVICE
static bool
compute_epilogue(WorkTileInfo const& work_tile_info, Params const& params) {
// `is_final_split` will be set to `true` for the following scenarios, all of which must compute the epilogue:
// 1. The tile is computed in data-parallel mode
// 2. The tile is computed in split-/stream-K mode and this work unit represents the final split of the tile
// 3. The tile is computed in split-/stream-K mode and separate reduction is used, and this is a separate reduction unit
return work_tile_info.is_valid() &&
(work_tile_info.is_final_split(params.divmod_tiles_per_output_tile_.divisor) &&
!params.requires_separate_reduction()) || work_tile_info.is_separate_reduction;
}
// Returns the linearized index of the output tile corresponding to the tile with offset [L, M, K]
CUTLASS_DEVICE
static uint64_t
output_tile_index(Params const& params, WorkTileInfo const& work_tile_info) {
uint64_t linear_idx_in_batch = UnderlyingScheduler::get_linear_idx_from_m_and_n(
work_tile_info.M_idx, work_tile_info.N_idx,
params.divmod_cluster_shape_major_,
params.divmod_cluster_shape_minor_,
params.divmod_cluster_blk_major_,
params.log_swizzle_size_,
params.raster_order_
);
uint64_t tiles_mn = params.divmod_batch_.divisor;
return tiles_mn * work_tile_info.L_idx + linear_idx_in_batch;
}
template <class ProblemShape, class ElementAccumulator>
static size_t
get_workspace_size(
Arguments const& args,
ProblemShape problem_shape,
KernelHardwareInfo const& hw_info,
uint32_t mma_warp_groups,
const uint32_t epilogue_subtile = 1,
[[maybe_unused]] uint32_t num_accumulator_mtxs = 1) {
auto problem_shape_mnkl = cute::append<4>(problem_shape, 1);
ClusterShape cluster_shape;
TileShape tile_shape;
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{})));
return Params::get_workspace_size(
problem_blocks,
k_tile_per_output_tile,
to_gemm_coord(tile_shape),
to_gemm_coord(cluster_shape),
hw_info,
args.splits,
args.max_swizzle_size,
args.raster_order,
args.decomposition_mode,
mma_warp_groups,
sizeof_bits<BarrierType>::value,
sizeof_bits<ElementAccumulator>::value,
epilogue_subtile
);
}
template <class ProblemShape, class ElementAccumulator>
static cutlass::Status
initialize_workspace(
Arguments const& args,
void* workspace,
cudaStream_t stream,
ProblemShape const& problem_shape,
KernelHardwareInfo const& hw_info,
uint32_t mma_warp_groups,
const uint32_t epilogue_subtile = 1,
[[maybe_unused]] uint32_t num_accumulator_mtxs = 1,
CudaHostAdapter* cuda_adapter = nullptr) {
auto problem_shape_mnkl = cute::append<4>(problem_shape, 1);
ClusterShape cluster_shape;
TileShape tile_shape;
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape);
uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{})));
return Params::initialize_workspace(
workspace,
stream,
problem_blocks,
k_tile_per_output_tile,
to_gemm_coord(tile_shape),
to_gemm_coord(cluster_shape),
hw_info,
args.splits,
args.max_swizzle_size,
args.raster_order,
args.decomposition_mode,
mma_warp_groups,
sizeof_bits<BarrierType>::value,
sizeof_bits<ElementAccumulator>::value,
epilogue_subtile,
1,
cuda_adapter
);
}
template <class ProblemShape>
CUTLASS_HOST_DEVICE
static uint32_t
get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape, TileShape) {
return work_tile_info.k_tile_count;
}
CUTLASS_HOST_DEVICE
static uint32_t
get_work_k_tile_start(WorkTileInfo const& work_tile_info) {
return work_tile_info.K_idx;
}
// Kernel helper function to get next work tile
CUTLASS_DEVICE
auto
fetch_next_work(WorkTileInfo work_tile_info) {
if (continue_current_work(work_tile_info)) {
return cute::make_tuple(work_tile_info, true);
}
advance_to_next_work();
return cute::make_tuple(get_current_work(), true);
}
// Returns the initial work tile info that will be computed over
CUTLASS_DEVICE
WorkTileInfo
initial_work_tile_info(ClusterShape) {
return get_current_work();
}
private:
// Sets the current stream-K work to compute within work_tile_info. If new_unit is true, work_tile_info
// is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining
// iterations) is used to find the next tile in the current work unit.
CUTLASS_DEVICE
static void
assign_work(
Params const& params,
uint64_t linear_idx,
WorkTileInfo& work_tile_info) {
auto [cta_m_in_cluster_, cta_n_in_cluster_, _] = cute::block_id_in_cluster();
uint64_t cta_m_in_cluster = static_cast<uint64_t>(cta_m_in_cluster_);
uint64_t cta_n_in_cluster = static_cast<uint64_t>(cta_n_in_cluster_);
uint64_t output_tile_id = linear_idx;
if (linear_idx >= params.units_per_problem_ * params.divmod_splits_.divisor) {
// Separate-reduction work
auto cluster_size = params.get_cluster_size();
// Divide up the linearized separate reduction units into clusters
uint64_t cluster_linear_reduction_unit_idx = params.div_cluster_size((linear_idx - params.units_per_problem_));
uint64_t cluster_tile_idx, epi_subtile_idx;
params.divmod_epilogue_subtile_(cluster_tile_idx, epi_subtile_idx, cluster_linear_reduction_unit_idx);
// Bring the linearized tile ID back into the space of tiles, rather than clusters
output_tile_id = cluster_tile_idx * cluster_size;
work_tile_info.setup_separate_reduction(epi_subtile_idx);
}
else if (linear_idx >= params.sk_units_ && params.divmod_splits_.divisor == 1) {
// Data-parallel work
output_tile_id = linear_idx - params.sk_units_ + params.sk_tiles_;
work_tile_info.K_idx = 0;
work_tile_info.k_tile_count = params.divmod_tiles_per_output_tile_.divisor;
work_tile_info.k_tile_remaining = params.divmod_tiles_per_output_tile_.divisor;
}
else {
// In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K
// threadblock individually. For the most part, the set of K iterations corresponding to stream-K
// work was divided amongst stream-K threadblocks, and a threadblock determined which tile
// it would compute a (potentially-partial) output tile for based on the space of k iterations
// assigned to it. This often results in stream-K threadblocks processing tiles with different
// offsets in the K dimension from one another. This can reduce locality, but is lmitied to the
// (generally few) waves of threadblocks assigned to compute stream-K work.
//
// With the introduction of threadblock clusters, there is additional benefit to maintaining
// locality in the K dimension: shared portions of operands can be multicasted to threadblocks
// within a cluster. Thus, we would like to ensure that the assignment of stream-K work to
// threadblocks respects the ability to perform multicasting.
//
// To do so, we divide up the linearized stream-K units into clusters and share the same K
// offsets for work within clusters.
uint64_t cluster_linear_work_idx = params.div_cluster_size(linear_idx);
uint64_t group_idx;
params.divmod_sk_groups_(cluster_linear_work_idx, group_idx, cluster_linear_work_idx);
// Determine whether we are in a "big group" that will process an additional
// stream-K cluster tile.
uint64_t sk_cluster_tiles = params.div_cluster_size(params.sk_tiles_);
uint64_t sk_cluster_tiles_in_group = params.divmod_sk_groups_.divide(sk_cluster_tiles);
if (group_idx < params.big_groups_) {
++sk_cluster_tiles_in_group;
}
// Determine whether we are in a "big unit" within the group, that will process
// an additional K chunk in the group.
uint64_t sk_tiles_in_group = sk_cluster_tiles_in_group * params.get_cluster_size();
uint64_t k_tiles_in_group = sk_tiles_in_group * params.divmod_tiles_per_output_tile_.divisor;
uint64_t k_tiles_per_unit_in_group = params.divmod_sk_units_per_group_.divide(k_tiles_in_group);
uint64_t big_units_in_group = params.div_cluster_size(
k_tiles_in_group - (k_tiles_per_unit_in_group * params.divmod_sk_units_per_group_.divisor));
uint64_t split;
params.divmod_clusters_mnl_(split, cluster_linear_work_idx, cluster_linear_work_idx);
bool is_split_k = params.divmod_splits_.divisor > 1;
uint64_t big_unit_cmp_lhs = is_split_k ? split : cluster_linear_work_idx;
uint64_t big_unit_cmp_rhs = is_split_k ? params.big_units_ : big_units_in_group;
uint64_t linear_idx_mult = is_split_k ? params.divmod_tiles_per_output_tile_.divisor : k_tiles_per_unit_in_group;
uint64_t k_tiles_per_split = is_split_k ? params.divmod_k_tiles_per_sk_unit_.divisor : k_tiles_per_unit_in_group;
// Determine the starting k iteration computed by this stream-K work unit
uint32_t unit_iter_start = (linear_idx_mult * cluster_linear_work_idx) +
(k_tiles_per_split * split);
// Adjust the starting position and number of k iterations for "big units," which
// compute one extra iteration. If there are any big units, they will be the first
// in the linearized ID space.
auto k_tiles_in_my_split = k_tiles_per_split;
if (big_unit_cmp_lhs < big_unit_cmp_rhs) {
// Since the "big units" are the first units in the linearized ID space, each
// of the units preceding this big unit computed one extra iteration. Thus,
// we must offset our start iteration by the number of units that precede
// the current unit in the linearized ID space.
unit_iter_start += big_unit_cmp_lhs;
++k_tiles_in_my_split;
}
else {
// Increment by one for each of the big clusters (since all big units precede this unit)
unit_iter_start += big_unit_cmp_rhs;
}
if (!is_split_k) {
// Adjust the unit starting position and number of tiles to avoid
// computing splits of size less than min_iters_per_sk_unit_
int unused, start_tile_k_tile;
params.divmod_tiles_per_output_tile_(unused, start_tile_k_tile, unit_iter_start);
if (start_tile_k_tile < Params::min_iters_per_sk_unit_) {
// Starting K tile is in range [0, Params::min_iters_per_sk_unit_), which means that another
// stream-K unit will be computing a split with fewer than Params::min_iters_per_sk_unit_ K tiles.
// Adjust our work to take over these K tiles.
unit_iter_start -= start_tile_k_tile;
k_tiles_in_my_split += start_tile_k_tile;
}
else if (start_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) {
// Starting K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile,
// which means that this unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles.
// Adjust our work to shed these K tiles to a neighboring stream-K unit that will compute more consecutive K tiles.
auto adjustment_tiles = (params.divmod_tiles_per_output_tile_.divisor - start_tile_k_tile);
unit_iter_start += adjustment_tiles;
k_tiles_in_my_split -= adjustment_tiles;
}
else if (params.ktile_start_alignment_count == 2 && start_tile_k_tile % 2 != 0) {
// ktile for each SM start from even number
// If start from odd number ktile within the output tile
// now start at the ktile one before my initial ktile start (take one ktile from prev sm)
// if end on odd number ktile within the output tile
// now end at ktile that one before my ktile end (give one ktile to next sm)
unit_iter_start -= 1;
k_tiles_in_my_split += 1;
}
}
if (work_tile_info.k_tile_count == 0) {
// This is a new unit
if (!is_split_k) {
//
// Adjust the unit ending position and number of tiles to avoid
// computing splits of size less than min_iters_per_sk_unit_
//
// Begin by assuming that no adjustment is needed
auto initial_unit_iter_end = unit_iter_start + k_tiles_in_my_split;
int unused, end_tile_k_tile;
params.divmod_tiles_per_output_tile_(unused, end_tile_k_tile, initial_unit_iter_end);
if (end_tile_k_tile < Params::min_iters_per_sk_unit_) {
// Ending K tile is within the first Params::min_iters_per_sk_unit_ K tiles of some output tile,
// which means that this unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles.
// Adjust our work to shed these K tiles to a neighboring stream-K unit that will compute more consecutive K tiles.
k_tiles_in_my_split -= end_tile_k_tile;
}
else if (end_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) {
// Ending K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile,
// which means that some other unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles.
// Adjust our work to take on these K tiles.
k_tiles_in_my_split += (params.divmod_tiles_per_output_tile_.divisor - end_tile_k_tile);
}
else if (params.ktile_start_alignment_count == 2 && end_tile_k_tile % 2 != 0) {
// ktile for each SM start from even number
// If start from odd number ktile within the output tile
// now start at the ktile one before my initial ktile start (take one ktile from prev sm)
// If end on odd number ktile within the output tile,
// now end at ktile that one before my ktile end (give one ktile to next sm)
k_tiles_in_my_split -= 1;
}
}
work_tile_info.k_tile_remaining = k_tiles_in_my_split;
}
uint32_t unit_iter_end = unit_iter_start + work_tile_info.k_tile_remaining - 1;
// Find the output tile corresponding to the final k tile covered by this
// work unit. Stream-K work units will work backwards in terms of the tiles they
// are responsible computing. This is beneficial because the final (partial)
// tile computed by a stream-K block is typically the beginning of the output
// tile, while the beginning (partial) tile is typically the ending of another
// output tile. Since ending portions of an output tile must reduce across
// other work units computing portions of that output tile, it is preferable
// for them to be computed later, so as to reduce the likelihood of blocking
// on other work.
auto output_tile_id_in_group = params.divmod_tiles_per_output_tile_.divide(unit_iter_end);
uint32_t output_tile_iter_start = output_tile_id_in_group * params.divmod_tiles_per_output_tile_.divisor;
uint32_t output_tile_iter_end = output_tile_iter_start + params.divmod_tiles_per_output_tile_.divisor;
// Convert the output tile from the linearized space within each group to the
// overall linearized space.
output_tile_id = (output_tile_id_in_group * params.divmod_sk_groups_.divisor) + group_idx;
// Bring the linearized tile ID back into the space of tiles, rather than clusters
output_tile_id *= params.get_cluster_size();
// The final linearized tile ID is in units of the cluster dimension over which we rasterize.
if (params.raster_order_ == RasterOrder::AlongN) {
output_tile_id += cta_n_in_cluster * params.divmod_cluster_shape_minor_.divisor;
}
else {
output_tile_id += cta_m_in_cluster * params.divmod_cluster_shape_minor_.divisor;
}
// The unit's starting k iteration in the current tile is either the starting
// iteration for the tile as a whole, or the starting k iteration for the unit
// as a whole (if the latter is greater than the former).
uint32_t tile_iter_start = max(output_tile_iter_start, unit_iter_start);
// Similarly, the unit's ending k iteration (exclusive) is either the end of
// the current tile it is assigned, or the ending iteration of the unit as a whole
// (if the latter is less than the former).
uint32_t tile_iter_end = min(output_tile_iter_end, unit_iter_end + 1);
// Set the k offset to be the starting k tile for this output tile
work_tile_info.K_idx = static_cast<int32_t>(tile_iter_start - output_tile_iter_start);
work_tile_info.k_tile_count = tile_iter_end - tile_iter_start;
}
uint64_t work_idx_l, remainder;
params.divmod_batch_(work_idx_l, remainder, output_tile_id);
uint64_t cta_per_grid_dim = params.divmod_cluster_shape_minor_.divide(remainder);
auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n(
cta_per_grid_dim,
params.divmod_cluster_shape_major_,
params.divmod_cluster_shape_minor_,
params.divmod_cluster_blk_major_,
params.log_swizzle_size_,
params.raster_order_
);
// Set the M, N, and L block offsets
work_tile_info.M_idx = work_idx_m;
work_tile_info.N_idx = work_idx_n;
work_tile_info.L_idx = static_cast<int32_t>(work_idx_l);
}
// Returns the starting and ending peer ID of this tile
CUTLASS_HOST_DEVICE
static auto
tile_peer_range(Params const& params, uint32_t tile_idx, uint32_t cur_k_tile) {
uint32_t tile_idx_in_cluster_path = params.div_cluster_size(tile_idx);
uint32_t start_k_tile = params.divmod_tiles_per_output_tile_.divisor * tile_idx_in_cluster_path;
uint32_t end_k_tile = start_k_tile + params.divmod_tiles_per_output_tile_.divisor - 1;
uint32_t big_unit_k_tiles = params.big_units_ * (params.divmod_k_tiles_per_sk_unit_.divisor + 1);
auto adjust_unit = [&](uint32_t k_tile, uint32_t unit_idx, uint32_t k_tiles_per_unit) {
uint32_t unit_k_start = unit_idx * k_tiles_per_unit;
uint32_t unit_k_end = unit_k_start + k_tiles_per_unit;
if (k_tile - start_k_tile < Params::min_iters_per_sk_unit_ &&
unit_k_end - start_k_tile < Params::min_iters_per_sk_unit_) {
// k_tile is within the first min_iters_per_sk_unit_ K tiles of this output tile,
// and the stream-K unit computes fewer than min_iters_per_sk_unit_ K tiles for this
// output tile. This work will thus be subsumed by the next stream-K unit.
++unit_idx;
}
if (end_k_tile + 1 - k_tile < Params::min_iters_per_sk_unit_ &&
end_k_tile + 1 - unit_k_start < Params::min_iters_per_sk_unit_) {
// k_tile is within the last min_iters_per_sk_unit_ K tiles of this output tile,
// and the stream-K unit computes fewer than min_iters_per_sk_unit_ K tiles for this
// output tile. This work will thus be subsumed by the previous stream-K unit.
--unit_idx;
}
return unit_idx;
};
// Lambda to find the ID of the stream-K unit that computes this K tile
auto find_unit = [&](uint32_t k_tile) {
if (k_tile < big_unit_k_tiles) {
// The tile is within the "big unit range"
uint32_t unit_idx = params.divmod_k_tiles_per_sk_big_unit_.divide(k_tile);
return static_cast<uint64_t>(adjust_unit(k_tile, unit_idx, params.divmod_k_tiles_per_sk_big_unit_.divisor));
}
else {
// The tile is after the "big unit range." Account for this by finding the "normal unit"
// that it belongs to, and then offsetting by the number of big units
uint32_t unit_idx = params.divmod_k_tiles_per_sk_unit_.divide(k_tile - big_unit_k_tiles) + params.big_units_;
return static_cast<uint64_t>(adjust_unit(k_tile, unit_idx, params.divmod_k_tiles_per_sk_unit_.divisor));
}
};
return cute::make_tuple(find_unit(start_k_tile), find_unit(cur_k_tile), find_unit(end_k_tile));
}
};
} // namespace cutlass::gemm::kernel::detail