forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathell_gemm.h
824 lines (678 loc) · 28.2 KB
/
ell_gemm.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a Block-Ell sparse gemm kernel.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/arch/arch.h"
#include "cutlass/transform/threadblock/ell_iterator.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled.
bool IsASparse ///! If true, A is sparse matrix
>
struct EllGemm {
using Mma = Mma_;
using Epilogue = Epilogue_;
using OutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size{};
cutlass::gemm::GemmCoord grid_tiled_shape{};
int swizzle_log_tile{0};
typename Mma::IteratorA::Params params_A{};
typename Mma::IteratorA::TensorRef ref_A{};
typename Mma::IteratorB::Params params_B{};
typename Mma::IteratorB::TensorRef ref_B{};
typename Epilogue::OutputTileIterator::Params params_C{};
typename Epilogue::OutputTileIterator::TensorRef ref_C{};
typename Epilogue::OutputTileIterator::Params params_D{};
typename Epilogue::OutputTileIterator::TensorRef ref_D{};
typename OutputOp::Params output_op{};
int *semaphore = nullptr;
int gemm_k_iterations{0};
int gemm_k_size{0};
const int* ell_idx = nullptr;
int ell_ncol{0};
int ell_blocksize{0};
int ell_base_idx{0};
//
// Methods
//
Params() = default;
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const & problem_size,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D,
const int* ell_idx,
int ell_ncol,
int ell_blocksize,
int ell_base_idx,
typename OutputOp::Params output_op = typename OutputOp::Params(),
int *workspace = nullptr
):
problem_size(problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
params_A(ref_A.layout()),
ref_A(ref_A),
params_B(ref_B.layout()),
ref_B(ref_B),
params_C(ref_C.layout()),
ref_C(ref_C),
params_D(ref_D.layout()),
ref_D(ref_D),
output_op(output_op),
ell_idx(ell_idx),
ell_ncol(ell_ncol),
ell_blocksize(ell_blocksize),
ell_base_idx(ell_base_idx)
{
int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
semaphore = workspace;
}
};
/// Shared memory storage structure
struct SharedStorage {
union{
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
typename cutlass::transform::threadblock::ell::SharedStorage ell;
};
//
// Methods
//
EllGemm() = default;
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D) {
static int const kAlignmentA = (platform::is_same<typename Mma::IteratorA::Layout,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Mma::IteratorA::Layout,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = (platform::is_same<typename Mma::IteratorB::Layout,
layout::RowMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Mma::IteratorB::Layout,
layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
if (!TensorRef_aligned(ref_A, kAlignmentA)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_B, kAlignmentB)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_C, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_D, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
int tile_in_ell_block = (params.ell_blocksize + Mma::Shape::kM - 1 ) / Mma::Shape::kM;
int ell_block_offset_m = threadblock_tile_offset.m() / tile_in_ell_block;
int tile_offset_m = threadblock_tile_offset.m() % tile_in_ell_block;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
typename Mma::FragmentC accumulators;
accumulators.clear();
// skip computation if matrix is 0
if (params.ell_ncol > 0) {
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
ell_block_offset_m * params.ell_blocksize
+ tile_offset_m * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size
};
cutlass::MatrixCoord tb_offset_B{
threadblock_tile_offset.k() * params.gemm_k_size,
threadblock_tile_offset.n() * Mma::Shape::kN
};
int ell_idx_start =
(threadblock_tile_offset.m() / tile_in_ell_block) *
(params.ell_ncol / params.ell_blocksize);
const int* ell_idx_ptr = &(params.ell_idx[ell_idx_start]);
// Problem size is a function of threadblock index in the K dimension
int problem_size_k = min(
params.problem_size.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
problem_size_k = min(problem_size_k, params.ell_ncol);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations =
(problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
params.ref_A.data(),
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B,
params.ref_B.data(),
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B);
// Define coef for ELL index depending on LayoutB
int ell_stride = iterator_B.get_stride();
typename cutlass::transform::threadblock::ell::Iterator ell_iterator(
shared_storage.ell,
ell_idx_ptr,
params.ell_blocksize,
params.ell_base_idx,
Mma::Shape::kK,
problem_size_k,
ell_stride,
thread_idx
);
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
if (!kSplitKSerial || gemm_k_iterations > 0) {
// check if index computations can be skipped
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
constexpr bool is_double = (sizeof(Mma::IteratorA::Element) == 8);
constexpr bool is_multiple_alignment =
(kAlignmentA > 1) && (kAlignmentB > 1) && (kAlignmentC > 1);
const bool is_specialized_blocksize =
((params.ell_blocksize) & (params.ell_blocksize-1)) == 0
&& params.ell_blocksize >= Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
if ((is_double || is_multiple_alignment) && is_specialized_blocksize) {
mma.operator()<true, true>(
gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator);
}
else {
mma.operator()<true, false>(
gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator);
}
}
} // if (params.ell_ncols > 0)
//
// Epilogue
//
OutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
ell_block_offset_m = threadblock_tile_offset.m() / tile_in_ell_block;
tile_offset_m = threadblock_tile_offset.m() % tile_in_ell_block;
//assume identity swizzle
MatrixCoord threadblock_offset(
ell_block_offset_m * params.ell_blocksize
+ tile_offset_m * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN
);
//avoid out of bounds
MatrixCoord threadblock_extent(
min(params.problem_size.m(),
ell_block_offset_m * params.ell_blocksize
+ min((tile_offset_m + 1) * Mma::Shape::kM, params.ell_blocksize)),
min(params.problem_size.n(),
(threadblock_tile_offset.n()+1) * Mma::Shape::kN)
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
params.ref_C.data(),
threadblock_extent,
thread_idx,
threadblock_offset
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
params.ref_D.data(),
threadblock_extent,
thread_idx,
threadblock_offset
);
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
};
// B is Sparse
template <
typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
>
struct EllGemm<Mma_, Epilogue_, ThreadblockSwizzle_, SplitKSerial, false> {
using Mma = Mma_;
using Epilogue = Epilogue_;
using OutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size{};
cutlass::gemm::GemmCoord grid_tiled_shape{};
int swizzle_log_tile{0};
typename Mma::IteratorA::Params params_A{};
typename Mma::IteratorA::TensorRef ref_A{};
typename Mma::IteratorB::Params params_B{};
typename Mma::IteratorB::TensorRef ref_B{};
typename Epilogue::OutputTileIterator::Params params_C{};
typename Epilogue::OutputTileIterator::TensorRef ref_C{};
typename Epilogue::OutputTileIterator::Params params_D{};
typename Epilogue::OutputTileIterator::TensorRef ref_D{};
typename OutputOp::Params output_op{};
int *semaphore = nullptr;
int gemm_k_iterations{0};
int gemm_k_size{0};
const int* ell_idx = nullptr;
int ell_ncol{0};
int ell_blocksize{0};
int ell_base_idx{0};
//
// Methods
//
Params() = default;
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const & problem_size,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D,
const int* ell_idx,
int ell_ncol,
int ell_blocksize,
int ell_base_idx,
typename OutputOp::Params output_op = typename OutputOp::Params(),
int *workspace = nullptr
):
problem_size(problem_size),
grid_tiled_shape(grid_tiled_shape),
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
params_A(ref_A.layout()),
ref_A(ref_A),
params_B(ref_B.layout()),
ref_B(ref_B),
params_C(ref_C.layout()),
ref_C(ref_C),
params_D(ref_D.layout()),
ref_D(ref_D),
output_op(output_op),
ell_idx(ell_idx),
ell_ncol(ell_ncol),
ell_blocksize(ell_blocksize),
ell_base_idx(ell_base_idx)
{
int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
semaphore = workspace;
}
};
/// Shared memory storage structure
struct SharedStorage {
union{
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
typename cutlass::transform::threadblock::ell::SharedStorage ell;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
EllGemm() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D) {
static int const kAlignmentA = (platform::is_same<typename Mma::IteratorA::Layout,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Mma::IteratorA::Layout,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = (platform::is_same<typename Mma::IteratorB::Layout,
layout::RowMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Mma::IteratorB::Layout,
layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
if (!TensorRef_aligned(ref_A, kAlignmentA)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_B, kAlignmentB)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_C, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_D, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
(problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
(problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
int tile_in_ell_block = (params.ell_blocksize + Mma::Shape::kN - 1 ) / Mma::Shape::kN;
int ell_block_offset_n = threadblock_tile_offset.n() / tile_in_ell_block;
int tile_offset_n = threadblock_tile_offset.n() % tile_in_ell_block;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
typename Mma::FragmentC accumulators;
accumulators.clear();
// skip computation if matrix is 0
if (params.ell_ncol > 0) {
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size,
};
cutlass::MatrixCoord tb_offset_B{
threadblock_tile_offset.k() * params.gemm_k_size,
ell_block_offset_n * params.ell_blocksize
+ tile_offset_n * Mma::Shape::kN,
};
int ell_idx_start =
(threadblock_tile_offset.n() / tile_in_ell_block) *
(params.ell_ncol / params.ell_blocksize);
const int* ell_idx_ptr = &(params.ell_idx[ell_idx_start]);
// Problem size is a function of threadblock index in the K dimension
int problem_size_k = min(
params.problem_size.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
problem_size_k = min(problem_size_k, params.ell_ncol);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations =
(problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
params.ref_A.data(),
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A);
typename Mma::IteratorB iterator_B(
params.params_B,
params.ref_B.data(),
{problem_size_k, params.problem_size.n()},
thread_idx,
tb_offset_B);
// Define coef for ELL index depending on LayoutA
int ell_stride = iterator_A.get_stride();
typename cutlass::transform::threadblock::ell::Iterator ell_iterator(
shared_storage.ell,
ell_idx_ptr,
params.ell_blocksize,
params.ell_base_idx,
Mma::Shape::kK,
problem_size_k,
ell_stride,
thread_idx
);
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
if (!kSplitKSerial || gemm_k_iterations > 0) {
// check if index computations can be skipped
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
constexpr bool is_double = (sizeof(Mma::IteratorA::Element) == 8);
constexpr bool is_multiple_alignment =
(kAlignmentA > 1) && (kAlignmentB > 1) && (kAlignmentC > 1);
const bool is_specialized_blocksize =
((params.ell_blocksize) & (params.ell_blocksize-1)) == 0
&& params.ell_blocksize >= Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
if ((is_double || is_multiple_alignment) && is_specialized_blocksize) {
mma.operator()<false, true>(
gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator);
}
else {
mma.operator()<false, false>(
gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator);
}
}
} // if (params.ell_ncols > 0)
//
// Epilogue
//
OutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
ell_block_offset_n = threadblock_tile_offset.n() / tile_in_ell_block;
tile_offset_n = threadblock_tile_offset.n() % tile_in_ell_block;
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
ell_block_offset_n * params.ell_blocksize
+ tile_offset_n * Mma::Shape::kN
);
//avoid out of bounds
MatrixCoord threadblock_extent(
min(params.problem_size.m(),
(threadblock_tile_offset.m()+1) * Mma::Shape::kM),
min(params.problem_size.n(),
ell_block_offset_n * params.ell_blocksize
+ min((tile_offset_n + 1) * Mma::Shape::kN, params.ell_blocksize))
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
params.ref_C.data(),
threadblock_extent,
thread_idx,
threadblock_offset
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
params.ref_D.data(),
threadblock_extent,
thread_idx,
threadblock_offset
);
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass