-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathcopy_traits_sm100_tma.hpp
487 lines (433 loc) · 20.6 KB
/
copy_traits_sm100_tma.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
/***************************************************************************************************
* Copyright (c) 2021 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#if !defined(__CUDACC_RTC__)
#include <cuda.h>
#endif
#include <cute/tensor.hpp>
#include <cute/atom/copy_traits_sm90_tma.hpp>
#include <cute/arch/copy_sm100_tma.hpp>
#include <cute/atom/copy_traits.hpp>
namespace cute
{
//////////////////////////////////////////////////////////////////////////////
////////////////////////////// TMA_LOAD ////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
struct SM100_TMA_2SM_LOAD_OP : SM100_TMA_2SM_LOAD {};
// The non-executable SM100_TMA_2SM_LOAD with tma_desc and no tma_mbar
// Use .with(tma_mbar) to construct an executable version
template <class NumBitsPerTMA, class AuxParams_>
struct Copy_Traits<SM100_TMA_2SM_LOAD, NumBitsPerTMA, AuxParams_>
{
using ThrID = Layout<_2>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM100_TMA_2SM_LOAD arguments
TmaDescriptor tma_desc_;
using AuxParams = AuxParams_;
AuxParams aux_params_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
TmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
// Construct an executable SM100_TMA_2SM_LOAD with tma_mbar
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM100_TMA_2SM_LOAD_OP, NumBitsPerTMA>
with(
uint64_t& tma_mbar,
[[maybe_unused]] uint16_t const& multicast_mask = 0,
TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {{}, {&tma_desc_, &tma_mbar, static_cast<uint64_t>(cache_hint)}};
}
// Construct an executable SM100_TMA_2SM_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM100_TMA_2SM_LOAD_OP, NumBitsPerTMA>
with(
TmaDescriptor const* new_tma_desc,
uint64_t& tma_mbar,
[[maybe_unused]] uint16_t const& multicast_mask = 0,
TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {{}, {new_tma_desc, &tma_mbar, static_cast<uint64_t>(cache_hint)}};
}
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
}
// Don't try to execute a copy with SM100_TMA_2SM_LOAD before calling .with()
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) = delete;
};
// The executable SM100_TMA_2SM_LOAD with tma_desc and tma_mbar
template <class NumBitsPerTMA>
struct Copy_Traits<SM100_TMA_2SM_LOAD_OP, NumBitsPerTMA>
: TMA_LOAD_Unpack<SM100_TMA_2SM_LOAD_OP, NumBitsPerTMA>
{
using ThrID = Layout<_2>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM100_TMA_2SM_LOAD arguments
tuple<
TmaDescriptor const*,
uint64_t*, // smem mbarrier
uint64_t // cache hint
> const opargs_;
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_LOAD_MULTICAST /////////////////////////////
//////////////////////////////////////////////////////////////////////////////
struct SM100_TMA_2SM_LOAD_MULTICAST_OP : SM100_TMA_2SM_LOAD_MULTICAST {};
template <class NumBitsPerTMA, class AuxParams_>
struct Copy_Traits<SM100_TMA_2SM_LOAD_MULTICAST, NumBitsPerTMA, AuxParams_>
{
using ThrID = Layout<_2>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM100_TMA_2SM_LOAD_MULTICAST_OP arguments
TmaDescriptor tma_desc_;
using AuxParams = AuxParams_;
AuxParams aux_params_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
TmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
// Construct an executable SM100_TMA_2SM_LOAD_MULTICAST_OP with tma_mbar
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM100_TMA_2SM_LOAD_MULTICAST_OP, NumBitsPerTMA>
with(
uint64_t& tma_load_mbar,
uint16_t const& multicast_mask,
TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const {
return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast<uint64_t>(cache_hint)}};
}
// Construct an executable SM100_TMA_2SM_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM100_TMA_2SM_LOAD_MULTICAST_OP, NumBitsPerTMA>
with(
TmaDescriptor const* new_tma_desc,
uint64_t& tma_load_mbar,
uint16_t const& multicast_mask,
TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const {
return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast<uint64_t>(cache_hint)}};
}
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
}
// Don't try to execute a copy with SM100_TMA_2SM_LOAD_MULTICAST_OP before calling .with()
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) = delete;
};
template <class NumBitsPerTMA>
struct Copy_Traits<SM100_TMA_2SM_LOAD_MULTICAST_OP, NumBitsPerTMA>
: TMA_LOAD_Unpack<SM100_TMA_2SM_LOAD_MULTICAST_OP, NumBitsPerTMA>
{
using ThrID = Layout<_2>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_2,NumBitsPerTMA>, Stride<NumBitsPerTMA,_1>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM100_TMA_2SM_LOAD_MULTICAST_OP arguments
tuple<
TmaDescriptor const*,
uint64_t*, // smem mbarrier
uint16_t, // multicast mask
uint64_t // cache hint
> const opargs_;
};
////////////////////////////////////
// Make TMA
///////////////////////////////////
#if !defined(__CUDACC_RTC__)
/** Make a CuTe CTA-collective TiledCopy for a TMA operation.
*
* @param CopyOp The target copy operation: SM100_TMA_2SM_LOAD
* @param gtensor The GMEM Tensor to be involved in the TMA.
* @param slayout The SMEM Layout to be involved in the TMA.
* @param cluster_tile The Cluster-local tile that each Cluster will be tiling GMEM with.
* This is often the cluster_tile_shape that is used to tile the GMEM:
* local_tile(gtensor, cluster_tile_shape, cluster_coord)
* -> Cluster-local tile of GMEM
* @param mma The TiledMMA that defines the Cluster-Tile to Block-Tile partitioning.
*
* This code attempts to maximize the TMA box size. It does this by tracing
* the SMEM "vector" -- the inverse of the smem layout -- to find the largest
* contiguous array of smem that can be written to/from global memory given
* the constraints that the TMA instruction imposes.
*
* This is accomplished by assigning "basis" strides to the GMEM to track which
* modes of SMEM map to which modes of GMEM, then reordering the modes of GMEM according
* to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc.
*
* Examples:
*/
template <class TmaInternalType = void,
class CopyOp,
class GEngine, class GLayout,
class SLayout,
class Cluster_Tiler,
class... Args>
CUTE_HOST
auto
make_tma_copy_A_sm100(CopyOp const& copy_op,
Tensor<GEngine,GLayout> const& gtensor, // (M, K, ...)
SLayout const& slayout, // (MMA, MMA_M, MMA_K, ...)
Cluster_Tiler const& cluster_tiler, // (TILER_M, TILER_N, TILER_K, ...)
TiledMMA<Args...> const& mma)
{
// Keep only MK modes from MNK
auto cluster_tiler_mk = remove<1>(cluster_tiler);
// cluster tile coord -> gtensor coord
auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_mk); // (TILE_M, TILE_K, ...)
// cta val idx -> gmem mode
auto cta_v_tile = layout<1>(mma.thrfrg_A(g_tile))(_, repeat<rank(g_tile)>(_)); // (MMA, MMA_M, MMA_K, ...)
auto cta_t_vmnk_strides = [](){
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_MULTICAST> ||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_MULTICAST>) {
return Stride<_0,_0,_1,_0>{}; // VMNK: Use only the N-CTAs in the Multicast
} else
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD> ||
is_same_v<CopyOp, SM90_TMA_STORE> ||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD>) {
return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast
} else {
static_assert(dependent_false<CopyOp>, "Unsupported TMA");
}
}();
auto cta_t_shape = shape(mma.get_thr_layout_vmnk());
// cta rank -> logical cta idx
auto cta_t_map = coalesce(make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides)));
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
return detail::make_tma_copy_tiled<TmaType>(copy_op, gtensor, slayout, cta_t_map, cta_v_tile);
}
template <class TmaInternalType = void,
class CopyOp,
class GEngine, class GLayout,
class SLayout,
class Cluster_Tiler,
class... Args>
CUTE_HOST
auto
make_tma_copy_B_sm100(CopyOp const& copy_op,
Tensor<GEngine,GLayout> const& gtensor, // (N, K, ...)
SLayout const& slayout, // (MMA, MMA_N, MMA_K, ...)
Cluster_Tiler const& cluster_tiler, // (TILE_M, TILE_N, TILE_K, ...)
TiledMMA<Args...> const& mma)
{
// Keep only NK modes from MNK
auto cluster_tiler_nk = remove<0>(cluster_tiler);
// cluster tile coord -> gtensor coord
auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_nk); // (TILE_N, TILE_K, ...)
// cta val idx -> gmem mode
auto cta_v_tile = layout<1>(mma.thrfrg_B(g_tile))(_, repeat<rank(g_tile)>(_)); // (MMA, MMA_N, MMA_K, ...)
auto cta_t_vmnk_strides = [](){
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_MULTICAST> ||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_MULTICAST>) {
return Stride<_0,_1,_0,_0>{}; // VMNK: Use only the M-CTAs in the Multicast
} else
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD> ||
is_same_v<CopyOp, SM90_TMA_STORE> ||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD>) {
return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast
} else {
static_assert(dependent_false<CopyOp>, "Unsupported TMA");
}
}();
auto cta_t_shape = shape(mma.get_thr_layout_vmnk());
// cta rank -> logical cta idx
auto cta_t_map = coalesce(make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides)));
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
return detail::make_tma_copy_tiled<TmaType>(copy_op, gtensor, slayout, cta_t_map, cta_v_tile);
}
template <class TmaInternalType = void,
class CopyOp,
class GEngine, class GLayout,
class SLayout,
class Cluster_Tiler,
class... Args>
CUTE_HOST
auto
make_tma_copy_C_sm100(CopyOp const& copy_op,
Tensor<GEngine,GLayout> const& gtensor, // (M, N, ...)
SLayout const& slayout, // (MMA, MMA_M, MMA_N, ...)
Cluster_Tiler const& cluster_tiler, // (TILE_M, TILE_N, TILE_K, ...)
TiledMMA<Args...> const& mma)
{
// Keep only MN modes from MNK
auto cluster_tiler_mn = remove<2>(cluster_tiler);
// cluster tile coord -> gtensor coord
auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_mn); // (TILE_M, TILE_N, ...)
// cta val idx -> gmem mode
auto cta_v_tile = layout<1>(mma.thrfrg_C(g_tile))(_, repeat<rank(g_tile)>(_)); // (MMA, MMA_M, MMA_N, ...)
static_assert(is_same_v<CopyOp, SM90_TMA_LOAD> ||
is_same_v<CopyOp, SM90_TMA_STORE> ||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD>,
"Unsupported TMA Op, expected a non-multicast TMA");
// No multicast, so only 1 CTA involved
auto cta_t_map = Layout<_1,_0>{};
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
return detail::make_tma_copy_tiled<TmaType>(copy_op, gtensor, slayout, cta_t_map, cta_v_tile);
}
////////////////////////////////////
// Experimental Make TMA Atom
///////////////////////////////////
template <class TmaInternalType = void,
class CopyOp,
class GEngine, class GLayout,
class SLayout,
class MMA_Tiler,
class... Args,
class ClusterShapeVMNK>
CUTE_HOST
auto
make_tma_atom_A_sm100(CopyOp const& copy_op,
Tensor<GEngine,GLayout> const& gtensor, // (M, K, ...)
SLayout const& slayout, // (MMA, MMA_M, MMA_K, ...)
MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...)
TiledMMA<Args...> const& mma,
ClusterShapeVMNK const& cluster_shape) // (CTA_V, CTA_M, CTA_N, CTA_K)
{
// Keep only MK modes from MNK
auto mma_tiler_mk = remove<1>(mma_tiler);
// cluster tile coord -> gtensor coord
auto g_tile = make_identity_layout(shape(gtensor)).compose(mma_tiler_mk); // (TILE_M, TILE_K, ...)
// cta val idx -> gmem mode
auto cta_v_tile = layout<1>(mma.thrfrg_A(g_tile))(_, repeat<rank(g_tile)>(_)); // (MMA, MMA_M, MMA_K, ...)
#if 0
print("(tma_a) slayout: "); print(slayout); print("\n");
print("(tma_a) mma_tiler_nk: "); print(mma_tiler_nk); print("\n");
print("(tma_a) g_tile: "); print(g_tile); print("\n");
print("(tma_a) mma_tiler: "); print(mma_tiler); print("\n");
print("(tma_a) cta_v_tile: "); print(cta_v_tile); print("\n");
#endif
// The size of the multicasting
auto num_multicast = [&](){
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_MULTICAST> ||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_MULTICAST>) {
return size<2>(cluster_shape); // VMNK: Use only the N-CTAs in the Multicast
} else
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD> ||
is_same_v<CopyOp, SM90_TMA_STORE> ||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD>) {
return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast
} else {
static_assert(dependent_false<CopyOp>, "Unsupported TMA");
}
}();
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
return detail::make_tma_copy_atom<TmaType>(copy_op, gtensor, slayout, num_multicast, cta_v_tile);
}
template <class TmaInternalType = void,
class CopyOp,
class GEngine, class GLayout,
class SLayout,
class MMA_Tiler,
class... Args,
class ClusterShapeVMNK>
CUTE_HOST
auto
make_tma_atom_B_sm100(CopyOp const& copy_op,
Tensor<GEngine,GLayout> const& gtensor, // (N, K, ...)
SLayout const& slayout, // (MMA, MMA_N, MMA_K, ...)
MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...)
TiledMMA<Args...> const& mma,
ClusterShapeVMNK const& cluster_shape) // (CTA_V, CTA_M, CTA_N, CTA_K)
{
// Keep only NK modes from MNK
auto mma_tiler_nk = remove<0>(mma_tiler);
// cluster tile coord -> gtensor coord
auto g_tile = make_identity_layout(shape(gtensor)).compose(mma_tiler_nk); // (TILE_N, TILE_K, ...)
// cta val idx -> gmem mode
auto cta_v_tile = layout<1>(mma.thrfrg_B(g_tile))(_, repeat<rank(g_tile)>(_)); // (MMA, MMA_N, MMA_K, ...)
#if 0
print("(tma_b) slayout: "); print(slayout); print("\n");
print("(tma_b) mma_tiler_nk: "); print(mma_tiler_nk); print("\n");
print("(tma_b) g_tile: "); print(g_tile); print("\n");
print("(tma_b) mma_tiler: "); print(mma_tiler); print("\n");
print("(tma_b) cta_v_tile: "); print(cta_v_tile); print("\n");
#endif
// The size of the multicasting
auto num_multicast = [&](){
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD_MULTICAST> ||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD_MULTICAST>) {
return size<1>(cluster_shape); // VMNK: Use only the M-CTAs in the Multicast
} else
if constexpr (is_same_v<CopyOp, SM90_TMA_LOAD> ||
is_same_v<CopyOp, SM90_TMA_STORE> ||
is_same_v<CopyOp, SM100_TMA_2SM_LOAD>) {
return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast
} else {
static_assert(dependent_false<CopyOp>, "Unsupported TMA");
}
}();
// Prefer TmaInternalType if specified. Fallback to GEngine::value_type
using TmaType = conditional_t<is_same<void, TmaInternalType>::value, typename GEngine::value_type, TmaInternalType>;
return detail::make_tma_copy_atom<TmaType>(copy_op, gtensor, slayout, num_multicast, cta_v_tile);
}
#endif // !defined(__CUDACC_RTC__)
} // end namespace cute